package org.deeplearning4j.gradientcheck;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/gradientcheck/GradientCheckUtil.class */
public class GradientCheckUtil {
    private static Logger log = LoggerFactory.getLogger(GradientCheckUtil.class);

    public static boolean checkGradients(MultiLayerNetwork multiLayerNetwork, double d, double d2, boolean z, boolean z2, INDArray iNDArray, INDArray iNDArray2, boolean z3) {
        if (d <= 0.0d || d > 0.1d) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (d2 <= 0.0d || d2 > 0.25d) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + d2);
        }
        if (!(multiLayerNetwork.getOutputLayer() instanceof BaseOutputLayer)) {
            throw new IllegalArgumentException("Cannot check backprop gradients without OutputLayer");
        }
        multiLayerNetwork.setInput(iNDArray);
        multiLayerNetwork.setLabels(iNDArray2);
        multiLayerNetwork.computeGradientAndScore();
        Pair<Gradient, Double> gradientAndScore = multiLayerNetwork.gradientAndScore();
        if (z3) {
            UpdaterCreator.getUpdater(multiLayerNetwork).update(multiLayerNetwork, gradientAndScore.getFirst(), 0, multiLayerNetwork.batchSize());
        }
        INDArray gradient = gradientAndScore.getFirst().gradient();
        INDArray params = multiLayerNetwork.params();
        int length = params.length();
        int i = 0;
        double d3 = 0.0d;
        for (int i2 = 0; i2 < length; i2++) {
            INDArray dup = params.dup();
            dup.putScalar(i2, dup.getDouble(i2) + d);
            multiLayerNetwork.setParameters(dup);
            multiLayerNetwork.computeGradientAndScore();
            double score = multiLayerNetwork.score();
            dup.putScalar(i2, dup.getDouble(i2) - (2.0d * d));
            multiLayerNetwork.setParameters(dup);
            multiLayerNetwork.computeGradientAndScore();
            double score2 = multiLayerNetwork.score();
            double d4 = (score - score2) / (2.0d * d);
            if (Double.isNaN(d4)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i2 + " of " + length);
            }
            double d5 = gradient.getDouble(i2);
            double abs = Math.abs(d5 - d4) / (Math.abs(d4) + Math.abs(d5));
            if (d5 == 0.0d && d4 == 0.0d) {
                abs = 0.0d;
            }
            if (abs > d3) {
                d3 = abs;
            }
            if (abs > d2 || Double.isNaN(abs)) {
                if (z) {
                    log.info("Param " + i2 + " FAILED: grad= " + d5 + ", numericalGrad= " + d4 + ", relError= " + abs + ", scorePlus=" + score + ", scoreMinus= " + score2);
                }
                if (z2) {
                    return false;
                }
                i++;
            } else if (z) {
                log.info("Param " + i2 + " passed: grad= " + d5 + ", numericalGrad= " + d4 + ", relError= " + abs);
            }
        }
        if (z) {
            log.info("GradientCheckUtil.checkGradients(): " + length + " params checked, " + (length - i) + " passed, " + i + " failed. Largest relative error = " + d3);
        }
        return i == 0;
    }

    public static boolean checkGradients(ComputationGraph computationGraph, double d, double d2, boolean z, boolean z2, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        if (d <= 0.0d || d > 0.1d) {
            throw new IllegalArgumentException("Invalid epsilon: expect epsilon in range (0,0.1], usually 1e-4 or so");
        }
        if (d2 <= 0.0d || d2 > 0.25d) {
            throw new IllegalArgumentException("Invalid maxRelativeError: " + d2);
        }
        if (computationGraph.getNumInputArrays() != iNDArrayArr.length) {
            throw new IllegalArgumentException("Invalid input arrays: expect " + computationGraph.getNumInputArrays() + " inputs");
        }
        if (computationGraph.getNumOutputArrays() != iNDArrayArr2.length) {
            throw new IllegalArgumentException("Invalid labels arrays: expect " + computationGraph.getNumOutputArrays() + " outputs");
        }
        for (int i = 0; i < iNDArrayArr.length; i++) {
            computationGraph.setInput(i, iNDArrayArr[i]);
        }
        for (int i2 = 0; i2 < iNDArrayArr2.length; i2++) {
            computationGraph.setLabel(i2, iNDArrayArr2[i2]);
        }
        computationGraph.computeGradientAndScore();
        Pair<Gradient, Double> gradientAndScore = computationGraph.gradientAndScore();
        new ComputationGraphUpdater(computationGraph).update(computationGraph, gradientAndScore.getFirst(), 0, computationGraph.batchSize());
        INDArray gradient = gradientAndScore.getFirst().gradient();
        INDArray params = computationGraph.params();
        int length = params.length();
        int i3 = 0;
        double d3 = 0.0d;
        for (int i4 = 0; i4 < length; i4++) {
            INDArray dup = params.dup();
            dup.putScalar(i4, dup.getDouble(i4) + d);
            computationGraph.setParams(dup);
            computationGraph.computeGradientAndScore();
            double score = computationGraph.score();
            dup.putScalar(i4, dup.getDouble(i4) - (2.0d * d));
            computationGraph.setParams(dup);
            computationGraph.computeGradientAndScore();
            double score2 = computationGraph.score();
            double d4 = (score - score2) / (2.0d * d);
            if (Double.isNaN(d4)) {
                throw new IllegalStateException("Numerical gradient was NaN for parameter " + i4 + " of " + length);
            }
            double d5 = gradient.getDouble(i4);
            double abs = Math.abs(d5 - d4) / (Math.abs(d4) + Math.abs(d5));
            if (d5 == 0.0d && d4 == 0.0d) {
                abs = 0.0d;
            }
            if (abs > d3) {
                d3 = abs;
            }
            if (abs > d2 || Double.isNaN(abs)) {
                if (z) {
                    log.info("Param " + i4 + " FAILED: grad= " + d5 + ", numericalGrad= " + d4 + ", relError= " + abs + ", scorePlus=" + score + ", scoreMinus= " + score2);
                }
                if (z2) {
                    return false;
                }
                i3++;
            } else if (z) {
                log.info("Param " + i4 + " passed: grad= " + d5 + ", numericalGrad= " + d4 + ", relError= " + abs);
            }
        }
        if (z) {
            log.info("GradientCheckUtil.checkGradients(): " + length + " params checked, " + (length - i3) + " passed, " + i3 + " failed. Largest relative error = " + d3);
        }
        return i3 == 0;
    }
}
