package de.jungblut.math;

import com.google.common.base.Preconditions;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.tuple.Tuple;
import de.jungblut.math.tuple.Tuple3;
import de.jungblut.online.ml.FeatureOutcomePair;
import de.jungblut.reader.Dataset;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.function.Predicate;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:de/jungblut/math/MathUtils.class */
public final class MathUtils {
    public static final double EPS = Math.sqrt(2.2E-16d);

    /* loaded from: input_file:de/jungblut/math/MathUtils$PredictionOutcomePair.class */
    public static class PredictionOutcomePair implements Comparable<PredictionOutcomePair> {
        private final int outcomeClass;
        private final double prediction;

        private PredictionOutcomePair(int i, double d) {
            this.outcomeClass = i;
            this.prediction = d;
        }

        public static PredictionOutcomePair from(int i, double d) {
            Preconditions.checkArgument(i == 0 || i == 1, "Outcome class must be 0 or 1! Supplied: " + i);
            return new PredictionOutcomePair(i, d);
        }

        @Override // java.lang.Comparable
        public int compareTo(PredictionOutcomePair predictionOutcomePair) {
            return Double.compare(this.prediction, predictionOutcomePair.prediction);
        }

        public int getOutcomeClass() {
            return this.outcomeClass;
        }

        public double getPrediction() {
            return this.prediction;
        }
    }

    private MathUtils() {
        throw new IllegalAccessError();
    }

    public static Tuple<DoubleMatrix, DoubleVector> meanNormalizeRows(DoubleMatrix doubleMatrix) {
        DenseDoubleMatrix denseDoubleMatrix = new DenseDoubleMatrix(doubleMatrix.getRowCount(), doubleMatrix.getColumnCount());
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(denseDoubleMatrix.getRowCount());
        for (int i = 0; i < denseDoubleMatrix.getRowCount(); i++) {
            double d = 0.0d;
            int i2 = 0;
            for (int i3 = 0; i3 < denseDoubleMatrix.getColumnCount(); i3++) {
                double d2 = doubleMatrix.get(i, i3);
                if (d2 != 0.0d) {
                    d += d2;
                    i2++;
                }
            }
            if (i2 != 0.0d) {
                d /= i2;
            }
            denseDoubleVector.set(i, d);
            for (int i4 = 0; i4 < denseDoubleMatrix.getColumnCount(); i4++) {
                double d3 = doubleMatrix.get(i, i4);
                if (d3 != 0.0d) {
                    denseDoubleMatrix.set(i, i4, d3 - d);
                }
            }
        }
        return new Tuple<>(denseDoubleMatrix, denseDoubleVector);
    }

    public static Tuple3<DoubleMatrix, DoubleVector, DoubleVector> meanNormalizeColumns(DoubleMatrix doubleMatrix) {
        DenseDoubleMatrix denseDoubleMatrix = new DenseDoubleMatrix(doubleMatrix.getRowCount(), doubleMatrix.getColumnCount());
        int columnCount = doubleMatrix.getColumnCount();
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(columnCount);
        DenseDoubleVector denseDoubleVector2 = new DenseDoubleVector(columnCount);
        for (int i = 0; i < columnCount; i++) {
            DoubleVector columnVector = doubleMatrix.getColumnVector(i);
            double sum = columnVector.sum() / columnVector.getLength();
            denseDoubleVector.set(i, sum);
            denseDoubleVector2.set(i, Math.sqrt(columnVector.subtract(sum).pow(2.0d).sum() / columnVector.getLength()));
        }
        for (int i2 = 0; i2 < columnCount; i2++) {
            denseDoubleMatrix.setColumn(i2, doubleMatrix.getColumnVector(i2).subtract(denseDoubleVector.get(i2)).divide(denseDoubleVector2.get(i2)).toArray());
        }
        return new Tuple3<>(denseDoubleMatrix, denseDoubleVector, denseDoubleVector2);
    }

    public static Tuple<DoubleVector, DoubleVector> meanNormalizeColumns(Dataset dataset) {
        return meanNormalizeColumns(dataset, featureOutcomePair -> {
            return true;
        });
    }

    public static Tuple<DoubleVector, DoubleVector> meanNormalizeColumns(Dataset dataset, Predicate<FeatureOutcomePair> predicate) {
        int length = dataset.getFeatures().length;
        DoubleVector doubleVector = null;
        for (int i = 0; i < length; i++) {
            if (predicate.test(new FeatureOutcomePair(dataset.getFeatures()[i], dataset.getOutcomes()[i]))) {
                doubleVector = doubleVector == null ? dataset.getFeatures()[i] : doubleVector.add(dataset.getFeatures()[i]);
            }
        }
        DoubleVector divide = doubleVector.divide(length);
        DoubleVector doubleVector2 = null;
        for (int i2 = 0; i2 < length; i2++) {
            if (predicate.test(new FeatureOutcomePair(dataset.getFeatures()[i2], dataset.getOutcomes()[i2]))) {
                doubleVector2 = doubleVector2 == null ? dataset.getFeatures()[i2].subtract(divide).pow(2.0d) : doubleVector2.add(dataset.getFeatures()[i2].subtract(divide).pow(2.0d));
            }
        }
        DoubleVector apply = doubleVector2.divide(length).sqrt().apply((i3, d) -> {
            return Math.max(1.0d, d);
        });
        for (int i4 = 0; i4 < length; i4++) {
            if (predicate.test(new FeatureOutcomePair(dataset.getFeatures()[i4], dataset.getOutcomes()[i4]))) {
                dataset.getFeatures()[i4] = dataset.getFeatures()[i4].subtract(divide).divide(apply);
            }
        }
        return new Tuple<>(divide, apply);
    }

    public static DenseDoubleMatrix createPolynomials(DenseDoubleMatrix denseDoubleMatrix, int i) {
        if (i == 1) {
            return denseDoubleMatrix;
        }
        DenseDoubleMatrix denseDoubleMatrix2 = new DenseDoubleMatrix(denseDoubleMatrix.getRowCount(), denseDoubleMatrix.getColumnCount() * i);
        int i2 = 0;
        int i3 = 0;
        while (true) {
            int i4 = i3;
            if (i4 >= denseDoubleMatrix2.getColumnCount()) {
                return denseDoubleMatrix2;
            }
            int i5 = i2;
            i2++;
            double[] column = denseDoubleMatrix.getColumn(i5);
            denseDoubleMatrix2.setColumn(i4, column);
            for (int i6 = 2; i6 < i + 1; i6++) {
                denseDoubleMatrix2.setColumn((i4 + i6) - 1, new DenseDoubleVector(column).pow(i6).toArray());
            }
            i3 = i4 + i;
        }
    }

    public static DoubleVector numericalGradient(DoubleVector doubleVector, CostFunction costFunction) {
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(doubleVector.getLength());
        DoubleVector deepCopy = doubleVector.deepCopy();
        for (int i = 0; i < doubleVector.getLength(); i++) {
            double abs = EPS * (Math.abs(doubleVector.get(i)) + 1.0d);
            deepCopy.set(i, doubleVector.get(i) + abs);
            double cost = costFunction.evaluateCost(deepCopy).getCost();
            deepCopy.set(i, doubleVector.get(i) - abs);
            denseDoubleVector.set(i, (cost - costFunction.evaluateCost(deepCopy).getCost()) / (2.0d * abs));
        }
        return denseDoubleVector;
    }

    public static DoubleMatrix logMatrix(DoubleMatrix doubleMatrix) {
        DenseDoubleMatrix denseDoubleMatrix = new DenseDoubleMatrix(doubleMatrix.getRowCount(), doubleMatrix.getColumnCount());
        for (int i = 0; i < denseDoubleMatrix.getRowCount(); i++) {
            for (int i2 = 0; i2 < denseDoubleMatrix.getColumnCount(); i2++) {
                denseDoubleMatrix.set(i, i2, guardedLogarithm(doubleMatrix.get(i, i2)));
            }
        }
        return denseDoubleMatrix;
    }

    public static DoubleVector logVector(DoubleVector doubleVector) {
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(doubleVector.getDimension());
        for (int i = 0; i < denseDoubleVector.getDimension(); i++) {
            denseDoubleVector.set(i, guardedLogarithm(doubleVector.get(i)));
        }
        return denseDoubleVector;
    }

    public static DoubleMatrix minMaxScale(DoubleMatrix doubleMatrix, double d, double d2, double d3, double d4) {
        DenseDoubleMatrix denseDoubleMatrix = new DenseDoubleMatrix(doubleMatrix.getRowCount(), doubleMatrix.getColumnCount());
        double[][] array = doubleMatrix.toArray();
        for (int i = 0; i < denseDoubleMatrix.getRowCount(); i++) {
            for (int i2 = 0; i2 < denseDoubleMatrix.getColumnCount(); i2++) {
                denseDoubleMatrix.set(i, i2, minMaxScale(array[i][i2], d, d2, d3, d4));
            }
        }
        return denseDoubleMatrix;
    }

    public static DoubleVector minMaxScale(DoubleVector doubleVector, double d, double d2, double d3, double d4) {
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(doubleVector.getDimension());
        double[] array = doubleVector.toArray();
        for (int i = 0; i < array.length; i++) {
            denseDoubleVector.set(i, minMaxScale(array[i], d, d2, d3, d4));
        }
        return denseDoubleVector;
    }

    public static double minMaxScale(double d, double d2, double d3, double d4, double d5) {
        return (((d - d2) * (d5 - d4)) / (d3 - d2)) + d4;
    }

    public static double guardedLogarithm(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d)) {
            return 0.0d;
        }
        if (d <= 0.0d || d <= -0.0d) {
            return -10.0d;
        }
        return FastMath.log(d);
    }

    public static double computeAUC(List<PredictionOutcomePair> list) {
        Collections.sort(list);
        int size = list.size();
        int i = 0;
        Iterator<PredictionOutcomePair> it = list.iterator();
        while (it.hasNext()) {
            if (it.next().getOutcomeClass() == 1) {
                i++;
            }
        }
        if (i == 0 || i == size) {
            return 1.0d;
        }
        long j = i;
        long j2 = j;
        long j3 = j;
        long j4 = 0;
        long j5 = 0;
        double prediction = list.get(0).getPrediction();
        for (int i2 = 0; i2 < size; i2++) {
            double outcomeClass = list.get(i2).getOutcomeClass();
            double prediction2 = list.get(i2).getPrediction();
            if (prediction2 != prediction) {
                prediction = prediction2;
                j5 += j4 * (j3 + j2);
                j2 = j3;
                j4 = 0;
            }
            j4 = (long) (j4 + (1.0d - outcomeClass));
            j3 = (long) (j3 - outcomeClass);
        }
        return (j5 + (j4 * (j3 + j2))) / ((2 * i) * (size - i));
    }
}
