package cc.mallet.fst.semi_supervised;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.LogNumber;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.lang3.StringUtils;

/* loaded from: input_file:cc/mallet/fst/semi_supervised/GELattice.class */
public class GELattice {
    protected int latticeLength;
    protected Transducer transducer;
    protected int numStates;
    protected LatticeNode[][] lattice;
    protected LogNumber[][][] dotCache;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:cc/mallet/fst/semi_supervised/GELattice$LatticeNode.class */
    public class LatticeNode {
        protected LogNumber[] alpha;
        protected LogNumber[] beta;

        public LatticeNode() {
            this.alpha = new LogNumber[GELattice.this.numStates];
            this.beta = new LogNumber[GELattice.this.numStates];
            for (int i = 0; i < GELattice.this.numStates; i++) {
                this.alpha[i] = new LogNumber(Double.NEGATIVE_INFINITY, true);
                this.beta[i] = new LogNumber(Double.NEGATIVE_INFINITY, true);
            }
        }
    }

    static {
        $assertionsDisabled = !GELattice.class.desiredAssertionStatus();
    }

    public GELattice(FeatureVectorSequence featureVectorSequence, double[][] dArr, double[][][] dArr2, Transducer transducer, int[][] iArr, int[][] iArr2, CRF.Factors factors, ArrayList<GEConstraint> arrayList, boolean z) {
        if (!$assertionsDisabled && factors == null) {
            throw new AssertionError();
        }
        this.latticeLength = featureVectorSequence.size() + 1;
        this.transducer = transducer;
        this.numStates = transducer.numStates();
        this.lattice = new LatticeNode[this.latticeLength][this.numStates];
        for (int i = 0; i < this.latticeLength; i++) {
            for (int i2 = 0; i2 < this.numStates; i2++) {
                this.lattice[i][i2] = new LatticeNode();
            }
        }
        this.dotCache = new LogNumber[this.latticeLength][this.numStates][this.numStates];
        ArrayList<GEConstraint> arrayList2 = new ArrayList<>();
        ArrayList<GEConstraint> arrayList3 = new ArrayList<>();
        Iterator<GEConstraint> it = arrayList.iterator();
        while (it.hasNext()) {
            GEConstraint next = it.next();
            if (next.isOneStateConstraint()) {
                arrayList2.add(next);
            } else {
                arrayList3.add(next);
            }
        }
        CRF crf = (CRF) transducer;
        runBackward(crf, dArr, dArr2, iArr, iArr2, featureVectorSequence, runForward(crf, arrayList2, arrayList3, dArr, dArr2, iArr, featureVectorSequence), factors);
    }

    private double runForward(CRF crf, ArrayList<GEConstraint> arrayList, ArrayList<GEConstraint> arrayList2, double[][] dArr, double[][][] dArr2, int[][] iArr, FeatureVectorSequence featureVectorSequence) {
        double d = 0.0d;
        LogNumber[] logNumberArr = new LogNumber[this.numStates];
        LogNumber logNumber = new LogNumber(Double.NEGATIVE_INFINITY, true);
        LogNumber logNumber2 = new LogNumber(Double.NEGATIVE_INFINITY, true);
        for (int i = 0; i < this.latticeLength - 1; i++) {
            FeatureVector featureVector = featureVectorSequence.get(i);
            Iterator<GEConstraint> it = arrayList.iterator();
            while (it.hasNext()) {
                it.next().preProcess(featureVector);
            }
            Iterator<GEConstraint> it2 = arrayList2.iterator();
            while (it2.hasNext()) {
                it2.next().preProcess(featureVector);
            }
            boolean[] zArr = new boolean[this.numStates];
            for (int i2 = 0; i2 < this.numStates; i2++) {
                logNumber.set(Double.NEGATIVE_INFINITY, true);
                if (i != 0) {
                    for (int i3 : iArr[i2]) {
                        logNumber.plusEquals(this.lattice[i - 1][i3].alpha[i2]);
                    }
                }
                if (!$assertionsDisabled && Double.isNaN(logNumber.logVal)) {
                    throw new AssertionError();
                }
                CRF.State state = (CRF.State) crf.getState(i2);
                LatticeNode latticeNode = this.lattice[i][i2];
                double[] dArr3 = dArr2[i][i2];
                double d2 = dArr[i][i2];
                for (int i4 = 0; i4 < state.numDestinations(); i4++) {
                    int index = state.getDestinationState(i4).getIndex();
                    double d3 = 0.0d;
                    Iterator<GEConstraint> it3 = arrayList2.iterator();
                    while (it3.hasNext()) {
                        d3 += it3.next().getCompositeConstraintFeatureValue(featureVector, i, i2, index);
                    }
                    if (!zArr[index]) {
                        double d4 = 0.0d;
                        Iterator<GEConstraint> it4 = arrayList.iterator();
                        while (it4.hasNext()) {
                            d4 += it4.next().getCompositeConstraintFeatureValue(featureVector, i, i2, index);
                        }
                        if (d4 < 0.0d) {
                            d += Math.exp(dArr[i + 1][index]) * d4;
                            logNumberArr[index] = new LogNumber(Math.log(-d4), false);
                        } else if (d4 > 0.0d) {
                            d += Math.exp(dArr[i + 1][index]) * d4;
                            logNumberArr[index] = new LogNumber(Math.log(d4), true);
                        } else {
                            logNumberArr[index] = null;
                        }
                        zArr[index] = true;
                    }
                    if (d3 == 0.0d && logNumberArr[index] == null) {
                        this.dotCache[i][i2][index] = null;
                    } else if (d3 != 0.0d || logNumberArr[index] == null) {
                        d += Math.exp(dArr3[index]) * d3;
                        if (d3 < 0.0d) {
                            this.dotCache[i][i2][index] = new LogNumber(Math.log(-d3), false);
                        } else {
                            this.dotCache[i][i2][index] = new LogNumber(Math.log(d3), true);
                        }
                        if (logNumberArr[index] != null) {
                            this.dotCache[i][i2][index].plusEquals(logNumberArr[index]);
                        }
                    } else {
                        this.dotCache[i][i2][index] = logNumberArr[index];
                    }
                    if (this.dotCache[i][i2][index] != null) {
                        logNumber2.set(dArr3[index], true);
                        logNumber2.timesEquals(this.dotCache[i][i2][index]);
                        latticeNode.alpha[index].plusEquals(logNumber2);
                    }
                    if (d2 == Double.NEGATIVE_INFINITY) {
                        latticeNode.alpha[index] = new LogNumber(Double.NEGATIVE_INFINITY, true);
                    } else {
                        logNumber2.set(dArr3[index] - d2, true);
                        logNumber2.timesEquals(logNumber);
                        latticeNode.alpha[index].plusEquals(logNumber2);
                    }
                    if (!$assertionsDisabled && Double.isNaN(latticeNode.alpha[index].logVal)) {
                        throw new AssertionError("xi: " + dArr3[index] + ", gamma: " + d2 + ", constraint feature: " + this.dotCache[i][i2][index] + ", nuApha: " + logNumber + " dot: " + d3);
                    }
                }
            }
        }
        return d;
    }

    private void runBackward(CRF crf, double[][] dArr, double[][][] dArr2, int[][] iArr, int[][] iArr2, FeatureVectorSequence featureVectorSequence, double d, CRF.Factors factors) {
        LogNumber logNumber = new LogNumber(Double.NEGATIVE_INFINITY, true);
        LogNumber logNumber2 = new LogNumber(Double.NEGATIVE_INFINITY, true);
        LogNumber logNumber3 = new LogNumber(Double.NEGATIVE_INFINITY, true);
        LogNumber logNumber4 = new LogNumber(Double.NEGATIVE_INFINITY, true);
        for (int i = this.latticeLength - 2; i >= 0; i--) {
            for (int i2 = 0; i2 < this.numStates; i2++) {
                logNumber.set(Double.NEGATIVE_INFINITY, true);
                logNumber2.set(Double.NEGATIVE_INFINITY, true);
                CRF.State state = (CRF.State) crf.getState(i2);
                for (int i3 = 0; i3 < state.numDestinations(); i3++) {
                    int index = state.getDestinationState(i3).getIndex();
                    logNumber.plusEquals(this.lattice[i + 1][i2].beta[index]);
                    if (!$assertionsDisabled && Double.isNaN(logNumber.logVal)) {
                        throw new AssertionError();
                    }
                    LogNumber logNumber5 = this.dotCache[i + 1][i2][index];
                    if (logNumber5 != null) {
                        logNumber3.set(dArr2[i + 1][i2][index], true);
                        logNumber3.timesEquals(logNumber5);
                        logNumber2.plusEquals(logNumber3);
                    }
                }
                double d2 = dArr[i + 1][i2];
                int[] iArr3 = iArr[i2];
                for (int i4 = 0; i4 < iArr3.length; i4++) {
                    int i5 = iArr3[i4];
                    CRF.State state2 = (CRF.State) crf.getState(i5);
                    LatticeNode latticeNode = this.lattice[i][i5];
                    double d3 = dArr2[i][i5][i2];
                    if (d2 == Double.NEGATIVE_INFINITY) {
                        latticeNode.beta[i2] = new LogNumber(Double.NEGATIVE_INFINITY, true);
                    } else {
                        logNumber3.set(logNumber2.logVal, logNumber2.sign);
                        logNumber3.plusEquals(logNumber);
                        logNumber4.set(d3 - d2, true);
                        logNumber3.timesEquals(logNumber4);
                        latticeNode.beta[i2].plusEquals(logNumber3);
                    }
                    if (!$assertionsDisabled && Double.isNaN(latticeNode.beta[i2].logVal)) {
                        throw new AssertionError("xi: " + d3 + ", gamma: " + d2 + ", xi: " + d3 + ", log(indicatorFeat): " + this.dotCache[i][i2]);
                    }
                    double exp = (latticeNode.alpha[i2].exp() + latticeNode.beta[i2].exp()) - (Math.exp(d3) * d);
                    int length = state2.getWeightNames(iArr2[i2][i4]).length;
                    for (int i6 = 0; i6 < length; i6++) {
                        int weightsIndex = ((CRF) this.transducer).getWeightsIndex(state2.getWeightNames(iArr2[i2][i4])[i6]);
                        factors.weights[weightsIndex].plusEqualsSparse(featureVectorSequence.get(i), exp);
                        double[] dArr3 = factors.defaultWeights;
                        dArr3[weightsIndex] = dArr3[weightsIndex] + exp;
                    }
                }
            }
        }
    }

    public void check(ArrayList<GEConstraint> arrayList, double[][] dArr, double[][][] dArr2, FeatureVectorSequence featureVectorSequence) {
        double d = 0.0d;
        for (int i = 0; i < this.latticeLength - 1; i++) {
            for (int i2 = 0; i2 < this.numStates; i2++) {
                for (int i3 = 0; i3 < this.numStates; i3++) {
                    double d2 = 0.0d;
                    Iterator<GEConstraint> it = arrayList.iterator();
                    while (it.hasNext()) {
                        d2 += it.next().getCompositeConstraintFeatureValue(featureVectorSequence.get(i), i, i2, i3);
                    }
                    d += Math.exp(dArr2[i][i2][i3]) * d2;
                }
            }
        }
        double d3 = 0.0d;
        for (int i4 = 0; i4 < this.latticeLength - 1; i4++) {
            double d4 = 0.0d;
            for (int i5 = 0; i5 < this.numStates; i5++) {
                LatticeNode latticeNode = this.lattice[i4][i5];
                for (int i6 = 0; i6 < this.numStates; i6++) {
                    d4 += latticeNode.alpha[i6].exp() + latticeNode.beta[i6].exp();
                }
            }
            if (!$assertionsDisabled && d - d4 >= 1.0E-6d) {
                throw new AssertionError(String.valueOf(d) + StringUtils.SPACE + d4);
            }
            d3 += d4;
        }
        double d5 = d3 / (this.latticeLength - 1);
        if (!$assertionsDisabled && d - d5 >= 1.0E-6d) {
            throw new AssertionError(String.valueOf(d) + StringUtils.SPACE + d5);
        }
    }

    public LogNumber getAlpha(int i, int i2, int i3) {
        return this.lattice[i][i2].alpha[i3];
    }

    public LogNumber getBeta(int i, int i2, int i3) {
        return this.lattice[i][i2].beta[i3];
    }
}
