package de.jungblut.classification.bayes;

import de.jungblut.classification.AbstractClassifier;
import de.jungblut.datastructure.Iterables;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import de.jungblut.math.tuple.Tuple;
import de.jungblut.writable.MatrixWritable;
import de.jungblut.writable.VectorWritable;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:de/jungblut/classification/bayes/MultinomialNaiveBayes.class */
public final class MultinomialNaiveBayes extends AbstractClassifier {
    private static final double LOW_PROBABILITY = FastMath.log(1.0E-8d);
    private DoubleMatrix probabilityMatrix;
    private DoubleVector classPriorProbability;
    private boolean verbose;

    public MultinomialNaiveBayes() {
    }

    public MultinomialNaiveBayes(boolean z) {
        this.verbose = z;
    }

    private MultinomialNaiveBayes(DoubleMatrix doubleMatrix, DoubleVector doubleVector) {
        this.probabilityMatrix = doubleMatrix;
        this.classPriorProbability = doubleVector;
    }

    @Override // de.jungblut.classification.AbstractClassifier, de.jungblut.classification.Classifier
    public void train(Iterable<DoubleVector> iterable, Iterable<DoubleVector> iterable2) {
        Iterator<DoubleVector> it = iterable.iterator();
        Iterator<DoubleVector> it2 = iterable2.iterator();
        Tuple consumeNext = Iterables.consumeNext(it, it2);
        int dimension = ((DoubleVector) consumeNext.getSecond()).getDimension();
        int i = dimension == 1 ? 2 : dimension;
        this.probabilityMatrix = new SparseDoubleRowMatrix(i, ((DoubleVector) consumeNext.getFirst()).getDimension());
        int[] iArr = new int[i];
        int[] iArr2 = new int[i];
        observe((DoubleVector) consumeNext.getFirst(), (DoubleVector) consumeNext.getSecond(), i, iArr, iArr2);
        int i2 = 1;
        while (true) {
            Tuple consumeNext2 = Iterables.consumeNext(it, it2);
            if (consumeNext2 == null) {
                break;
            }
            observe((DoubleVector) consumeNext2.getFirst(), (DoubleVector) consumeNext2.getSecond(), i, iArr, iArr2);
            i2++;
        }
        for (int i3 = 0; i3 < i; i3++) {
            Iterator iterateNonZero = this.probabilityMatrix.getRowVector(i3).iterateNonZero();
            double log = FastMath.log((iArr[i3] + this.probabilityMatrix.getColumnCount()) - 1);
            while (iterateNonZero.hasNext()) {
                DoubleVector.DoubleVectorElement doubleVectorElement = (DoubleVector.DoubleVectorElement) iterateNonZero.next();
                this.probabilityMatrix.set(i3, doubleVectorElement.getIndex(), FastMath.log(doubleVectorElement.getValue()) - log);
            }
            if (this.verbose) {
                System.out.println("Computed " + i3 + " / " + i + "!");
            }
        }
        this.classPriorProbability = new DenseDoubleVector(i);
        for (int i4 = 0; i4 < i; i4++) {
            this.classPriorProbability.set(i4, FastMath.log(iArr2[i4]) - FastMath.log(i2));
        }
    }

    private void observe(DoubleVector doubleVector, DoubleVector doubleVector2, int i, int[] iArr, int[] iArr2) {
        int maxIndex = doubleVector2.maxIndex();
        if (i == 2) {
            maxIndex = (int) doubleVector2.get(0);
        }
        int i2 = maxIndex;
        iArr[i2] = iArr[i2] + doubleVector.getLength();
        int i3 = maxIndex;
        iArr2[i3] = iArr2[i3] + 1;
        Iterator iterateNonZero = doubleVector.iterateNonZero();
        while (iterateNonZero.hasNext()) {
            DoubleVector.DoubleVectorElement doubleVectorElement = (DoubleVector.DoubleVectorElement) iterateNonZero.next();
            this.probabilityMatrix.set(maxIndex, doubleVectorElement.getIndex(), this.probabilityMatrix.get(maxIndex, doubleVectorElement.getIndex()) + doubleVectorElement.getValue());
        }
    }

    @Override // de.jungblut.classification.Predictor
    public DoubleVector predict(DoubleVector doubleVector) {
        return getProbabilityDistribution(doubleVector);
    }

    private double getProbabilityForClass(DoubleVector doubleVector, int i) {
        double d = 0.0d;
        Iterator iterateNonZero = doubleVector.iterateNonZero();
        while (iterateNonZero.hasNext()) {
            DoubleVector.DoubleVectorElement doubleVectorElement = (DoubleVector.DoubleVectorElement) iterateNonZero.next();
            double value = doubleVectorElement.getValue();
            double d2 = this.probabilityMatrix.get(i, doubleVectorElement.getIndex());
            if (d2 == 0.0d) {
                d2 = LOW_PROBABILITY;
            }
            d += value * d2;
        }
        return d;
    }

    private DenseDoubleVector getProbabilityDistribution(DoubleVector doubleVector) {
        int length = this.classPriorProbability.getLength();
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(length);
        for (int i = 0; i < length; i++) {
            denseDoubleVector.set(i, getProbabilityForClass(doubleVector, i));
        }
        double max = denseDoubleVector.max();
        double d = 0.0d;
        for (int i2 = 0; i2 < length; i2++) {
            double exp = FastMath.exp((denseDoubleVector.get(i2) - max) + this.classPriorProbability.get(i2));
            denseDoubleVector.set(i2, exp);
            d += exp;
        }
        return denseDoubleVector.divide(d);
    }

    DoubleVector getClassProbability() {
        return this.classPriorProbability;
    }

    DoubleMatrix getProbabilityMatrix() {
        return this.probabilityMatrix;
    }

    public static MultinomialNaiveBayes deserialize(DataInput dataInput) throws IOException {
        MatrixWritable matrixWritable = new MatrixWritable();
        matrixWritable.readFields(dataInput);
        return new MultinomialNaiveBayes(matrixWritable.getMatrix(), VectorWritable.readVector(dataInput));
    }

    public static void serialize(MultinomialNaiveBayes multinomialNaiveBayes, DataOutput dataOutput) throws IOException {
        new MatrixWritable(multinomialNaiveBayes.probabilityMatrix).write(dataOutput);
        VectorWritable.writeVector(multinomialNaiveBayes.classPriorProbability, dataOutput);
    }
}
