package de.jungblut.classification.nn;

import com.google.common.base.Preconditions;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.ActivationFunction;
import de.jungblut.math.cuda.JCUDAMatrixUtils;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.loss.LossFunction;
import de.jungblut.math.minimize.AbstractMiniBatchCostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.minimize.DenseMatrixFolder;
import de.jungblut.writable.MatrixWritable;
import java.util.Random;

/* loaded from: input_file:de/jungblut/classification/nn/MultilayerPerceptronCostFunction.class */
public final class MultilayerPerceptronCostFunction extends AbstractMiniBatchCostFunction {
    private final NetworkConfiguration configuration;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: de.jungblut.classification.nn.MultilayerPerceptronCostFunction$1, reason: invalid class name */
    /* loaded from: input_file:de/jungblut/classification/nn/MultilayerPerceptronCostFunction$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$de$jungblut$classification$nn$TrainingType = new int[TrainingType.values().length];

        static {
            try {
                $SwitchMap$de$jungblut$classification$nn$TrainingType[TrainingType.CPU.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$de$jungblut$classification$nn$TrainingType[TrainingType.GPU.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* loaded from: input_file:de/jungblut/classification/nn/MultilayerPerceptronCostFunction$NetworkConfiguration.class */
    public static class NetworkConfiguration {
        public double lambda;
        public int[] layerSizes;
        public int[][] unfoldParameters;
        public ActivationFunction[] activations;
        public LossFunction error;
        public TrainingType trainingType;
        public double visibleDropoutProbability;
        public double hiddenDropoutProbability;
        public Random rnd;
    }

    public MultilayerPerceptronCostFunction(MultilayerPerceptron multilayerPerceptron, DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2) {
        super(doubleVectorArr, doubleVectorArr2, multilayerPerceptron.getMiniBatchSize(), multilayerPerceptron.getBatchParallelism(), multilayerPerceptron.isStochastic());
        this.configuration = new NetworkConfiguration();
        this.configuration.lambda = multilayerPerceptron.getLambda();
        this.configuration.layerSizes = multilayerPerceptron.getLayers();
        this.configuration.unfoldParameters = computeUnfoldParameters(this.configuration.layerSizes);
        this.configuration.activations = multilayerPerceptron.getActivations();
        this.configuration.error = multilayerPerceptron.getErrorFunction();
        this.configuration.trainingType = multilayerPerceptron.getTrainingType();
        this.configuration.visibleDropoutProbability = multilayerPerceptron.getVisibleDropoutProbability();
        this.configuration.hiddenDropoutProbability = multilayerPerceptron.getHiddenDropoutProbability();
        this.configuration.rnd = new Random();
    }

    @Override // de.jungblut.math.minimize.AbstractMiniBatchCostFunction
    protected CostGradientTuple evaluateBatch(DoubleVector doubleVector, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        return computeNextStep(doubleVector, doubleMatrix, doubleMatrix2, this.configuration);
    }

    public static CostGradientTuple computeNextStep(DoubleVector doubleVector, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, NetworkConfiguration networkConfiguration) {
        Preconditions.checkArgument(doubleMatrix.getColumnCount() - 1 == networkConfiguration.layerSizes[0], "Input layer size must match the given vector dimension! Given: " + (doubleMatrix.getColumnCount() - 1) + ", expected: " + networkConfiguration.layerSizes[0]);
        int rowCount = doubleMatrix.getRowCount();
        DoubleMatrix[] unfoldMatrices = DenseMatrixFolder.unfoldMatrices(doubleVector, networkConfiguration.unfoldParameters);
        DoubleMatrix[] doubleMatrixArr = new DoubleMatrix[unfoldMatrices.length];
        DoubleMatrix[] doubleMatrixArr2 = new DoubleMatrix[networkConfiguration.layerSizes.length];
        DoubleMatrix[] doubleMatrixArr3 = new DoubleMatrix[networkConfiguration.layerSizes.length];
        dropoutVisibleLayer(doubleMatrix, doubleMatrixArr2, networkConfiguration);
        forwardPropagate(unfoldMatrices, doubleMatrixArr2, doubleMatrixArr3, networkConfiguration);
        double calculateRegularization = calculateRegularization(unfoldMatrices, rowCount, networkConfiguration);
        calculateGradients(unfoldMatrices, doubleMatrixArr, doubleMatrixArr2, backwardPropagate(doubleMatrix2, unfoldMatrices, doubleMatrixArr2, doubleMatrixArr3, networkConfiguration), rowCount, networkConfiguration);
        return new CostGradientTuple(networkConfiguration.error.calculateLoss(doubleMatrix2, doubleMatrixArr2[networkConfiguration.layerSizes.length - 1]) + calculateRegularization, DenseMatrixFolder.foldMatrices(doubleMatrixArr));
    }

    public static void forwardPropagate(DoubleMatrix[] doubleMatrixArr, DoubleMatrix[] doubleMatrixArr2, DoubleMatrix[] doubleMatrixArr3, NetworkConfiguration networkConfiguration) {
        for (int i = 1; i < networkConfiguration.layerSizes.length; i++) {
            doubleMatrixArr3[i] = multiply(doubleMatrixArr2[i - 1], doubleMatrixArr[i - 1], false, true, networkConfiguration);
            if (i < networkConfiguration.layerSizes.length - 1) {
                doubleMatrixArr2[i] = new DenseDoubleMatrix(DenseDoubleVector.ones(doubleMatrixArr3[i].getRowCount()), networkConfiguration.activations[i].apply(doubleMatrixArr3[i]));
                if (networkConfiguration.hiddenDropoutProbability > 0.0d) {
                    dropout(networkConfiguration.rnd, doubleMatrixArr2[i], networkConfiguration.hiddenDropoutProbability);
                }
            } else {
                doubleMatrixArr2[i] = networkConfiguration.activations[i].apply(doubleMatrixArr3[i]);
            }
        }
    }

    public static DoubleMatrix[] backwardPropagate(DoubleMatrix doubleMatrix, DoubleMatrix[] doubleMatrixArr, DoubleMatrix[] doubleMatrixArr2, DoubleMatrix[] doubleMatrixArr3, NetworkConfiguration networkConfiguration) {
        DoubleMatrix[] doubleMatrixArr4 = new DoubleMatrix[networkConfiguration.layerSizes.length];
        doubleMatrixArr4[doubleMatrixArr4.length - 1] = doubleMatrixArr2[networkConfiguration.layerSizes.length - 1].subtract(doubleMatrix);
        for (int length = networkConfiguration.layerSizes.length - 2; length > 0; length--) {
            doubleMatrixArr4[length] = multiply(doubleMatrixArr4[length + 1], doubleMatrixArr[length].slice(0, doubleMatrixArr[length].getRowCount(), 1, doubleMatrixArr[length].getColumnCount()), false, false, networkConfiguration);
            doubleMatrixArr4[length] = doubleMatrixArr4[length].multiplyElementWise(networkConfiguration.activations[length].gradient(doubleMatrixArr3[length]));
        }
        return doubleMatrixArr4;
    }

    public static void calculateGradients(DoubleMatrix[] doubleMatrixArr, DoubleMatrix[] doubleMatrixArr2, DoubleMatrix[] doubleMatrixArr3, DoubleMatrix[] doubleMatrixArr4, int i, NetworkConfiguration networkConfiguration) {
        for (int i2 = 0; i2 < doubleMatrixArr2.length; i2++) {
            DoubleMatrix multiply = multiply(doubleMatrixArr4[i2 + 1], doubleMatrixArr3[i2], true, false, networkConfiguration);
            if (i != 1) {
                doubleMatrixArr2[i2] = multiply.divide(i);
            } else {
                doubleMatrixArr2[i2] = multiply;
            }
            if (networkConfiguration.lambda != 0.0d) {
                doubleMatrixArr2[i2] = doubleMatrixArr2[i2].add(doubleMatrixArr[i2].multiply(networkConfiguration.lambda / i));
                doubleMatrixArr2[i2].setColumnVector(0, doubleMatrixArr[i2].slice(0, doubleMatrixArr[i2].getRowCount(), 0, 1).multiply(networkConfiguration.lambda / i).getColumnVector(0));
            }
        }
    }

    public static double calculateRegularization(DoubleMatrix[] doubleMatrixArr, int i, NetworkConfiguration networkConfiguration) {
        double d = 0.0d;
        if (networkConfiguration.lambda != 0.0d) {
            for (DoubleMatrix doubleMatrix : doubleMatrixArr) {
                d += doubleMatrix.slice(0, doubleMatrix.getRowCount(), 1, doubleMatrix.getColumnCount()).pow(2.0d).sum();
            }
            d = (networkConfiguration.lambda / (2.0d * i)) * d;
        }
        return d;
    }

    public static void dropoutVisibleLayer(DoubleMatrix doubleMatrix, DoubleMatrix[] doubleMatrixArr, NetworkConfiguration networkConfiguration) {
        if (networkConfiguration.visibleDropoutProbability <= 0.0d) {
            doubleMatrixArr[0] = doubleMatrix;
        } else {
            doubleMatrixArr[0] = doubleMatrix.deepCopy();
            dropout(networkConfiguration.rnd, doubleMatrixArr[0], networkConfiguration.visibleDropoutProbability);
        }
    }

    private static DoubleMatrix multiply(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, boolean z, boolean z2, NetworkConfiguration networkConfiguration) {
        switch (AnonymousClass1.$SwitchMap$de$jungblut$classification$nn$TrainingType[networkConfiguration.trainingType.ordinal()]) {
            case MatrixWritable.DENSE_DOUBLE_MATRIX /* 1 */:
                return multiplyCPU(doubleMatrix, doubleMatrix2, z, z2);
            case MatrixWritable.SPARSE_DOUBLE_ROW_MATRIX /* 2 */:
                return multiplyGPU(doubleMatrix, doubleMatrix2, z, z2);
            default:
                throw new IllegalArgumentException("Trainingtype couldn't be anticipated by switch.");
        }
    }

    private static DoubleMatrix multiplyCPU(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, boolean z, boolean z2) {
        return (z ? doubleMatrix.transpose() : doubleMatrix).multiply(z2 ? doubleMatrix2.transpose() : doubleMatrix2);
    }

    private static DoubleMatrix multiplyGPU(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, boolean z, boolean z2) {
        return JCUDAMatrixUtils.multiply((DenseDoubleMatrix) doubleMatrix, (DenseDoubleMatrix) doubleMatrix2, z, z2);
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [int[], int[][]] */
    public static int[][] computeUnfoldParameters(int[] iArr) {
        ?? r0 = new int[iArr.length - 1];
        for (int i = 0; i < r0.length; i++) {
            int[] iArr2 = new int[2];
            iArr2[0] = iArr[i + 1];
            iArr2[1] = iArr[i] + 1;
            r0[i] = iArr2;
        }
        return r0;
    }

    public static void dropout(Random random, DoubleMatrix doubleMatrix, double d) {
        for (int i = 0; i < doubleMatrix.getRowCount(); i++) {
            for (int i2 = 0; i2 < doubleMatrix.getColumnCount(); i2++) {
                if (random.nextDouble() <= d) {
                    doubleMatrix.set(i, i2, 0.0d);
                }
            }
        }
    }
}
