package org.nd4j.linalg.lossfunctions;

import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/lossfunctions/LossFunctions.class */
public class LossFunctions {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/nd4j/linalg/lossfunctions/LossFunctions$LossFunction.class */
    public enum LossFunction {
        MSE,
        EXPLL,
        XENT,
        MCXENT,
        RMSE_XENT,
        SQUARED_LOSS,
        RECONSTRUCTION_CROSSENTROPY,
        NEGATIVELOGLIKELIHOOD
    }

    public static float score(INDArray iNDArray, LossFunction lossFunction, INDArray iNDArray2, double d, boolean z) {
        if (!$assertionsDisabled && Nd4j.hasInvalidNumber(iNDArray2)) {
            throw new AssertionError("Invalid output on labels. Must not contain nan or infinite numbers.");
        }
        float f = 0.0f;
        double d2 = 0.5d * d;
        switch (lossFunction) {
            case RECONSTRUCTION_CROSSENTROPY:
                f = (-iNDArray.mul(Transforms.log(iNDArray2.dup())).add(iNDArray.rsub((Number) 1)).mul(Transforms.log(iNDArray2).rsubi((Number) 1)).sum(1).sum(Integer.MAX_VALUE).get(0)) / iNDArray.rows();
                break;
            case MCXENT:
                f = (-iNDArray.mul(Transforms.log(iNDArray2)).sum(1).sum(Integer.MAX_VALUE).get(0)) / iNDArray.rows();
                break;
            case XENT:
                f = (-iNDArray.mul(Transforms.log(iNDArray2)).add(iNDArray.rsub((Number) 1)).mul(Transforms.log(iNDArray2).rsubi((Number) 1)).sum(1).sum(Integer.MAX_VALUE).get(0)) / iNDArray.rows();
                break;
            case RMSE_XENT:
                f = Transforms.sqrt(Transforms.pow(iNDArray.sub(iNDArray2), (Number) 2)).sum(1).sum(Integer.MAX_VALUE).get(0) / iNDArray.rows();
                break;
            case MSE:
                f = (0.5f * Transforms.pow(iNDArray.sub(iNDArray2), (Number) 2).sum(1).sum(Integer.MAX_VALUE).get(0)) / iNDArray.rows();
                break;
            case EXPLL:
                f = (-iNDArray2.sub(iNDArray.mul(Transforms.log(iNDArray2))).sum(1).sum(Integer.MAX_VALUE).get(0)) / iNDArray.rows();
                break;
            case SQUARED_LOSS:
                f = ((Float) Transforms.pow(iNDArray.sub(iNDArray2), (Number) 2).sum(1).sum(Integer.MAX_VALUE).element()).floatValue() / iNDArray.rows();
                break;
        }
        if (z && d > 0.0d) {
            f = (float) (f + d2);
        }
        return f;
    }

    public static float reconEntropy(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, ActivationFunction activationFunction) {
        INDArray iNDArray5 = (INDArray) activationFunction.apply(iNDArray.mmul(iNDArray4).addRowVector(iNDArray2));
        if (!$assertionsDisabled && Nd4j.hasInvalidNumber(iNDArray5)) {
            throw new AssertionError();
        }
        INDArray iNDArray6 = (INDArray) activationFunction.apply(iNDArray5.mmul(iNDArray4.transpose()).addRowVector(iNDArray3));
        if ($assertionsDisabled || !Nd4j.hasInvalidNumber(iNDArray5)) {
            return ((Float) iNDArray.mul(Transforms.log(iNDArray6)).add(iNDArray.rsub((Number) 1).mul(Transforms.log(iNDArray6.rsub((Number) 1)))).sum(1).mean(Integer.MAX_VALUE).element()).floatValue() / iNDArray.rows();
        }
        throw new AssertionError();
    }

    static {
        $assertionsDisabled = !LossFunctions.class.desiredAssertionStatus();
    }
}
