package cc.mallet.fst;

import cc.mallet.fst.Transducer;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/fst/SumLatticeScaling.class */
public class SumLatticeScaling implements SumLattice {
    private static Logger logger;
    protected static boolean saveXis;
    Sequence input;
    Sequence output;
    Transducer t;
    double totalWeight;
    LatticeNode[][] nodes;
    double[] alphaLogScaling;
    double[] betaLogScaling;
    double zLogScaling;
    int latticeLength;
    double[][] gammas;
    double[][][] xis;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:cc/mallet/fst/SumLatticeScaling$Factory.class */
    public static class Factory extends SumLatticeFactory implements Serializable {
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 1;

        @Override // cc.mallet.fst.SumLatticeFactory
        public SumLattice newSumLattice(Transducer transducer, Sequence sequence, Sequence sequence2, Transducer.Incrementor incrementor, boolean z, LabelAlphabet labelAlphabet) {
            return new SumLatticeScaling(transducer, sequence, sequence2, incrementor, z, labelAlphabet);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.writeInt(1);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.readInt();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:cc/mallet/fst/SumLatticeScaling$LatticeNode.class */
    public class LatticeNode {
        int inputPosition;
        Transducer.State state;
        Object output;
        double alpha = Double.NaN;
        double beta = Double.NaN;

        LatticeNode(int i, Transducer.State state) {
            this.inputPosition = i;
            this.state = state;
        }
    }

    static {
        $assertionsDisabled = !SumLatticeScaling.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(SumLatticeScaling.class.getName());
        saveXis = false;
    }

    protected SumLatticeScaling() {
    }

    protected LatticeNode getLatticeNode(int i, int i2) {
        if (this.nodes[i][i2] == null) {
            this.nodes[i][i2] = new LatticeNode(i, this.t.getState(i2));
        }
        return this.nodes[i][i2];
    }

    public SumLatticeScaling(Transducer transducer, Sequence sequence) {
        this(transducer, sequence, null, null, saveXis, null);
    }

    public SumLatticeScaling(Transducer transducer, Sequence sequence, boolean z) {
        this(transducer, sequence, null, null, z, null);
    }

    public SumLatticeScaling(Transducer transducer, Sequence sequence, Transducer.Incrementor incrementor) {
        this(transducer, sequence, null, incrementor, saveXis, null);
    }

    public SumLatticeScaling(Transducer transducer, Sequence sequence, Sequence sequence2) {
        this(transducer, sequence, sequence2, null, saveXis, null);
    }

    public SumLatticeScaling(Transducer transducer, Sequence sequence, Sequence sequence2, Transducer.Incrementor incrementor) {
        this(transducer, sequence, sequence2, incrementor, saveXis, null);
    }

    public SumLatticeScaling(Transducer transducer, Sequence sequence, Sequence sequence2, Transducer.Incrementor incrementor, LabelAlphabet labelAlphabet) {
        this(transducer, sequence, sequence2, incrementor, saveXis, labelAlphabet);
    }

    public SumLatticeScaling(Transducer transducer, Sequence sequence, Sequence sequence2, Transducer.Incrementor incrementor, boolean z) {
        this(transducer, sequence, sequence2, incrementor, z, null);
    }

    public SumLatticeScaling(Transducer transducer, Sequence sequence, Sequence sequence2, Transducer.Incrementor incrementor, boolean z, LabelAlphabet labelAlphabet) {
        if (!$assertionsDisabled && sequence2 != null && sequence.size() != sequence2.size()) {
            throw new AssertionError();
        }
        this.t = transducer;
        this.input = sequence;
        this.output = sequence2;
        this.latticeLength = sequence.size() + 1;
        int numStates = this.t.numStates();
        this.nodes = new LatticeNode[this.latticeLength][numStates];
        this.alphaLogScaling = new double[this.latticeLength];
        this.betaLogScaling = new double[this.latticeLength];
        this.gammas = new double[this.latticeLength][numStates];
        if (z) {
            this.xis = new double[this.latticeLength][numStates][numStates];
        }
        double[][] dArr = labelAlphabet != null ? new double[this.latticeLength][labelAlphabet.size()] : null;
        for (int i = 0; i < this.latticeLength; i++) {
            this.alphaLogScaling[i] = 0.0d;
            this.betaLogScaling[i] = 0.0d;
            for (int i2 = 0; i2 < numStates; i2++) {
                this.gammas[i][i2] = Double.NEGATIVE_INFINITY;
                if (z) {
                    for (int i3 = 0; i3 < numStates; i3++) {
                        this.xis[i][i2][i3] = Double.NEGATIVE_INFINITY;
                    }
                }
            }
        }
        logger.fine("Starting Foward pass");
        boolean z2 = false;
        for (int i4 = 0; i4 < numStates; i4++) {
            double initialWeight = this.t.getState(i4).getInitialWeight();
            if (initialWeight > Double.NEGATIVE_INFINITY) {
                getLatticeNode(0, i4).alpha = Math.exp(initialWeight);
                z2 = true;
            }
        }
        rescaleAlphas(0);
        if (!z2) {
            logger.warning("There are no starting states!");
        }
        for (int i5 = 0; i5 < this.latticeLength - 1; i5++) {
            for (int i6 = 0; i6 < numStates; i6++) {
                if (!isInvalidNode(i5, i6)) {
                    Transducer.TransitionIterator transitionIterator = this.t.getState(i6).transitionIterator(sequence, i5, sequence2, i5);
                    while (transitionIterator.hasNext()) {
                        LatticeNode latticeNode = getLatticeNode(i5 + 1, transitionIterator.next().getIndex());
                        if (Double.isNaN(latticeNode.alpha)) {
                            latticeNode.alpha = 0.0d;
                        }
                        latticeNode.output = transitionIterator.getOutput();
                        latticeNode.alpha += this.nodes[i5][i6].alpha * Math.exp(transitionIterator.getWeight());
                    }
                }
            }
            rescaleAlphas(i5 + 1);
        }
        double d = Double.NaN;
        for (int i7 = 0; i7 < numStates; i7++) {
            if (this.nodes[this.latticeLength - 1][i7] != null) {
                d = (Double.isNaN(d) ? 0.0d : d) + (this.nodes[this.latticeLength - 1][i7].alpha * Math.exp(this.t.getState(i7).getFinalWeight()));
            }
        }
        this.zLogScaling = this.alphaLogScaling[this.latticeLength - 1];
        if (Double.isNaN(d)) {
            this.totalWeight = Double.NEGATIVE_INFINITY;
            return;
        }
        this.totalWeight = Math.log(d) + this.zLogScaling;
        for (int i8 = 0; i8 < numStates; i8++) {
            if (this.nodes[this.latticeLength - 1][i8] != null) {
                Transducer.State state = this.t.getState(i8);
                this.nodes[this.latticeLength - 1][i8].beta = Math.exp(state.getFinalWeight());
                double d2 = (this.nodes[this.latticeLength - 1][i8].alpha * this.nodes[this.latticeLength - 1][i8].beta) / d;
                this.gammas[this.latticeLength - 1][i8] = Math.log(d2);
                if (incrementor == null) {
                    continue;
                } else {
                    if (!$assertionsDisabled && (d2 < 0.0d || d2 > 1.000001d)) {
                        throw new AssertionError("p=" + d2 + ", gamma=" + this.gammas[this.latticeLength - 1][i8]);
                    }
                    incrementor.incrementFinalState(state, d2);
                }
            }
        }
        rescaleBetas(this.latticeLength - 1);
        for (int i9 = this.latticeLength - 2; i9 >= 0; i9--) {
            for (int i10 = 0; i10 < numStates; i10++) {
                if (!isInvalidNode(i9, i10)) {
                    Transducer.TransitionIterator transitionIterator2 = this.t.getState(i10).transitionIterator(sequence, i9, sequence2, i9);
                    double d3 = (this.alphaLogScaling[i9] + this.betaLogScaling[i9 + 1]) - this.zLogScaling;
                    double exp = Math.exp(d3);
                    while (transitionIterator2.hasNext()) {
                        int index = transitionIterator2.next().getIndex();
                        LatticeNode latticeNode2 = this.nodes[i9 + 1][index];
                        if (latticeNode2 != null) {
                            double weight = transitionIterator2.getWeight();
                            if (Double.isNaN(this.nodes[i9][i10].beta)) {
                                this.nodes[i9][i10].beta = 0.0d;
                            }
                            double exp2 = Math.exp(weight);
                            this.nodes[i9][i10].beta += latticeNode2.beta * exp2;
                            double d4 = ((this.nodes[i9][i10].alpha * exp2) * this.nodes[i9 + 1][index].beta) / d;
                            if (z) {
                                this.xis[i9][i10][index] = Math.log(d4) + d3;
                            }
                            if (incrementor != null || labelAlphabet != null) {
                                double d5 = d4 * exp;
                                if (!$assertionsDisabled && (d5 < 0.0d || d5 > 1.000001d)) {
                                    throw new AssertionError("p=" + d5 + ", xis[" + i9 + "][" + i10 + "][" + index + "]=" + d4);
                                }
                                if (incrementor != null) {
                                    incrementor.incrementTransition(transitionIterator2, d5);
                                }
                                if (labelAlphabet != null) {
                                    int lookupIndex = labelAlphabet.lookupIndex(transitionIterator2.getOutput(), false);
                                    if (!$assertionsDisabled && lookupIndex < 0) {
                                        throw new AssertionError();
                                    }
                                    double[] dArr2 = dArr[i9];
                                    dArr2[lookupIndex] = dArr2[lookupIndex] + d5;
                                } else {
                                    continue;
                                }
                            }
                        }
                    }
                    this.gammas[i9][i10] = Math.log((this.nodes[i9][i10].alpha * this.nodes[i9][i10].beta) / d) + d3;
                }
            }
            rescaleBetas(i9);
        }
        if (incrementor != null) {
            for (int i11 = 0; i11 < numStates; i11++) {
                double exp3 = Math.exp(this.gammas[0][i11]);
                if (!$assertionsDisabled && (exp3 < 0.0d || exp3 > 1.000001d)) {
                    throw new AssertionError("p=" + exp3);
                }
                incrementor.incrementInitialState(this.t.getState(i11), exp3);
            }
        }
    }

    private boolean isInvalidNode(int i, int i2) {
        return this.nodes[i][i2] == null || Double.isNaN(this.nodes[i][i2].alpha);
    }

    private void rescaleAlphas(int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.t.numStates(); i2++) {
            if (!isInvalidNode(i, i2)) {
                d += this.nodes[i][i2].alpha;
            }
        }
        if (!$assertionsDisabled && d <= 0.0d) {
            throw new AssertionError("Invalid sum over alphas for ip=" + i);
        }
        this.alphaLogScaling[i] = Math.log(d) + (i == 0 ? 0.0d : this.alphaLogScaling[i - 1]);
        for (int i3 = 0; i3 < this.t.numStates(); i3++) {
            if (!isInvalidNode(i, i3)) {
                this.nodes[i][i3].alpha /= d;
            }
        }
    }

    private void rescaleBetas(int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.t.numStates(); i2++) {
            if (!isInvalidNode(i, i2)) {
                d += this.nodes[i][i2].beta;
            }
        }
        if (!$assertionsDisabled && d <= 0.0d) {
            throw new AssertionError("Invalid sum over betas for ip=" + i);
        }
        this.betaLogScaling[i] = Math.log(d) + (i == this.latticeLength - 1 ? 0.0d : this.betaLogScaling[i + 1]);
        for (int i3 = 0; i3 < this.t.numStates(); i3++) {
            if (!isInvalidNode(i, i3)) {
                this.nodes[i][i3].beta /= d;
            }
        }
    }

    @Override // cc.mallet.fst.SumLattice
    public double[][][] getXis() {
        return this.xis;
    }

    @Override // cc.mallet.fst.SumLattice
    public double[][] getGammas() {
        return this.gammas;
    }

    @Override // cc.mallet.fst.SumLattice
    public double getTotalWeight() {
        return this.totalWeight;
    }

    @Override // cc.mallet.fst.SumLattice
    public double getGammaWeight(int i, Transducer.State state) {
        return this.gammas[i][state.getIndex()];
    }

    public double getGammaWeight(int i, int i2) {
        return this.gammas[i][i2];
    }

    @Override // cc.mallet.fst.SumLattice
    public double getGammaProbability(int i, Transducer.State state) {
        return Math.exp(this.gammas[i][state.getIndex()]);
    }

    public double getGammaProbability(int i, int i2) {
        return getGammaProbability(i, this.t.getState(i2));
    }

    @Override // cc.mallet.fst.SumLattice
    public double getXiProbability(int i, Transducer.State state, Transducer.State state2) {
        return Math.exp(getXiWeight(i, state, state2));
    }

    @Override // cc.mallet.fst.SumLattice
    public double getXiWeight(int i, Transducer.State state, Transducer.State state2) {
        if (this.xis == null) {
            throw new IllegalStateException("xis were not saved.");
        }
        int index = state.getIndex();
        return this.xis[i][index][state2.getIndex()];
    }

    @Override // cc.mallet.fst.SumLattice
    public int length() {
        return this.latticeLength;
    }

    @Override // cc.mallet.fst.SumLattice
    public double getAlpha(int i, Transducer.State state) {
        return getLatticeNode(i, state.getIndex()).alpha * Math.exp(this.alphaLogScaling[i]);
    }

    @Override // cc.mallet.fst.SumLattice
    public double getBeta(int i, Transducer.State state) {
        return getLatticeNode(i, state.getIndex()).beta * Math.exp(this.betaLogScaling[i]);
    }

    @Override // cc.mallet.fst.SumLattice
    public LabelVector getLabelingAtPosition(int i) {
        throw new RuntimeException("Not implemented for SumLatticeScaling!");
    }

    @Override // cc.mallet.fst.SumLattice
    public Sequence getInput() {
        return this.input;
    }

    @Override // cc.mallet.fst.SumLattice
    public Transducer getTransducer() {
        return this.t;
    }
}
