package com.github.steveash.jg2p.wfst;

import com.github.steveash.jg2p.align.AlignModel;
import com.github.steveash.jg2p.align.Alignment;
import com.github.steveash.jg2p.align.InputRecord;
import com.github.steveash.kylm.model.ngram.NgramLM;
import com.github.steveash.kylm.model.ngram.smoother.MKNSmoother;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/steveash/jg2p/wfst/G2pFstTrainer.class */
public class G2pFstTrainer {
    private static final Logger log = LoggerFactory.getLogger(G2pFstTrainer.class);
    private static final Joiner tieJoiner = Joiner.on(SeqTransducer.SEP);
    private NgramLM lastLm;

    public SeqTransducer alignAndTrain(List<InputRecord> list, AlignModel alignModel, int i) {
        log.info("Preparing training input for WFST trainer");
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(list.size());
        for (InputRecord inputRecord : list) {
            List<Alignment> align = alignModel.align(inputRecord.xWord, inputRecord.yWord, 1);
            if (!align.isEmpty()) {
                newArrayListWithCapacity.add(alignToExample(align.get(0)));
            }
        }
        return trainWithSentences(newArrayListWithCapacity, i);
    }

    public SeqTransducer trainWithAligned(List<Alignment> list, int i) {
        log.info("Preparing training input for WFST trainer");
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(list.size());
        Iterator<Alignment> it = list.iterator();
        while (it.hasNext()) {
            newArrayListWithCapacity.add(alignToExample(it.next()));
        }
        return trainWithSentences(newArrayListWithCapacity, i);
    }

    public SeqTransducer trainWithSentences(List<String[]> list, int i) {
        log.info("Training LM on " + list.size() + " training examples");
        MKNSmoother mKNSmoother = new MKNSmoother();
        mKNSmoother.setSmoothUnigrams(true);
        NgramLM ngramLM = new NgramLM(i, mKNSmoother);
        ngramLM.trainModel(list);
        this.lastLm = ngramLM;
        return new LangModelToFst().fromModel(ngramLM);
    }

    public NgramLM getLastLm() {
        return this.lastLm;
    }

    private String[] alignToExample(Alignment alignment) {
        Iterator<Pair<List<String>, List<String>>> it = alignment.getGraphonesSplit().iterator();
        String[] strArr = new String[alignment.getGraphones().size()];
        for (int i = 0; i < alignment.getGraphones().size(); i++) {
            Pair<List<String>, List<String>> next = it.next();
            strArr[i] = String.valueOf(tieJoiner.join((Iterable) next.getLeft())) + SeqTransducer.GRAPHONE_DELIM + tieJoiner.join((Iterable) next.getRight());
        }
        Preconditions.checkState(!it.hasNext());
        return strArr;
    }
}
