package com.github.steveash.jg2p.seq;

import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.DefaultAcrfTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.Target2LabelSequence;
import cc.mallet.pipe.TokenSequence2FeatureVectorSequence;
import cc.mallet.pipe.TokenSequenceLowercase;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import com.github.steveash.jg2p.align.Alignment;
import com.github.steveash.jg2p.seq.SeqInputReader;
import com.github.steveash.jg2p.util.GramBuilder;
import com.google.common.base.Function;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
import java.util.Collection;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/steveash/jg2p/seq/PhonemeACrfTrainer.class */
public class PhonemeACrfTrainer {
    private static final Logger log = LoggerFactory.getLogger(PhonemeACrfTrainer.class);

    public void train(Collection<Alignment> collection) {
        Pipe makePipe = makePipe();
        InstanceList makeExamplesFromAligns = makeExamplesFromAligns(collection, makePipe);
        ACRF acrf = new ACRF(makePipe, new ACRF.Template[]{new ACRF.BigramTemplate(0)});
        DefaultAcrfTrainer defaultAcrfTrainer = new DefaultAcrfTrainer();
        acrf.setSupportedOnly(true);
        acrf.setGaussianPriorVariance(2.0d);
        DefaultAcrfTrainer.LogEvaluator logEvaluator = new DefaultAcrfTrainer.LogEvaluator();
        logEvaluator.setNumIterToSkip(2);
        defaultAcrfTrainer.train(acrf, makeExamplesFromAligns, (InstanceList) null, (InstanceList) null, logEvaluator, 9999);
    }

    private static InstanceList makeExamplesFromAligns(Iterable<Alignment> iterable, Pipe pipe) {
        int i = 0;
        InstanceList instanceList = new InstanceList(pipe);
        for (Alignment alignment : iterable) {
            List<String> allYTokensAsList = alignment.getAllYTokensAsList();
            updateEpsilons(allYTokensAsList);
            instanceList.addThruPipe(new Instance(alignment.getAllXTokensAsList(), allYTokensAsList, (Object) null, (Object) null));
            i++;
        }
        log.info("Read {} instances of training data", Integer.valueOf(i));
        return instanceList;
    }

    private Iterable<Alignment> getAlignsFromGroup(List<SeqInputReader.AlignGroup> list) {
        return FluentIterable.from(list).transformAndConcat(new Function<SeqInputReader.AlignGroup, Iterable<Alignment>>() { // from class: com.github.steveash.jg2p.seq.PhonemeACrfTrainer.1
            public Iterable<Alignment> apply(SeqInputReader.AlignGroup alignGroup) {
                return alignGroup.alignments;
            }
        });
    }

    private static void updateEpsilons(List<String> list) {
        for (int i = 0; i < list.size(); i++) {
            if (StringUtils.isBlank(list.get(i))) {
                list.set(i, GramBuilder.EPS);
            }
        }
    }

    private static Pipe makePipe() {
        Alphabet alphabet = new Alphabet();
        Target2LabelSequence target2LabelSequence = new Target2LabelSequence();
        LabelAlphabet targetAlphabet = target2LabelSequence.getTargetAlphabet();
        return new SerialPipes(ImmutableList.of(new StringListToTokenSequence(alphabet, targetAlphabet), new TokenSequenceLowercase(), new NeighborTokenFeature(true, makeNeighbors()), new NeighborShapeFeature(true, makeShapeNeighs()), new TokenSequenceToFeature(), new TokenSequence2FeatureVectorSequence(alphabet, true, true), target2LabelSequence, new LabelSequenceToLabelsAssignment(alphabet, targetAlphabet)));
    }

    private static List<TokenWindow> makeShapeNeighs() {
        return ImmutableList.of(new TokenWindow(-5, 5), new TokenWindow(-4, 4), new TokenWindow(-3, 3), new TokenWindow(-2, 2), new TokenWindow(-1, 1), new TokenWindow(1, 1), new TokenWindow(1, 2), new TokenWindow(1, 3), new TokenWindow(1, 4), new TokenWindow(1, 5));
    }

    private static List<TokenWindow> makeNeighbors() {
        return ImmutableList.of(new TokenWindow(1, 1), new TokenWindow(2, 1), new TokenWindow(3, 1), new TokenWindow(-1, 1), new TokenWindow(-2, 1), new TokenWindow(-3, 1), new TokenWindow(-2, 2));
    }
}
