package com.github.steveash.jg2p.seq;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.CRFTrainerByThreadedLabelLikelihood;
import cc.mallet.fst.TokenAccuracyEvaluator;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2LabelSequence;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.RankedFeatureVector;
import com.github.steveash.jg2p.align.Alignment;
import com.github.steveash.jg2p.align.TrainOptions;
import com.github.steveash.jg2p.seq.SeqInputReader;
import com.github.steveash.jg2p.util.FeatureSelections;
import com.github.steveash.jg2p.util.GramBuilder;
import com.github.steveash.jg2p.util.ModelReadWrite;
import com.github.steveash.jg2p.util.ReadWrite;
import com.google.common.base.Function;
import com.google.common.base.Stopwatch;
import com.google.common.base.Throwables;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
import java.io.File;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.regex.Pattern;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/steveash/jg2p/seq/PhonemeCrfTrainer.class */
public class PhonemeCrfTrainer {
    public static final String PROP_STRUCTURE = "prop_structure";
    private static final Logger log = LoggerFactory.getLogger(PhonemeCrfTrainer.class);
    private final TrainOptions opts;
    private CRF crfFrom;
    private CRF crf = null;
    private TransducerTrainer lastTrainer = null;

    public static PhonemeCrfTrainer open(TrainOptions trainOptions) {
        return new PhonemeCrfTrainer(trainOptions);
    }

    private PhonemeCrfTrainer(TrainOptions trainOptions) {
        this.crfFrom = null;
        this.opts = trainOptions;
        if (trainOptions.initCrfFromModelFile != null) {
            try {
                log.info("Loading initial weights from " + trainOptions.initCrfFromModelFile);
                this.crfFrom = readCrfFrom();
            } catch (Exception e) {
                throw Throwables.propagate(e);
            }
        }
    }

    private void initializeFor(InstanceList instanceList) {
        this.crf = new CRF(instanceList.getPipe(), (Pipe) null);
        this.crf.addOrderNStates(instanceList, new int[]{1}, (boolean[]) null, (String) null, (Pattern) null, (Pattern) null, false);
        this.crf.addStartState();
        this.crf.setWeightsDimensionAsIn(instanceList, false);
        if (this.crfFrom != null) {
            this.crf.initializeApplicableParametersFrom(this.crfFrom);
        }
    }

    private CRF readCrfFrom() throws IOException, ClassNotFoundException {
        return ModelReadWrite.readPronouncerFrom(this.opts.initCrfFromModelFile).getCrf();
    }

    public void trainFor(Collection<Alignment> collection) {
        Stopwatch createStarted = Stopwatch.createStarted();
        trainRound(collection, new Alphabet(), 0);
        this.crf.getInputAlphabet().stopGrowth();
        this.crf.getOutputAlphabet().stopGrowth();
        createStarted.stop();
        log.info("Training took " + createStarted);
    }

    private void trainRound(Collection<Alignment> collection, Alphabet alphabet, int i) {
        InstanceList makeExamplesFromAligns = makeExamplesFromAligns(makePipe(alphabet), collection);
        initializeFor(makeExamplesFromAligns);
        CRFTrainerByThreadedLabelLikelihood makeNewTrainer = makeNewTrainer(this.crf);
        this.lastTrainer = makeNewTrainer;
        makeNewTrainer.train(makeExamplesFromAligns, this.opts.maxPronouncerTrainingIterations);
        makeNewTrainer.shutdown();
        if (i == 0 && this.opts.trimFeaturesUnderPercentile > 0) {
            makeNewTrainer.getCRF().pruneFeaturesBelowPercentile(this.opts.trimFeaturesUnderPercentile);
            makeNewTrainer.train(makeExamplesFromAligns);
            makeNewTrainer.shutdown();
        }
        if (i != 0 || this.opts.trimFeaturesByGradientGain <= 0.0d) {
            return;
        }
        log.info("Trimming based on gradiant gain ratio...");
        RankedFeatureVector gradientGainRatioFrom = FeatureSelections.gradientGainRatioFrom(makeExamplesFromAligns, this.crf);
        Alphabet alphabet2 = new Alphabet();
        for (int i2 = 0; i2 < gradientGainRatioFrom.singleSize(); i2++) {
            if (gradientGainRatioFrom.value(i2) > this.opts.trimFeaturesByGradientGain) {
                alphabet2.lookupIndex(alphabet.lookupObject(i2), true);
            }
        }
        log.info("Feature selection before count " + alphabet.size() + " after " + alphabet2.size());
        alphabet2.stopGrowth();
        this.crfFrom = this.crf;
        trainRound(collection, alphabet2, i + 1);
    }

    private double accuracyFor(InstanceList instanceList) {
        TokenAccuracyEvaluator tokenAccuracyEvaluator = new TokenAccuracyEvaluator(instanceList, "train");
        tokenAccuracyEvaluator.evaluate(this.lastTrainer);
        return tokenAccuracyEvaluator.getAccuracy("train");
    }

    public PhonemeCrfModel buildModel() {
        return new PhonemeCrfModel(this.crf);
    }

    private static CRFTrainerByThreadedLabelLikelihood makeNewTrainer(CRF crf) {
        CRFTrainerByThreadedLabelLikelihood cRFTrainerByThreadedLabelLikelihood = new CRFTrainerByThreadedLabelLikelihood(crf, getCpuCount());
        cRFTrainerByThreadedLabelLikelihood.setGaussianPriorVariance(2.0d);
        cRFTrainerByThreadedLabelLikelihood.setAddNoFactors(true);
        cRFTrainerByThreadedLabelLikelihood.setUseSomeUnsupportedTrick(false);
        return cRFTrainerByThreadedLabelLikelihood;
    }

    private static CRFTrainerByLabelLikelihood makeNewTrainerSingleThreaded(CRF crf) {
        CRFTrainerByLabelLikelihood cRFTrainerByLabelLikelihood = new CRFTrainerByLabelLikelihood(crf);
        cRFTrainerByLabelLikelihood.setGaussianPriorVariance(2.0d);
        cRFTrainerByLabelLikelihood.setAddNoFactors(true);
        cRFTrainerByLabelLikelihood.setUseSomeUnsupportedTrick(false);
        return cRFTrainerByLabelLikelihood;
    }

    private static int getCpuCount() {
        return Runtime.getRuntime().availableProcessors();
    }

    public void writeModel(File file) throws IOException {
        ReadWrite.writeTo(new PhonemeCrfModel(this.lastTrainer.getTransducer()), file);
        log.info("Wrote for whole data");
    }

    private InstanceList makeExamplesFromAligns(Pipe pipe, Iterable<Alignment> iterable) {
        int i = 0;
        InstanceList instanceList = new InstanceList(pipe);
        for (Alignment alignment : iterable) {
            List<String> allYTokensAsList = alignment.getAllYTokensAsList();
            updateEpsilons(allYTokensAsList);
            instanceList.addThruPipe(new Instance(alignment, allYTokensAsList, (Object) null, (Object) null));
            i++;
        }
        log.info("Read {} instances of training data for pronouncer training", Integer.valueOf(i));
        return instanceList;
    }

    private Iterable<Alignment> getAlignsFromGroup(List<SeqInputReader.AlignGroup> list) {
        return FluentIterable.from(list).transformAndConcat(new Function<SeqInputReader.AlignGroup, Iterable<Alignment>>() { // from class: com.github.steveash.jg2p.seq.PhonemeCrfTrainer.1
            public Iterable<Alignment> apply(SeqInputReader.AlignGroup alignGroup) {
                return alignGroup.alignments;
            }
        });
    }

    private static void updateEpsilons(List<String> list) {
        for (int i = 0; i < list.size(); i++) {
            if (StringUtils.isBlank(list.get(i))) {
                list.set(i, GramBuilder.EPS);
            }
        }
    }

    private SerialPipes makePipe(Alphabet alphabet) {
        Pipe target2LabelSequence = new Target2LabelSequence();
        return new SerialPipes(ImmutableList.of(new AlignmentToTokenSequence(alphabet, target2LabelSequence.getTargetAlphabet(), true, true, false), new TokenSequenceLowercase(), new NeighborTokenFeature(true, makeNeighbors()), new NeighborShapeFeature(true, makeShapeNeighs()), new NeighborSyllableFeature(-2, -1, 1, 2), new SyllCountingFeature(), new SyllCharRoleFeature(), new EndingVowelFeature(), new VowelWindowFeature(2, 1, "PRESYL_", -1, false), new VowelWindowFeature(2, 1, "PSTSYL_", 1, false), new SurroundingTokenFeature2(false, 1, 1), new SurroundingTokenFeature2(false, 2, 2), new Pipe[]{new SurroundingTokenFeature2(true, 3, 3), new TokenSequenceToFeature(), new TokenSequence2FeatureVectorSequence(alphabet, true, false), target2LabelSequence}));
    }

    private static List<TokenWindow> makeShapeNeighs() {
        return ImmutableList.of(new TokenWindow(-6, 6), new TokenWindow(-5, 5), new TokenWindow(-4, 4), new TokenWindow(1, 4), new TokenWindow(1, 5), new TokenWindow(1, 6));
    }

    private static List<TokenWindow> makeNeighbors() {
        return ImmutableList.of(new TokenWindow(1, 1), new TokenWindow(1, 2), new TokenWindow(2, 1), new TokenWindow(1, 3), new TokenWindow(3, 1), new TokenWindow(1, 4), new TokenWindow(-1, 1), new TokenWindow(-2, 2), new TokenWindow(-2, 1), new TokenWindow(-3, 3), new TokenWindow(-4, 4), new TokenWindow(-5, 5), new TokenWindow[0]);
    }
}
