package de.jungblut.ner;

import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.minimize.DenseMatrixFolder;
import java.util.Iterator;

/* loaded from: input_file:de/jungblut/ner/ConditionalLikelihoodCostFunction.class */
public final class ConditionalLikelihoodCostFunction implements CostFunction {
    private static final double SIGMA_SQUARED = 100.0d;
    private final DoubleMatrix features;
    private final DoubleMatrix outcome;
    private final int m;
    private final int classes;

    public ConditionalLikelihoodCostFunction(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        this.features = doubleMatrix;
        this.outcome = doubleMatrix2;
        this.m = doubleMatrix2.getRowCount();
        this.classes = doubleMatrix2.getColumnCount() == 1 ? 2 : doubleMatrix2.getColumnCount();
    }

    @Override // de.jungblut.math.minimize.CostFunction
    public CostGradientTuple evaluateCost(DoubleVector doubleVector) {
        DoubleMatrix unfoldMatrix = DenseMatrixFolder.unfoldMatrix(doubleVector, this.classes, (int) (doubleVector.getLength() / this.classes));
        DenseDoubleMatrix denseDoubleMatrix = new DenseDoubleMatrix(unfoldMatrix.getRowCount(), unfoldMatrix.getColumnCount());
        double d = 0.0d;
        for (int i = 0; i < this.m; i++) {
            DoubleVector rowVector = this.features.getRowVector(i);
            double[] dArr = new double[this.classes];
            Iterator iterateNonZero = rowVector.iterateNonZero();
            while (iterateNonZero.hasNext()) {
                DoubleVector.DoubleVectorElement doubleVectorElement = (DoubleVector.DoubleVectorElement) iterateNonZero.next();
                for (int i2 = 0; i2 < this.classes; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + unfoldMatrix.get(i2, doubleVectorElement.getIndex());
                }
            }
            double logSum = logSum(dArr);
            for (int i4 = 0; i4 < this.classes; i4++) {
                double exp = Math.exp(dArr[i4] - logSum);
                Iterator iterateNonZero2 = rowVector.iterateNonZero();
                while (iterateNonZero2.hasNext()) {
                    DoubleVector.DoubleVectorElement doubleVectorElement2 = (DoubleVector.DoubleVectorElement) iterateNonZero2.next();
                    denseDoubleMatrix.set(i4, doubleVectorElement2.getIndex(), denseDoubleMatrix.get(i4, doubleVectorElement2.getIndex()) + exp);
                    if (correctPrediction(i4, this.outcome.getRowVector(i))) {
                        denseDoubleMatrix.set(i4, doubleVectorElement2.getIndex(), denseDoubleMatrix.get(i4, doubleVectorElement2.getIndex()) - 1.0d);
                    }
                }
                if (correctPrediction(i4, this.outcome.getRowVector(i))) {
                    d -= Math.log(exp);
                }
            }
        }
        DoubleVector foldMatrix = DenseMatrixFolder.foldMatrix(denseDoubleMatrix);
        return new CostGradientTuple(d + computeLogPrior(doubleVector, foldMatrix), foldMatrix);
    }

    static boolean correctPrediction(int i, DoubleVector doubleVector) {
        return doubleVector.getLength() == 1 ? ((int) doubleVector.get(0)) == i : doubleVector.maxIndex() == i;
    }

    static double computeLogPrior(DoubleVector doubleVector, DoubleVector doubleVector2) {
        double d = 0.0d;
        for (int i = 0; i < doubleVector.getLength(); i++) {
            d += ((doubleVector.get(i) * doubleVector.get(i)) / 2.0d) / SIGMA_SQUARED;
            doubleVector2.set(i, doubleVector2.get(i) + (doubleVector.get(i) / SIGMA_SQUARED));
        }
        return d;
    }

    static double logSum(double[] dArr) {
        int i = 0;
        double d = dArr[0];
        for (int i2 = 1; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                i = i2;
                d = dArr[i2];
            }
        }
        boolean z = false;
        double d2 = 0.0d;
        double d3 = d - 30.0d;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (i3 != i && dArr[i3] > d3) {
                z = true;
                d2 += Math.exp(dArr[i3] - d);
            }
        }
        return z ? d + Math.log(1.0d + d2) : d;
    }
}
