package com.github.steveash.jg2p.syll;

import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByThreadedLabelLikelihood;
import cc.mallet.fst.TransducerTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2LabelSequence;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import com.github.steveash.jg2p.seq.StringListToTokenSequence;
import com.github.steveash.jg2p.seq.TokenSequenceToFeature;
import com.github.steveash.jg2p.seq.TokenWindow;
import com.google.common.base.Stopwatch;
import com.google.common.collect.ImmutableList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/steveash/jg2p/syll/PhoneSyllTagTrainer.class */
public class PhoneSyllTagTrainer {
    public static final boolean USE_ONC_CODING = false;
    private static final Logger log = LoggerFactory.getLogger(PhoneSyllTagTrainer.class);
    private CRF pullFrom = null;

    public void setPullFrom(CRF crf) {
        this.pullFrom = crf;
    }

    public PhoneSyllTagModel train(Collection<SWord> collection) {
        InstanceList makeExamplesFromAligns = makeExamplesFromAligns(collection);
        Pipe pipe = makeExamplesFromAligns.getPipe();
        log.info("Training test-time syll phone tagger on whole data...");
        return new PhoneSyllTagModel(trainOnce(pipe, makeExamplesFromAligns).getTransducer());
    }

    private InstanceList makeExamplesFromAligns(Collection<SWord> collection) {
        int i = 0;
        InstanceList instanceList = new InstanceList(makePipe());
        Iterator<SWord> it = collection.iterator();
        while (it.hasNext()) {
            instanceList.addThruPipe(new Instance(it.next(), (Object) null, (Object) null, (Object) null));
            i++;
        }
        log.info("Read {} instances of training data for syll phone tag", Integer.valueOf(i));
        return instanceList;
    }

    private TransducerTrainer trainOnce(Pipe pipe, InstanceList instanceList) {
        Stopwatch createStarted = Stopwatch.createStarted();
        CRF crf = new CRF(pipe, (Pipe) null);
        crf.addOrderNStates(instanceList, new int[]{1}, (boolean[]) null, (String) null, (Pattern) null, (Pattern) null, false);
        crf.addStartState();
        crf.setWeightsDimensionAsIn(instanceList);
        if (this.pullFrom != null) {
            crf.initializeApplicableParametersFrom(this.pullFrom);
        }
        log.info("Starting syll phone training...");
        CRFTrainerByThreadedLabelLikelihood cRFTrainerByThreadedLabelLikelihood = new CRFTrainerByThreadedLabelLikelihood(crf, 8);
        cRFTrainerByThreadedLabelLikelihood.setGaussianPriorVariance(2.0d);
        cRFTrainerByThreadedLabelLikelihood.setAddNoFactors(false);
        cRFTrainerByThreadedLabelLikelihood.setUseSomeUnsupportedTrick(true);
        cRFTrainerByThreadedLabelLikelihood.train(instanceList);
        cRFTrainerByThreadedLabelLikelihood.shutdown();
        createStarted.stop();
        pipe.getAlphabet().stopGrowth();
        pipe.getTargetAlphabet().stopGrowth();
        log.info("Align Tag CRF Training took " + createStarted.toString());
        return cRFTrainerByThreadedLabelLikelihood;
    }

    private Pipe makePipe() {
        Alphabet alphabet = new Alphabet();
        Target2LabelSequence target2LabelSequence = new Target2LabelSequence();
        return new SerialPipes(ImmutableList.of(new SWordConverterPipe(), new StringListToTokenSequence(alphabet, target2LabelSequence.getTargetAlphabet()), new TokenSequenceLowercase(), new PhoneNeighborPipe(true, makeNeighbors()), new PhoneClassPipe(true, makeClassNeighbors()), new VowelNeighborPipe(), new IsFirstPipe(), new ThisPhoneClassPipe(), new TokenSequenceToFeature(), new TokenSequence2FeatureVectorSequence(alphabet, true, false), target2LabelSequence));
    }

    private List<TokenWindow> makeClassNeighbors() {
        return ImmutableList.of(new TokenWindow(1, 1), new TokenWindow(1, 2), new TokenWindow(1, 3), new TokenWindow(-3, 3), new TokenWindow(-2, 2), new TokenWindow(-1, 1));
    }

    private List<TokenWindow> makeNeighbors() {
        return ImmutableList.of(new TokenWindow(1, 1), new TokenWindow(2, 1), new TokenWindow(3, 1), new TokenWindow(-1, 1), new TokenWindow(-2, 1), new TokenWindow(-3, 1));
    }
}
