package org.nd4j.linalg.lossfunctions;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/lossfunctions/LossCalculation.class */
public class LossCalculation {
    private INDArray labels;
    private INDArray z;
    private double l1;
    private double l2;
    private LossFunctions.LossFunction lossFunction;
    private boolean useRegularization;
    private INDArray delta;
    private boolean miniBatch;
    private int miniBatchSize;

    /* loaded from: input_file:org/nd4j/linalg/lossfunctions/LossCalculation$LossCalculationBuilder.class */
    public static class LossCalculationBuilder {
        private INDArray labels;
        private INDArray z;
        private double l1;
        private double l2;
        private LossFunctions.LossFunction lossFunction;
        private boolean useRegularization;
        private INDArray delta;
        private boolean miniBatch;
        private int miniBatchSize;

        LossCalculationBuilder() {
        }

        public LossCalculationBuilder labels(INDArray iNDArray) {
            this.labels = iNDArray;
            return this;
        }

        public LossCalculationBuilder z(INDArray iNDArray) {
            this.z = iNDArray;
            return this;
        }

        public LossCalculationBuilder l1(double d) {
            this.l1 = d;
            return this;
        }

        public LossCalculationBuilder l2(double d) {
            this.l2 = d;
            return this;
        }

        public LossCalculationBuilder lossFunction(LossFunctions.LossFunction lossFunction) {
            this.lossFunction = lossFunction;
            return this;
        }

        public LossCalculationBuilder useRegularization(boolean z) {
            this.useRegularization = z;
            return this;
        }

        public LossCalculationBuilder delta(INDArray iNDArray) {
            this.delta = iNDArray;
            return this;
        }

        public LossCalculationBuilder miniBatch(boolean z) {
            this.miniBatch = z;
            return this;
        }

        public LossCalculationBuilder miniBatchSize(int i) {
            this.miniBatchSize = i;
            return this;
        }

        public LossCalculation build() {
            return new LossCalculation(this.labels, this.z, this.l1, this.l2, this.lossFunction, this.useRegularization, this.delta, this.miniBatch, this.miniBatchSize);
        }

        public String toString() {
            return "LossCalculation.LossCalculationBuilder(labels=" + this.labels + ", z=" + this.z + ", l1=" + this.l1 + ", l2=" + this.l2 + ", lossFunction=" + this.lossFunction + ", useRegularization=" + this.useRegularization + ", delta=" + this.delta + ", miniBatch=" + this.miniBatch + ", miniBatchSize=" + this.miniBatchSize + ")";
        }
    }

    public double score() {
        double d = 0.0d;
        switch (this.lossFunction) {
            case CUSTOM:
                throw new IllegalStateException("Unable to score custom operation. Please define an alternative mechanism");
            case RECONSTRUCTION_CROSSENTROPY:
                d = -this.labels.mul(Transforms.log(this.z)).add(this.labels.rsub((Number) 1)).muli(Transforms.log(this.z).rsubi((Number) 1)).sumNumber().doubleValue();
                break;
            case MCXENT:
                d = -this.labels.mul(Transforms.log(this.z)).sumNumber().doubleValue();
                break;
            case XENT:
                d = this.labels.mul(Transforms.log(this.z)).add(this.labels.rsub((Number) 1)).muli(Transforms.log(this.z).rsubi((Number) 1)).sum(1).sumNumber().doubleValue();
                break;
            case RMSE_XENT:
                d = Transforms.sqrt(Transforms.pow(this.delta == null ? this.labels.sub(this.z) : this.delta, Double.valueOf(2.0d))).sumNumber().doubleValue();
                break;
            case MSE:
                d = 0.5d * Transforms.pow(this.delta == null ? this.labels.sub(this.z) : this.delta, 2).sum(1).sumNumber().doubleValue();
                break;
            case EXPLL:
                d = this.z.sub(this.labels.mul(Transforms.log(this.z))).sumNumber().doubleValue();
                break;
            case SQUARED_LOSS:
                d = Transforms.pow(this.delta == null ? this.labels.sub(this.z) : this.delta, 2).sumNumber().doubleValue();
                break;
            case NEGATIVELOGLIKELIHOOD:
                d = -this.labels.mul(Transforms.log(this.z)).sumNumber().doubleValue();
                break;
        }
        if (this.useRegularization) {
            d += this.l1 + this.l2;
        }
        if (this.miniBatch) {
            d /= this.miniBatchSize;
        }
        return d;
    }

    LossCalculation(INDArray iNDArray, INDArray iNDArray2, double d, double d2, LossFunctions.LossFunction lossFunction, boolean z, INDArray iNDArray3, boolean z2, int i) {
        this.miniBatch = false;
        this.labels = iNDArray;
        this.z = iNDArray2;
        this.l1 = d;
        this.l2 = d2;
        this.lossFunction = lossFunction;
        this.useRegularization = z;
        this.delta = iNDArray3;
        this.miniBatch = z2;
        this.miniBatchSize = i;
    }

    public static LossCalculationBuilder builder() {
        return new LossCalculationBuilder();
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public INDArray getZ() {
        return this.z;
    }

    public double getL1() {
        return this.l1;
    }

    public double getL2() {
        return this.l2;
    }

    public LossFunctions.LossFunction getLossFunction() {
        return this.lossFunction;
    }

    public boolean isUseRegularization() {
        return this.useRegularization;
    }

    public INDArray getDelta() {
        return this.delta;
    }

    public boolean isMiniBatch() {
        return this.miniBatch;
    }

    public int getMiniBatchSize() {
        return this.miniBatchSize;
    }

    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    public void setZ(INDArray iNDArray) {
        this.z = iNDArray;
    }

    public void setL1(double d) {
        this.l1 = d;
    }

    public void setL2(double d) {
        this.l2 = d;
    }

    public void setLossFunction(LossFunctions.LossFunction lossFunction) {
        this.lossFunction = lossFunction;
    }

    public void setUseRegularization(boolean z) {
        this.useRegularization = z;
    }

    public void setDelta(INDArray iNDArray) {
        this.delta = iNDArray;
    }

    public void setMiniBatch(boolean z) {
        this.miniBatch = z;
    }

    public void setMiniBatchSize(int i) {
        this.miniBatchSize = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LossCalculation)) {
            return false;
        }
        LossCalculation lossCalculation = (LossCalculation) obj;
        if (!lossCalculation.canEqual(this)) {
            return false;
        }
        INDArray labels = getLabels();
        INDArray labels2 = lossCalculation.getLabels();
        if (labels == null) {
            if (labels2 != null) {
                return false;
            }
        } else if (!labels.equals(labels2)) {
            return false;
        }
        INDArray z = getZ();
        INDArray z2 = lossCalculation.getZ();
        if (z == null) {
            if (z2 != null) {
                return false;
            }
        } else if (!z.equals(z2)) {
            return false;
        }
        if (Double.compare(getL1(), lossCalculation.getL1()) != 0 || Double.compare(getL2(), lossCalculation.getL2()) != 0) {
            return false;
        }
        LossFunctions.LossFunction lossFunction = getLossFunction();
        LossFunctions.LossFunction lossFunction2 = lossCalculation.getLossFunction();
        if (lossFunction == null) {
            if (lossFunction2 != null) {
                return false;
            }
        } else if (!lossFunction.equals(lossFunction2)) {
            return false;
        }
        if (isUseRegularization() != lossCalculation.isUseRegularization()) {
            return false;
        }
        INDArray delta = getDelta();
        INDArray delta2 = lossCalculation.getDelta();
        if (delta == null) {
            if (delta2 != null) {
                return false;
            }
        } else if (!delta.equals(delta2)) {
            return false;
        }
        return isMiniBatch() == lossCalculation.isMiniBatch() && getMiniBatchSize() == lossCalculation.getMiniBatchSize();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof LossCalculation;
    }

    public int hashCode() {
        INDArray labels = getLabels();
        int hashCode = (1 * 59) + (labels == null ? 0 : labels.hashCode());
        INDArray z = getZ();
        int hashCode2 = (hashCode * 59) + (z == null ? 0 : z.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(getL1());
        int i = (hashCode2 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        long doubleToLongBits2 = Double.doubleToLongBits(getL2());
        int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        LossFunctions.LossFunction lossFunction = getLossFunction();
        int hashCode3 = (((i2 * 59) + (lossFunction == null ? 0 : lossFunction.hashCode())) * 59) + (isUseRegularization() ? 79 : 97);
        INDArray delta = getDelta();
        return (((((hashCode3 * 59) + (delta == null ? 0 : delta.hashCode())) * 59) + (isMiniBatch() ? 79 : 97)) * 59) + getMiniBatchSize();
    }

    public String toString() {
        return "LossCalculation(labels=" + getLabels() + ", z=" + getZ() + ", l1=" + getL1() + ", l2=" + getL2() + ", lossFunction=" + getLossFunction() + ", useRegularization=" + isUseRegularization() + ", delta=" + getDelta() + ", miniBatch=" + isMiniBatch() + ", miniBatchSize=" + getMiniBatchSize() + ")";
    }
}
