package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import java.io.Serializable;
import java.text.DecimalFormat;

/* loaded from: input_file:cc/mallet/fst/MEMM.class */
public class MEMM extends CRF implements Serializable {

    /* loaded from: input_file:cc/mallet/fst/MEMM$State.class */
    public static class State extends CRF.State implements Serializable {
        InstanceList trainingSet;

        protected State(String str, int i, double d, double d2, String[] strArr, String[] strArr2, String[][] strArr3, CRF crf) {
            super(str, i, d, d2, strArr, strArr2, strArr3, crf);
        }

        @Override // cc.mallet.fst.CRF.State, cc.mallet.fst.Transducer.State
        public Transducer.TransitionIterator transitionIterator(Sequence sequence, int i, Sequence sequence2, int i2) {
            if (i < 0 || i2 < 0) {
                throw new UnsupportedOperationException("Epsilon transitions not implemented.");
            }
            if (sequence == null) {
                throw new UnsupportedOperationException("CRFs are not generative models; must have an input sequence.");
            }
            return new TransitionIterator(this, (FeatureVectorSequence) sequence, i, sequence2 == null ? null : (String) sequence2.get(i2), this.crf);
        }
    }

    /* loaded from: input_file:cc/mallet/fst/MEMM$TransitionIterator.class */
    protected static class TransitionIterator extends CRF.TransitionIterator implements Serializable {
        private double sum;
        static final /* synthetic */ boolean $assertionsDisabled;

        static {
            $assertionsDisabled = !MEMM.class.desiredAssertionStatus();
        }

        public TransitionIterator(State state, FeatureVectorSequence featureVectorSequence, int i, String str, CRF crf) {
            super(state, featureVectorSequence, i, str, crf);
            normalizeCosts();
        }

        public TransitionIterator(State state, FeatureVector featureVector, String str, CRF crf) {
            super(state, featureVector, str, crf);
            normalizeCosts();
        }

        private void normalizeCosts() {
            this.sum = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < this.weights.length; i++) {
                this.sum = MEMM.sumLogProb(this.sum, this.weights[i]);
            }
            if (!$assertionsDisabled && Double.isNaN(this.sum)) {
                throw new AssertionError();
            }
            if (Double.isInfinite(this.sum)) {
                return;
            }
            for (int i2 = 0; i2 < this.weights.length; i2++) {
                double[] dArr = this.weights;
                int i3 = i2;
                dArr[i3] = dArr[i3] - this.sum;
            }
        }

        @Override // cc.mallet.fst.CRF.TransitionIterator, cc.mallet.fst.Transducer.TransitionIterator
        public String describeTransition(double d) {
            return String.valueOf(super.describeTransition(d)) + "Log Z = " + new DecimalFormat("0.###").format(this.sum) + "\n";
        }
    }

    public MEMM(Pipe pipe, Pipe pipe2) {
        super(pipe, pipe2);
    }

    public MEMM(Alphabet alphabet, Alphabet alphabet2) {
        super(alphabet, alphabet2);
    }

    public MEMM(CRF crf) {
        super(crf);
    }

    @Override // cc.mallet.fst.CRF
    protected CRF.State newState(String str, int i, double d, double d2, String[] strArr, String[] strArr2, String[][] strArr3, CRF crf) {
        return new State(str, i, d, d2, strArr, strArr2, strArr3, crf);
    }
}
