package de.jungblut.nlp;

import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractClassifier;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.ViterbiUtils;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import de.jungblut.writable.MatrixWritable;
import de.jungblut.writable.VectorWritable;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import java.util.Random;
import org.apache.commons.math3.util.FastMath;
import org.apache.hadoop.io.Writable;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/jungblut/nlp/HMM.class */
public final class HMM extends AbstractClassifier implements Writable {
    private static final Logger LOG = LogManager.getLogger(HMM.class);
    private int numVisibleStates;
    private int numHiddenStates;
    private DoubleMatrix transitionProbabilityMatrix;
    private DoubleMatrix emissionProbabilityMatrix;
    private DoubleVector hiddenPriorProbability;
    private long seed;

    public HMM() {
        this.seed = System.currentTimeMillis();
    }

    public HMM(int i, int i2) {
        this(i, i2, System.currentTimeMillis());
    }

    HMM(int i, int i2, long j) {
        this.seed = j;
        this.numVisibleStates = i;
        this.numHiddenStates = i2;
        this.transitionProbabilityMatrix = new DenseDoubleMatrix(i2, i2);
        this.emissionProbabilityMatrix = new DenseDoubleMatrix(i2, i);
        this.hiddenPriorProbability = new DenseDoubleVector(i2);
    }

    private void normalizeProbabilities() {
        normalize(this.hiddenPriorProbability, this.transitionProbabilityMatrix, this.emissionProbabilityMatrix, false);
    }

    private void logNormalizeProbabilities() {
        normalize(this.hiddenPriorProbability, this.transitionProbabilityMatrix, this.emissionProbabilityMatrix, true);
    }

    public double estimateLikelihood(DoubleVector[] doubleVectorArr) {
        return estimateLikelihood(forward(new DenseDoubleMatrix(doubleVectorArr.length, this.numHiddenStates), this.transitionProbabilityMatrix, this.emissionProbabilityMatrix, this.hiddenPriorProbability, doubleVectorArr));
    }

    private static double estimateLikelihood(DoubleMatrix doubleMatrix) {
        return doubleMatrix.getRowVector(doubleMatrix.getRowCount() - 1).sum();
    }

    public DoubleMatrix decode(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2) {
        return ViterbiUtils.decode(this.emissionProbabilityMatrix, new SparseDoubleRowMatrix(doubleVectorArr), new SparseDoubleRowMatrix(doubleVectorArr2), this.numHiddenStates);
    }

    public void trainUnsupervised(DoubleVector[] doubleVectorArr, double d, int i, boolean z) {
        Random random = new Random(this.seed);
        this.transitionProbabilityMatrix = new DenseDoubleMatrix(this.numHiddenStates, this.numHiddenStates, random);
        this.emissionProbabilityMatrix = new DenseDoubleMatrix(this.numHiddenStates, this.numVisibleStates, random);
        this.hiddenPriorProbability = new DenseDoubleVector(this.numHiddenStates);
        for (int i2 = 0; i2 < this.numHiddenStates; i2++) {
            this.hiddenPriorProbability.set(i2, random.nextDouble());
        }
        normalizeProbabilities();
        DoubleMatrix denseDoubleMatrix = new DenseDoubleMatrix(doubleVectorArr.length, this.numHiddenStates);
        DoubleMatrix denseDoubleMatrix2 = new DenseDoubleMatrix(doubleVectorArr.length, this.numHiddenStates);
        for (int i3 = 0; i3 < i; i3++) {
            DoubleMatrix deepCopy = this.transitionProbabilityMatrix.deepCopy();
            DoubleMatrix deepCopy2 = this.emissionProbabilityMatrix.deepCopy();
            DoubleVector deepCopy3 = this.hiddenPriorProbability.deepCopy();
            denseDoubleMatrix = forward(denseDoubleMatrix, deepCopy, deepCopy2, deepCopy3, doubleVectorArr);
            denseDoubleMatrix2 = backward(denseDoubleMatrix2, deepCopy, deepCopy2, deepCopy3, doubleVectorArr);
            DoubleVector multiply = denseDoubleMatrix.getRowVector(0).multiply(denseDoubleMatrix2.getRowVector(0));
            double estimateLikelihood = estimateLikelihood(denseDoubleMatrix);
            for (int i4 = 0; i4 < this.numHiddenStates; i4++) {
                for (int i5 = 0; i5 < this.numHiddenStates; i5++) {
                    double d2 = 0.0d;
                    for (int i6 = 0; i6 < doubleVectorArr.length - 1; i6++) {
                        Iterator iterateNonZero = doubleVectorArr[i6 + 1].iterateNonZero();
                        while (iterateNonZero.hasNext()) {
                            d2 += denseDoubleMatrix.get(i6, i4) * deepCopy2.get(i5, ((DoubleVector.DoubleVectorElement) iterateNonZero.next()).getIndex()) * denseDoubleMatrix2.get(i6 + 1, i5);
                        }
                    }
                    deepCopy.set(i4, i5, (deepCopy.get(i4, i5) * d2) / estimateLikelihood);
                }
            }
            for (int i7 = 0; i7 < this.numHiddenStates; i7++) {
                for (int i8 = 0; i8 < this.numVisibleStates; i8++) {
                    double d3 = 0.0d;
                    for (int i9 = 0; i9 < doubleVectorArr.length; i9++) {
                        Iterator iterateNonZero2 = doubleVectorArr[i9].iterateNonZero();
                        while (iterateNonZero2.hasNext()) {
                            if (((DoubleVector.DoubleVectorElement) iterateNonZero2.next()).getIndex() == i8) {
                                d3 += denseDoubleMatrix.get(i9, i7) * denseDoubleMatrix2.get(i9, i7);
                            }
                        }
                    }
                    deepCopy2.set(i7, i8, d3 / estimateLikelihood);
                }
            }
            normalize(multiply, deepCopy, deepCopy2, false);
            double sum = this.transitionProbabilityMatrix.subtract(deepCopy).pow(2.0d).sum() + this.emissionProbabilityMatrix.subtract(deepCopy2).pow(2.0d).sum() + getHiddenPriorProbability().subtract(multiply).pow(2.0d).sum();
            if (z) {
                LOG.info("Iteration " + i3 + " | Model difference: " + sum + "\r");
            }
            this.transitionProbabilityMatrix = deepCopy;
            this.emissionProbabilityMatrix = deepCopy2;
            this.hiddenPriorProbability = multiply;
            if (sum < d) {
                break;
            }
        }
        normalize(this.hiddenPriorProbability, this.transitionProbabilityMatrix, this.emissionProbabilityMatrix, true);
    }

    private static DoubleMatrix backward(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3, DoubleVector doubleVector, DoubleVector[] doubleVectorArr) {
        int columnCount = doubleMatrix.getColumnCount();
        doubleMatrix.setRowVector(doubleVectorArr.length - 1, DenseDoubleVector.ones(columnCount));
        for (int length = doubleVectorArr.length - 2; length >= 0; length--) {
            for (int i = 0; i < columnCount; i++) {
                double d = 0.0d;
                for (int i2 = 0; i2 < columnCount; i2++) {
                    Iterator iterateNonZero = doubleVectorArr[length + 1].iterateNonZero();
                    while (iterateNonZero.hasNext()) {
                        d += doubleMatrix.get(length + 1, i2) * doubleMatrix2.get(i, i2) * doubleMatrix3.get(i2, ((DoubleVector.DoubleVectorElement) iterateNonZero.next()).getIndex());
                    }
                }
                doubleMatrix.set(length, i, d);
            }
        }
        return doubleMatrix;
    }

    private static DoubleMatrix forward(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3, DoubleVector doubleVector, DoubleVector[] doubleVectorArr) {
        double d;
        double d2;
        double d3;
        double d4;
        int columnCount = doubleMatrix.getColumnCount();
        for (int i = 0; i < columnCount; i++) {
            Iterator iterateNonZero = doubleVectorArr[0].iterateNonZero();
            while (true) {
                d4 = d3;
                d3 = iterateNonZero.hasNext() ? d4 + doubleMatrix3.get(i, ((DoubleVector.DoubleVectorElement) iterateNonZero.next()).getIndex()) : 0.0d;
            }
            doubleMatrix.set(0, i, doubleVector.get(i) * d4);
        }
        for (int i2 = 1; i2 < doubleVectorArr.length; i2++) {
            for (int i3 = 0; i3 < columnCount; i3++) {
                double d5 = 0.0d;
                for (int i4 = 0; i4 < columnCount; i4++) {
                    d5 += doubleMatrix.get(i2 - 1, i4) * doubleMatrix2.get(i4, i3);
                }
                Iterator iterateNonZero2 = doubleVectorArr[i2].iterateNonZero();
                while (true) {
                    d2 = d;
                    d = iterateNonZero2.hasNext() ? d2 + doubleMatrix3.get(i3, ((DoubleVector.DoubleVectorElement) iterateNonZero2.next()).getIndex()) : 0.0d;
                }
                doubleMatrix.set(i2, i3, d5 * d2);
            }
        }
        return doubleMatrix;
    }

    private static void normalize(DoubleVector doubleVector, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, boolean z) {
        double sum = doubleVector.sum();
        if (sum != 0.0d) {
            for (int i = 0; i < doubleVector.getDimension(); i++) {
                doubleVector.set(i, doubleVector.get(i) / sum);
            }
        }
        for (int i2 = 0; i2 < doubleMatrix.getRowCount(); i2++) {
            DoubleVector rowVector = doubleMatrix.getRowVector(i2);
            DoubleVector divide = rowVector.divide(rowVector.sum());
            if (z) {
                divide = divide.log();
            }
            doubleMatrix.setRowVector(i2, divide);
            DoubleVector rowVector2 = doubleMatrix2.getRowVector(i2);
            DoubleVector divide2 = rowVector2.divide(rowVector2.sum());
            if (z) {
                divide2 = divide2.log();
            }
            doubleMatrix2.setRowVector(i2, divide2);
        }
    }

    public void trainSupervised(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2) {
        Preconditions.checkArgument(doubleVectorArr.length == doubleVectorArr2.length, "Feature array length must match outcome array length: " + doubleVectorArr.length + " != " + doubleVectorArr2.length);
        Preconditions.checkArgument(doubleVectorArr.length > 0, "Feature array length be at least 1! Given: " + doubleVectorArr.length);
        Preconditions.checkArgument(doubleVectorArr[0].getDimension() == this.numVisibleStates, "Feature vector's dimension must match the number of visible states! Given: " + doubleVectorArr[0].getDimension() + ", but expected " + this.numVisibleStates);
        int dimension = doubleVectorArr2[0].getDimension();
        int i = dimension == 1 ? 2 : this.numHiddenStates;
        Preconditions.checkArgument(dimension == i, "Outcome dimension didn't match the given number of hidden states: " + dimension + " != " + i);
        this.hiddenPriorProbability = this.hiddenPriorProbability.add(1.0d);
        for (int i2 = 0; i2 < this.numHiddenStates; i2++) {
            this.transitionProbabilityMatrix.setRowVector(i2, DenseDoubleVector.ones(this.numHiddenStates));
            this.emissionProbabilityMatrix.setRowVector(i2, DenseDoubleVector.ones(this.numVisibleStates));
        }
        for (int i3 = 0; i3 < doubleVectorArr.length; i3++) {
            DoubleVector doubleVector = doubleVectorArr[i3];
            int outcomeState = getOutcomeState(doubleVectorArr2[i3]);
            this.hiddenPriorProbability.set(outcomeState, this.hiddenPriorProbability.get(outcomeState) + 1.0d);
            Iterator iterateNonZero = doubleVector.iterateNonZero();
            while (iterateNonZero.hasNext()) {
                DoubleVector.DoubleVectorElement doubleVectorElement = (DoubleVector.DoubleVectorElement) iterateNonZero.next();
                this.emissionProbabilityMatrix.set(outcomeState, doubleVectorElement.getIndex(), this.emissionProbabilityMatrix.get(outcomeState, doubleVectorElement.getIndex()) + 1.0d);
            }
            if (i3 + 1 < doubleVectorArr.length) {
                int outcomeState2 = getOutcomeState(doubleVectorArr2[i3 + 1]);
                this.transitionProbabilityMatrix.set(outcomeState, outcomeState2, this.transitionProbabilityMatrix.get(outcomeState, outcomeState2) + 1.0d);
            }
        }
        logNormalizeProbabilities();
    }

    @Override // de.jungblut.classification.AbstractClassifier, de.jungblut.classification.Classifier
    public void train(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2) {
        trainSupervised(doubleVectorArr, doubleVectorArr2);
    }

    @Override // de.jungblut.classification.Predictor
    public DoubleVector predict(DoubleVector doubleVector) {
        DoubleVector multiplyVectorRow = this.emissionProbabilityMatrix.multiplyVectorRow(doubleVector);
        double max = multiplyVectorRow.max();
        for (int i = 0; i < multiplyVectorRow.getDimension(); i++) {
            multiplyVectorRow.set(i, FastMath.exp(multiplyVectorRow.get(i) - max) * this.hiddenPriorProbability.get(i));
        }
        return multiplyVectorRow.divide(multiplyVectorRow.sum());
    }

    public DoubleVector predict(DoubleVector doubleVector, DoubleVector doubleVector2) {
        DoubleVector multiplyVectorRow = this.emissionProbabilityMatrix.multiplyVectorRow(doubleVector);
        multiplyVectorRow.add(this.transitionProbabilityMatrix.multiplyVectorRow(doubleVector2));
        double max = multiplyVectorRow.max();
        for (int i = 0; i < multiplyVectorRow.getDimension(); i++) {
            multiplyVectorRow.set(i, FastMath.exp(multiplyVectorRow.get(i) - max) * this.hiddenPriorProbability.get(i));
        }
        return multiplyVectorRow.divide(multiplyVectorRow.sum());
    }

    public int getNumHiddenStates() {
        return this.numHiddenStates;
    }

    public int getNumVisibleStates() {
        return this.numVisibleStates;
    }

    public DoubleMatrix getEmissionProbabilitiyMatrix() {
        return this.emissionProbabilityMatrix;
    }

    public DoubleVector getHiddenPriorProbability() {
        return this.hiddenPriorProbability;
    }

    public DoubleMatrix getTransitionProbabilityMatrix() {
        return this.transitionProbabilityMatrix;
    }

    private int getOutcomeState(DoubleVector doubleVector) {
        return doubleVector.getDimension() == 2 ? (int) doubleVector.get(0) : doubleVector.maxIndex();
    }

    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this.numVisibleStates);
        dataOutput.writeInt(this.numHiddenStates);
        VectorWritable.writeVector(this.hiddenPriorProbability, dataOutput);
        MatrixWritable.writeDenseMatrix(this.transitionProbabilityMatrix, dataOutput);
        MatrixWritable.writeDenseMatrix(this.emissionProbabilityMatrix, dataOutput);
    }

    public void readFields(DataInput dataInput) throws IOException {
        this.numVisibleStates = dataInput.readInt();
        this.numHiddenStates = dataInput.readInt();
        this.hiddenPriorProbability = VectorWritable.readVector(dataInput);
        this.transitionProbabilityMatrix = MatrixWritable.readDenseMatrix(dataInput);
        this.emissionProbabilityMatrix = MatrixWritable.readDenseMatrix(dataInput);
    }
}
