package de.jungblut.nlp;

import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import java.util.stream.Stream;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:de/jungblut/nlp/MarkovChain.class */
public final class MarkovChain {
    private final DoubleMatrix transitionProbabilities;
    private final int numStates;

    private MarkovChain(int i) {
        this(i, new SparseDoubleRowMatrix(i, i));
    }

    private MarkovChain(int i, DoubleMatrix doubleMatrix) {
        this.numStates = i;
        this.transitionProbabilities = doubleMatrix;
    }

    public void train(Stream<int[]> stream) {
        Preconditions.checkArgument(!stream.isParallel(), "parallel streams are not supported");
        stream.forEach(iArr -> {
            for (int i = 0; i < iArr.length - 1; i++) {
                this.transitionProbabilities.set(iArr[i], iArr[i + 1], ((int) this.transitionProbabilities.get(iArr[i], iArr[i + 1])) + 1);
            }
        });
        for (int i : this.transitionProbabilities.rowIndices()) {
            DoubleVector rowVector = this.transitionProbabilities.getRowVector(i);
            double sum = rowVector.sum();
            Iterator iterateNonZero = rowVector.iterateNonZero();
            while (iterateNonZero.hasNext()) {
                DoubleVector.DoubleVectorElement doubleVectorElement = (DoubleVector.DoubleVectorElement) iterateNonZero.next();
                this.transitionProbabilities.set(i, doubleVectorElement.getIndex(), FastMath.log(doubleVectorElement.getValue()) - FastMath.log(sum));
            }
        }
    }

    public double getProbabilityForSequence(int[] iArr) {
        DoubleVector transitionProbabilities = getTransitionProbabilities(iArr);
        double max = transitionProbabilities.max();
        double d = 0.0d;
        for (int i = 0; i < transitionProbabilities.getDimension(); i++) {
            d += transitionProbabilities.get(i) - max;
        }
        return FastMath.exp(d);
    }

    public double averageTransitionProbability(int[] iArr) {
        return FastMath.exp(getTransitionProbabilities(iArr).sum() / Math.max(1.0d, r0.getLength()));
    }

    public DoubleVector getTransitionProbabilities(int[] iArr) {
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(iArr.length - 1);
        for (int i = 0; i < denseDoubleVector.getDimension(); i++) {
            denseDoubleVector.set(i, this.transitionProbabilities.get(iArr[i], iArr[i + 1]));
        }
        return denseDoubleVector;
    }

    public int[] completeStateSequence(Optional<Random> optional, int[] iArr, int... iArr2) {
        Arrays.sort(iArr2);
        for (int i : iArr2) {
            if (i == 0) {
                if (i + 1 >= iArr.length) {
                    throw new IllegalArgumentException("Can't guess state " + i + " in " + Arrays.toString(iArr));
                }
                if (optional.isPresent()) {
                    iArr[i] = chooseState((Random) optional.get(), this.transitionProbabilities.getColumnVector(iArr[i + 1]));
                } else {
                    iArr[i] = this.transitionProbabilities.getColumnVector(iArr[i + 1]).maxIndex();
                }
            } else if (optional.isPresent()) {
                iArr[i] = chooseState((Random) optional.get(), this.transitionProbabilities.getRowVector(iArr[i - 1]));
            } else {
                iArr[i] = this.transitionProbabilities.getRowVector(iArr[i - 1]).maxIndex();
            }
        }
        return iArr;
    }

    private static int chooseState(Random random, DoubleVector doubleVector) {
        double nextDouble = random.nextDouble();
        Iterator iterateNonZero = doubleVector.iterateNonZero();
        while (iterateNonZero.hasNext()) {
            DoubleVector.DoubleVectorElement doubleVectorElement = (DoubleVector.DoubleVectorElement) iterateNonZero.next();
            if (nextDouble <= Math.exp(doubleVectorElement.getValue())) {
                return doubleVectorElement.getIndex();
            }
        }
        return doubleVector.maxIndex();
    }

    public DoubleMatrix getTransitionProbabilities() {
        return this.transitionProbabilities;
    }

    public int getNumStates() {
        return this.numStates;
    }

    public static MarkovChain create(int i) {
        return new MarkovChain(i);
    }

    public static MarkovChain create(int i, DoubleMatrix doubleMatrix) {
        return new MarkovChain(i, doubleMatrix);
    }
}
