package com.github.steveash.jg2p.align;

import com.github.steveash.jg2p.Word;
import com.github.steveash.jg2p.align.ProbTable;
import com.github.steveash.jg2p.align.XyWalker;
import com.github.steveash.jg2p.util.Assert;
import com.github.steveash.jg2p.util.DoubleTable;
import com.github.steveash.jg2p.util.ReadWrite;
import com.google.common.base.Throwables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.collect.Tables;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.tuple.Pair;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/steveash/jg2p/align/AlignerTrainer.class */
public class AlignerTrainer {
    private static final Logger log = LoggerFactory.getLogger(AlignerTrainer.class);
    private final ProbTable counts;
    private final ProbTable probs;
    private final ProbTable originalCounts;
    private final TrainOptions trainOpts;
    private final GramOptions gramOpts;
    private final XyWalker walker;
    private ProbTable labelledProbs;
    private final Set<Pair<String, String>> allowed;
    private final Set<Pair<String, String>> blocked;
    private final Penalizer penalizer;
    private ProbTable initFrom;

    public AlignerTrainer(TrainOptions trainOptions) {
        this(trainOptions, null);
    }

    public AlignerTrainer(TrainOptions trainOptions, XyWalker xyWalker) {
        this.counts = new ProbTable();
        this.probs = new ProbTable();
        this.originalCounts = new ProbTable();
        this.initFrom = null;
        this.trainOpts = trainOptions;
        this.gramOpts = trainOptions.makeGramOptions();
        XyWalker windowXyWalker = xyWalker == null ? trainOptions.useWindowWalker ? new WindowXyWalker(this.gramOpts) : new FullXyWalker(this.gramOpts) : xyWalker;
        if (trainOptions.alignAllowedFile != null) {
            try {
                this.allowed = FilterWalkerDecorator.readFromFile(trainOptions.alignAllowedFile);
                this.blocked = Sets.newHashSet();
            } catch (IOException e) {
                throw Throwables.propagate(e);
            }
        } else {
            this.allowed = null;
            this.blocked = null;
        }
        this.walker = windowXyWalker;
        this.penalizer = this.gramOpts.makePenalizer();
    }

    public void setInitFrom(ProbTable probTable) {
        this.initFrom = probTable;
    }

    public AlignModel train(List<InputRecord> list) {
        return train(list, new ProbTable());
    }

    public AlignModel train(List<InputRecord> list, ProbTable probTable) {
        ListeningExecutorService listeningDecorator = MoreExecutors.listeningDecorator(Executors.newCachedThreadPool());
        try {
            this.labelledProbs = probTable.makeNormalizedCopy();
            initCounts(list);
            maximization();
            int i = 0;
            boolean z = true;
            log.info("Starting EM rounds...");
            while (z) {
                i++;
                expectation(list, listeningDecorator);
                double maximization = maximization();
                z = !hasConverged(maximization, i);
                log.info("Completed EM round " + i + " mass delta " + String.format("%.15f", Double.valueOf(maximization)));
            }
            log.info("Training complete in " + i + " rounds!");
            return new AlignModel(this.gramOpts, this.probs);
        } finally {
            MoreExecutors.shutdownAndAwaitTermination(listeningDecorator, 60L, TimeUnit.SECONDS);
        }
    }

    private boolean hasConverged(double d, int i) {
        if (d >= this.trainOpts.probDeltaConvergenceThreshold) {
            return i >= this.trainOpts.trainingAlignerMaxIterations;
        }
        log.info("EM only had a mass shift by " + d + " training is complete.");
        return true;
    }

    private void expectation(List<InputRecord> list, ListeningExecutorService listeningExecutorService) {
        int availableProcessors = Runtime.getRuntime().availableProcessors();
        ArrayList newArrayList = Lists.newArrayList();
        Iterator it = Lists.partition(list, availableProcessors).iterator();
        while (it.hasNext()) {
            newArrayList.add(listeningExecutorService.submit(makeConsumer((List) it.next())));
        }
        try {
            ProbTable.mergeAll((List) Futures.allAsList(newArrayList).get(), this.counts);
        } catch (Exception e) {
            throw Throwables.propagate(e);
        }
    }

    private Callable<ProbTable> makeConsumer(final List<InputRecord> list) {
        return new Callable<ProbTable>() { // from class: com.github.steveash.jg2p.align.AlignerTrainer.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public ProbTable call() throws Exception {
                ProbTable probTable = new ProbTable();
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    AlignerTrainer.this.expectationForRecord((InputRecord) it.next(), probTable);
                }
                return probTable;
            }
        };
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void expectationForRecord(InputRecord inputRecord, final ProbTable probTable) {
        Word word = inputRecord.xWord;
        Word word2 = inputRecord.yWord;
        int unigramCount = word.unigramCount();
        int unigramCount2 = word2.unigramCount();
        final DoubleTable doubleTable = new DoubleTable(unigramCount + 1, unigramCount2 + 1);
        final DoubleTable doubleTable2 = new DoubleTable(unigramCount + 1, unigramCount2 + 1);
        forward(word, word2, doubleTable);
        backward(word, word2, doubleTable2);
        final double d = doubleTable.get(unigramCount, unigramCount2);
        if (d == 0.0d) {
            return;
        }
        this.walker.forward(word, word2, new XyWalker.Visitor() { // from class: com.github.steveash.jg2p.align.AlignerTrainer.2
            @Override // com.github.steveash.jg2p.align.XyWalker.Visitor
            public void visit(int i, int i2, String str, int i3, int i4, String str2) {
                probTable.addProb(str, str2, ((doubleTable.get(i, i3) * AlignerTrainer.this.penalize(str, str2, AlignerTrainer.this.probs.prob(str, str2))) * doubleTable2.get(i2, i4)) / d);
            }
        });
    }

    private void backward(Word word, Word word2, final DoubleTable doubleTable) {
        doubleTable.put(word.unigramCount(), word2.unigramCount(), 1.0d);
        this.walker.backward(word, word2, new XyWalker.Visitor() { // from class: com.github.steveash.jg2p.align.AlignerTrainer.3
            @Override // com.github.steveash.jg2p.align.XyWalker.Visitor
            public void visit(int i, int i2, String str, int i3, int i4, String str2) {
                doubleTable.add(i, i3, AlignerTrainer.this.penalize(str, str2, AlignerTrainer.this.probs.prob(str, str2)) * doubleTable.get(i2, i4));
            }
        });
    }

    private void forward(Word word, Word word2, final DoubleTable doubleTable) {
        doubleTable.put(0, 0, 1.0d);
        this.walker.forward(word, word2, new XyWalker.Visitor() { // from class: com.github.steveash.jg2p.align.AlignerTrainer.4
            @Override // com.github.steveash.jg2p.align.XyWalker.Visitor
            public void visit(int i, int i2, String str, int i3, int i4, String str2) {
                doubleTable.add(i2, i4, AlignerTrainer.this.penalize(str, str2, AlignerTrainer.this.probs.prob(str, str2)) * doubleTable.get(i, i3));
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double penalize(String str, String str2, double d) {
        return this.penalizer.penalize(str, str2, d);
    }

    private double maximization() {
        smoothCounts();
        ProbTable.Marginals calculateMarginals = this.counts.calculateMarginals();
        double d = 0.0d;
        double d2 = 1.0d - this.trainOpts.semiSupervisedFactor;
        double d3 = this.trainOpts.semiSupervisedFactor;
        for (Pair<String, String> pair : ProbTable.unionOfAllCells(this.counts, this.labelledProbs)) {
            String str = (String) pair.getLeft();
            String str2 = (String) pair.getRight();
            double maximize = (d2 * this.trainOpts.trainingAlignerMaximizer.maximize(Tables.immutableCell(str, str2, Double.valueOf(this.counts.prob(str, str2))), calculateMarginals)) + (d3 * this.labelledProbs.prob(str, str2));
            Assert.assertProb(maximize);
            d += Math.abs(this.probs.prob(str, str2) - maximize);
            this.probs.setProb(str, str2, maximize);
        }
        this.counts.clear();
        return this.trainOpts.trainingAlignerMaximizer.normalize(d, calculateMarginals);
    }

    private void smoothCounts() {
        if (this.allowed == null) {
            return;
        }
        double minAllowedCount = minAllowedCount() * ((2.0d * this.allowed.size()) / this.blocked.size());
        for (Pair<String, String> pair : this.allowed) {
            this.counts.addProb((String) pair.getLeft(), (String) pair.getRight(), minAllowedCount);
        }
        for (Pair<String, String> pair2 : this.blocked) {
            this.counts.setProb((String) pair2.getLeft(), (String) pair2.getRight(), minAllowedCount);
        }
    }

    private double minAllowedCount() {
        double d = Double.POSITIVE_INFINITY;
        for (Pair<String, String> pair : this.allowed) {
            double prob = this.counts.prob((String) pair.getLeft(), (String) pair.getRight());
            if (prob > 0.0d && prob < d) {
                d = prob;
            }
        }
        return d;
    }

    private void initCounts(List<InputRecord> list) {
        this.counts.clear();
        this.originalCounts.clear();
        for (InputRecord inputRecord : list) {
            this.walker.forward(inputRecord.m15getLeft(), inputRecord.m16getRight(), new XyWalker.Visitor() { // from class: com.github.steveash.jg2p.align.AlignerTrainer.5
                @Override // com.github.steveash.jg2p.align.XyWalker.Visitor
                public void visit(int i, int i2, String str, int i3, int i4, String str2) {
                    double d = 1.0d;
                    if (AlignerTrainer.this.initFrom != null) {
                        double prob = AlignerTrainer.this.initFrom.prob(str, str2);
                        if (prob > 0.0d) {
                            d = prob;
                        }
                    }
                    AlignerTrainer.this.originalCounts.addProb(str, str2, d);
                    if (AlignerTrainer.this.allowed == null) {
                        AlignerTrainer.this.counts.addProb(str, str2, d);
                    } else if (AlignerTrainer.this.allowed.contains(Pair.of(str, str2))) {
                        AlignerTrainer.this.counts.addProb(str, str2, d);
                    } else {
                        AlignerTrainer.this.blocked.add(Pair.of(str, str2));
                    }
                }
            });
        }
    }

    public int numberOfLowSupportAlignments(Alignment alignment, int i) {
        int i2 = 0;
        for (Pair<String, String> pair : alignment.getGraphones()) {
            double prob = this.originalCounts.prob((String) pair.getLeft(), (String) pair.getRight());
            if (prob > 0.0d && prob <= i) {
                i2++;
            }
        }
        return i2;
    }

    public static void main(String[] strArr) {
        try {
            trainAndSave(strArr);
        } catch (Exception e) {
            log.error("Problem training ", e);
        }
    }

    public static AlignModel trainAndSave(String[] strArr) throws CmdLineException, IOException {
        TrainOptions parseArgs = parseArgs(strArr);
        AlignerTrainer alignerTrainer = new AlignerTrainer(parseArgs);
        log.info("Reading input training records...");
        List<InputRecord> readFromFile = parseArgs.makeReader().readFromFile(parseArgs.trainingFile);
        log.info("Training the probabilistic model...");
        AlignModel train = alignerTrainer.train(readFromFile);
        log.info("Writing model to " + parseArgs.outputFile + "...");
        ReadWrite.writeTo(train, parseArgs.outputFile);
        log.info("Training complete!");
        return train;
    }

    private static TrainOptions parseArgs(String[] strArr) throws CmdLineException {
        TrainOptions trainOptions = new TrainOptions();
        new CmdLineParser(trainOptions).parseArgument(strArr);
        trainOptions.afterParametersSet();
        return trainOptions;
    }
}
