package com.github.steveash.jg2p.rerank;

import cc.mallet.classify.RankMaxEntTrainer;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import com.github.steveash.jg2p.syll.PhoneSyllTagModel;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/steveash/jg2p/rerank/Rerank3Trainer.class */
public class Rerank3Trainer {
    private static final Logger log = LoggerFactory.getLogger(Rerank3Trainer.class);
    private Pipe pipe = null;
    private PhoneSyllTagModel phoneSyllModel = null;

    public void setPhoneSyllModel(PhoneSyllTagModel phoneSyllTagModel) {
        this.phoneSyllModel = phoneSyllTagModel;
    }

    public Rerank3Model trainFor(Collection<List<RerankExample>> collection) {
        this.pipe = makePipe();
        return new Rerank3Model(new RankMaxEntTrainer(10.0d).train(convert(collection)));
    }

    private InstanceList convert(Collection<List<RerankExample>> collection) {
        InstanceList instanceList = new InstanceList(this.pipe, collection.size());
        int i = 0;
        for (List<RerankExample> list : collection) {
            instanceList.addThruPipe(new Instance(list, 1, (Object) null, list.get(0).getWordGraphs()));
            i++;
            if (i % 10000 == 0) {
                log.info("Loaded " + i + " instances ...");
            }
        }
        log.info("Loaded all " + instanceList.size() + " instances");
        return instanceList;
    }

    private Pipe makePipe() {
        Alphabet alphabet = new Alphabet();
        LabelAlphabet labelAlphabet = new LabelAlphabet();
        ArrayList newArrayList = Lists.newArrayList(new RerankFeature[]{new DupsPipe(), new ModePipe(), new PrefixPipe(), new RanksPipe(), new ScoresPipe(), new ShapePipe(), new ShapePrefixPipe()});
        if (this.phoneSyllModel != null) {
            newArrayList.add(new SyllAgreeRerankFeature(this.phoneSyllModel));
            log.info("Using the syll phone tagger in the reranker");
        }
        return new SerialPipes(ImmutableList.of(new LoadTargetPipe(alphabet, labelAlphabet), new RerankFeaturePipe(alphabet, labelAlphabet, newArrayList)));
    }
}
