package org.nd4j.autodiff.validation;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.validation.listeners.NonInplaceValidationListener;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.AdaGradUpdater;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/validation/GradCheckUtil.class */
public class GradCheckUtil {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) GradCheckUtil.class);
    public static final boolean DEFAULT_PRINT = false;
    public static final boolean DEFAULT_EXIT_FIRST_FAILURE = false;
    public static final boolean DEFAULT_DEBUG_MODE = false;
    public static final double DEFAULT_EPS = 1.0E-5d;
    public static final double DEFAULT_MAX_REL_ERROR = 1.0E-5d;
    public static final double DEFAULT_MIN_ABS_ERROR = 1.0E-6d;

    /* loaded from: input_file:org/nd4j/autodiff/validation/GradCheckUtil$ActGradConfig.class */
    public static class ActGradConfig {
        private SameDiff sd;
        private Map<String, INDArray> placeholderValues;
        private List<String> activationGradsToCheck;
        private double eps;
        private double maxRelError;
        private double minAbsError;
        private boolean print;
        boolean exitOnFirstFailure;
        private boolean skipValidation;
        private boolean debugMode;
        private Set<String> skipVariables;
        private Map<String, INDArray> gradCheckMask;
        int maxPerParam;
        private Subset subset;

        /* loaded from: input_file:org/nd4j/autodiff/validation/GradCheckUtil$ActGradConfig$ActGradConfigBuilder.class */
        public static class ActGradConfigBuilder {
            private SameDiff sd;
            private Map<String, INDArray> placeholderValues;
            private List<String> activationGradsToCheck;
            private boolean eps$set;
            private double eps$value;
            private boolean maxRelError$set;
            private double maxRelError$value;
            private boolean minAbsError$set;
            private double minAbsError$value;
            private boolean print$set;
            private boolean print$value;
            private boolean exitOnFirstFailure$set;
            private boolean exitOnFirstFailure$value;
            private boolean skipValidation$set;
            private boolean skipValidation$value;
            private boolean debugMode$set;
            private boolean debugMode$value;
            private Set<String> skipVariables;
            private Map<String, INDArray> gradCheckMask;
            private int maxPerParam;
            private Subset subset;

            ActGradConfigBuilder() {
            }

            public ActGradConfigBuilder sd(SameDiff sameDiff) {
                this.sd = sameDiff;
                return this;
            }

            public ActGradConfigBuilder placeholderValues(Map<String, INDArray> map) {
                this.placeholderValues = map;
                return this;
            }

            public ActGradConfigBuilder activationGradsToCheck(List<String> list) {
                this.activationGradsToCheck = list;
                return this;
            }

            public ActGradConfigBuilder eps(double d) {
                this.eps$value = d;
                this.eps$set = true;
                return this;
            }

            public ActGradConfigBuilder maxRelError(double d) {
                this.maxRelError$value = d;
                this.maxRelError$set = true;
                return this;
            }

            public ActGradConfigBuilder minAbsError(double d) {
                this.minAbsError$value = d;
                this.minAbsError$set = true;
                return this;
            }

            public ActGradConfigBuilder print(boolean z) {
                this.print$value = z;
                this.print$set = true;
                return this;
            }

            public ActGradConfigBuilder exitOnFirstFailure(boolean z) {
                this.exitOnFirstFailure$value = z;
                this.exitOnFirstFailure$set = true;
                return this;
            }

            public ActGradConfigBuilder skipValidation(boolean z) {
                this.skipValidation$value = z;
                this.skipValidation$set = true;
                return this;
            }

            public ActGradConfigBuilder debugMode(boolean z) {
                this.debugMode$value = z;
                this.debugMode$set = true;
                return this;
            }

            public ActGradConfigBuilder skipVariables(Set<String> set) {
                this.skipVariables = set;
                return this;
            }

            public ActGradConfigBuilder gradCheckMask(Map<String, INDArray> map) {
                this.gradCheckMask = map;
                return this;
            }

            public ActGradConfigBuilder maxPerParam(int i) {
                this.maxPerParam = i;
                return this;
            }

            public ActGradConfigBuilder subset(Subset subset) {
                this.subset = subset;
                return this;
            }

            public ActGradConfig build() {
                double d = this.eps$value;
                if (!this.eps$set) {
                    d = ActGradConfig.access$000();
                }
                double d2 = this.maxRelError$value;
                if (!this.maxRelError$set) {
                    d2 = ActGradConfig.access$100();
                }
                double d3 = this.minAbsError$value;
                if (!this.minAbsError$set) {
                    d3 = ActGradConfig.access$200();
                }
                boolean z = this.print$value;
                if (!this.print$set) {
                    z = ActGradConfig.access$300();
                }
                boolean z2 = this.exitOnFirstFailure$value;
                if (!this.exitOnFirstFailure$set) {
                    z2 = ActGradConfig.access$400();
                }
                boolean z3 = this.skipValidation$value;
                if (!this.skipValidation$set) {
                    z3 = ActGradConfig.access$500();
                }
                boolean z4 = this.debugMode$value;
                if (!this.debugMode$set) {
                    z4 = ActGradConfig.access$600();
                }
                return new ActGradConfig(this.sd, this.placeholderValues, this.activationGradsToCheck, d, d2, d3, z, z2, z3, z4, this.skipVariables, this.gradCheckMask, this.maxPerParam, this.subset);
            }

            public String toString() {
                return "GradCheckUtil.ActGradConfig.ActGradConfigBuilder(sd=" + this.sd + ", placeholderValues=" + this.placeholderValues + ", activationGradsToCheck=" + this.activationGradsToCheck + ", eps$value=" + this.eps$value + ", maxRelError$value=" + this.maxRelError$value + ", minAbsError$value=" + this.minAbsError$value + ", print$value=" + this.print$value + ", exitOnFirstFailure$value=" + this.exitOnFirstFailure$value + ", skipValidation$value=" + this.skipValidation$value + ", debugMode$value=" + this.debugMode$value + ", skipVariables=" + this.skipVariables + ", gradCheckMask=" + this.gradCheckMask + ", maxPerParam=" + this.maxPerParam + ", subset=" + this.subset + ")";
            }
        }

        private static double $default$eps() {
            return 1.0E-5d;
        }

        private static double $default$maxRelError() {
            return 1.0E-5d;
        }

        private static double $default$minAbsError() {
            return 1.0E-6d;
        }

        private static boolean $default$print() {
            return false;
        }

        private static boolean $default$exitOnFirstFailure() {
            return false;
        }

        private static boolean $default$skipValidation() {
            return false;
        }

        private static boolean $default$debugMode() {
            return false;
        }

        ActGradConfig(SameDiff sameDiff, Map<String, INDArray> map, List<String> list, double d, double d2, double d3, boolean z, boolean z2, boolean z3, boolean z4, Set<String> set, Map<String, INDArray> map2, int i, Subset subset) {
            this.sd = sameDiff;
            this.placeholderValues = map;
            this.activationGradsToCheck = list;
            this.eps = d;
            this.maxRelError = d2;
            this.minAbsError = d3;
            this.print = z;
            this.exitOnFirstFailure = z2;
            this.skipValidation = z3;
            this.debugMode = z4;
            this.skipVariables = set;
            this.gradCheckMask = map2;
            this.maxPerParam = i;
            this.subset = subset;
        }

        public static ActGradConfigBuilder builder() {
            return new ActGradConfigBuilder();
        }

        public SameDiff getSd() {
            return this.sd;
        }

        public Map<String, INDArray> getPlaceholderValues() {
            return this.placeholderValues;
        }

        public List<String> getActivationGradsToCheck() {
            return this.activationGradsToCheck;
        }

        public double getEps() {
            return this.eps;
        }

        public double getMaxRelError() {
            return this.maxRelError;
        }

        public double getMinAbsError() {
            return this.minAbsError;
        }

        public boolean isPrint() {
            return this.print;
        }

        public boolean isExitOnFirstFailure() {
            return this.exitOnFirstFailure;
        }

        public boolean isSkipValidation() {
            return this.skipValidation;
        }

        public boolean isDebugMode() {
            return this.debugMode;
        }

        public Set<String> getSkipVariables() {
            return this.skipVariables;
        }

        public Map<String, INDArray> getGradCheckMask() {
            return this.gradCheckMask;
        }

        public int getMaxPerParam() {
            return this.maxPerParam;
        }

        public Subset getSubset() {
            return this.subset;
        }

        public void setSd(SameDiff sameDiff) {
            this.sd = sameDiff;
        }

        public void setPlaceholderValues(Map<String, INDArray> map) {
            this.placeholderValues = map;
        }

        public void setActivationGradsToCheck(List<String> list) {
            this.activationGradsToCheck = list;
        }

        public void setEps(double d) {
            this.eps = d;
        }

        public void setMaxRelError(double d) {
            this.maxRelError = d;
        }

        public void setMinAbsError(double d) {
            this.minAbsError = d;
        }

        public void setPrint(boolean z) {
            this.print = z;
        }

        public void setExitOnFirstFailure(boolean z) {
            this.exitOnFirstFailure = z;
        }

        public void setSkipValidation(boolean z) {
            this.skipValidation = z;
        }

        public void setDebugMode(boolean z) {
            this.debugMode = z;
        }

        public void setSkipVariables(Set<String> set) {
            this.skipVariables = set;
        }

        public void setGradCheckMask(Map<String, INDArray> map) {
            this.gradCheckMask = map;
        }

        public void setMaxPerParam(int i) {
            this.maxPerParam = i;
        }

        public void setSubset(Subset subset) {
            this.subset = subset;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ActGradConfig)) {
                return false;
            }
            ActGradConfig actGradConfig = (ActGradConfig) obj;
            if (!actGradConfig.canEqual(this)) {
                return false;
            }
            SameDiff sd = getSd();
            SameDiff sd2 = actGradConfig.getSd();
            if (sd == null) {
                if (sd2 != null) {
                    return false;
                }
            } else if (!sd.equals(sd2)) {
                return false;
            }
            Map<String, INDArray> placeholderValues = getPlaceholderValues();
            Map<String, INDArray> placeholderValues2 = actGradConfig.getPlaceholderValues();
            if (placeholderValues == null) {
                if (placeholderValues2 != null) {
                    return false;
                }
            } else if (!placeholderValues.equals(placeholderValues2)) {
                return false;
            }
            List<String> activationGradsToCheck = getActivationGradsToCheck();
            List<String> activationGradsToCheck2 = actGradConfig.getActivationGradsToCheck();
            if (activationGradsToCheck == null) {
                if (activationGradsToCheck2 != null) {
                    return false;
                }
            } else if (!activationGradsToCheck.equals(activationGradsToCheck2)) {
                return false;
            }
            if (Double.compare(getEps(), actGradConfig.getEps()) != 0 || Double.compare(getMaxRelError(), actGradConfig.getMaxRelError()) != 0 || Double.compare(getMinAbsError(), actGradConfig.getMinAbsError()) != 0 || isPrint() != actGradConfig.isPrint() || isExitOnFirstFailure() != actGradConfig.isExitOnFirstFailure() || isSkipValidation() != actGradConfig.isSkipValidation() || isDebugMode() != actGradConfig.isDebugMode()) {
                return false;
            }
            Set<String> skipVariables = getSkipVariables();
            Set<String> skipVariables2 = actGradConfig.getSkipVariables();
            if (skipVariables == null) {
                if (skipVariables2 != null) {
                    return false;
                }
            } else if (!skipVariables.equals(skipVariables2)) {
                return false;
            }
            Map<String, INDArray> gradCheckMask = getGradCheckMask();
            Map<String, INDArray> gradCheckMask2 = actGradConfig.getGradCheckMask();
            if (gradCheckMask == null) {
                if (gradCheckMask2 != null) {
                    return false;
                }
            } else if (!gradCheckMask.equals(gradCheckMask2)) {
                return false;
            }
            if (getMaxPerParam() != actGradConfig.getMaxPerParam()) {
                return false;
            }
            Subset subset = getSubset();
            Subset subset2 = actGradConfig.getSubset();
            return subset == null ? subset2 == null : subset.equals(subset2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof ActGradConfig;
        }

        public int hashCode() {
            SameDiff sd = getSd();
            int hashCode = (1 * 59) + (sd == null ? 43 : sd.hashCode());
            Map<String, INDArray> placeholderValues = getPlaceholderValues();
            int hashCode2 = (hashCode * 59) + (placeholderValues == null ? 43 : placeholderValues.hashCode());
            List<String> activationGradsToCheck = getActivationGradsToCheck();
            int hashCode3 = (hashCode2 * 59) + (activationGradsToCheck == null ? 43 : activationGradsToCheck.hashCode());
            long doubleToLongBits = Double.doubleToLongBits(getEps());
            int i = (hashCode3 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
            long doubleToLongBits2 = Double.doubleToLongBits(getMaxRelError());
            int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
            long doubleToLongBits3 = Double.doubleToLongBits(getMinAbsError());
            int i3 = (((((((((i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3))) * 59) + (isPrint() ? 79 : 97)) * 59) + (isExitOnFirstFailure() ? 79 : 97)) * 59) + (isSkipValidation() ? 79 : 97)) * 59) + (isDebugMode() ? 79 : 97);
            Set<String> skipVariables = getSkipVariables();
            int hashCode4 = (i3 * 59) + (skipVariables == null ? 43 : skipVariables.hashCode());
            Map<String, INDArray> gradCheckMask = getGradCheckMask();
            int hashCode5 = (((hashCode4 * 59) + (gradCheckMask == null ? 43 : gradCheckMask.hashCode())) * 59) + getMaxPerParam();
            Subset subset = getSubset();
            return (hashCode5 * 59) + (subset == null ? 43 : subset.hashCode());
        }

        public String toString() {
            return "GradCheckUtil.ActGradConfig(sd=" + getSd() + ", placeholderValues=" + getPlaceholderValues() + ", activationGradsToCheck=" + getActivationGradsToCheck() + ", eps=" + getEps() + ", maxRelError=" + getMaxRelError() + ", minAbsError=" + getMinAbsError() + ", print=" + isPrint() + ", exitOnFirstFailure=" + isExitOnFirstFailure() + ", skipValidation=" + isSkipValidation() + ", debugMode=" + isDebugMode() + ", skipVariables=" + getSkipVariables() + ", gradCheckMask=" + getGradCheckMask() + ", maxPerParam=" + getMaxPerParam() + ", subset=" + getSubset() + ")";
        }

        static /* synthetic */ double access$000() {
            return $default$eps();
        }

        static /* synthetic */ double access$100() {
            return $default$maxRelError();
        }

        static /* synthetic */ double access$200() {
            return $default$minAbsError();
        }

        static /* synthetic */ boolean access$300() {
            return $default$print();
        }

        static /* synthetic */ boolean access$400() {
            return $default$exitOnFirstFailure();
        }

        static /* synthetic */ boolean access$500() {
            return $default$skipValidation();
        }

        static /* synthetic */ boolean access$600() {
            return $default$debugMode();
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/validation/GradCheckUtil$Subset.class */
    public enum Subset {
        EVERY_N,
        RANDOM
    }

    public static boolean checkGradients(TestCase testCase) {
        return checkGradients(testCase.sameDiff(), testCase.placeholderValues(), testCase.gradCheckEpsilon(), testCase.gradCheckMaxRelativeError(), testCase.gradCheckMinAbsError(), testCase.gradCheckPrint(), testCase.gradCheckDefaultExitFirstFailure(), false, testCase.gradCheckDebugMode(), testCase.gradCheckSkipVariables(), testCase.gradCheckMask());
    }

    public static boolean checkGradients(SameDiff sameDiff, Map<String, INDArray> map, String... strArr) {
        HashSet hashSet = null;
        if (strArr != null) {
            hashSet = new HashSet();
            Collections.addAll(hashSet, strArr);
        }
        return checkGradients(sameDiff, map, 1.0E-5d, 1.0E-5d, 1.0E-6d, false, false, false, false, hashSet, null);
    }

    public static boolean checkGradients(SameDiff sameDiff, Map<String, INDArray> map, boolean z, boolean z2) {
        return checkGradients(sameDiff, map, 1.0E-5d, 1.0E-5d, 1.0E-6d, z, z2);
    }

    public static boolean checkGradients(SameDiff sameDiff, Map<String, INDArray> map, double d, double d2, double d3, boolean z, boolean z2) {
        return checkGradients(sameDiff, map, d, d2, d3, z, z2, false, false, null, null);
    }

    public static boolean checkGradients(SameDiff sameDiff, Map<String, INDArray> map, double d, double d2, double d3, boolean z, boolean z2, boolean z3, boolean z4, Set<String> set, Map<String, INDArray> map2) {
        return checkGradients(sameDiff, map, d, d2, d3, z, z2, z3, z4, set, map2, -1, null);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v236, types: [java.util.Iterator] */
    /* JADX WARN: Type inference failed for: r7v0, types: [org.nd4j.autodiff.samediff.SameDiff] */
    public static boolean checkGradients(SameDiff sameDiff, Map<String, INDArray> map, double d, double d2, double d3, boolean z, boolean z2, boolean z3, boolean z4, Set<String> set, Map<String, INDArray> map2, int i, Subset subset) {
        NdIndexIterator ndIndexIterator;
        boolean isDebugMode = sameDiff.isDebugMode();
        if (z4) {
            sameDiff.enableDebugMode();
        }
        if (!z3) {
            validateInternalState(sameDiff, true);
        }
        if (Nd4j.dataType() != DataType.DOUBLE) {
            throw new IllegalStateException("Data type must be set to double");
        }
        HashSet hashSet = new HashSet();
        for (DifferentialFunction differentialFunction : sameDiff.ops()) {
            for (SDVariable sDVariable : differentialFunction.outputVariables()) {
                hashSet.add(sDVariable.name());
            }
        }
        for (Variable variable : sameDiff.getVariables().values()) {
            if (variable.getVariable().getVariableType() != VariableType.ARRAY && variable.getVariable().getArr(true) == null) {
                throw new IllegalStateException("Variable \"" + variable.getName() + "\" does not have array associated with it");
            }
        }
        List<String> lossVariables = sameDiff.getLossVariables();
        Preconditions.checkState((lossVariables == null || lossVariables.isEmpty()) ? false : true, "Expected 1 or more loss function variables for gradient check, got %s", lossVariables);
        HashSet hashSet2 = new HashSet();
        for (Variable variable2 : sameDiff.getVariables().values()) {
            if (variable2.getVariable().dataType().isFPType() && (variable2.getVariable().getVariableType() == VariableType.VARIABLE || variable2.getVariable().getVariableType() == VariableType.PLACEHOLDER)) {
                Preconditions.checkNotNull(variable2.getVariable().getGradient(), "No gradient variable found for variable %s", variable2.getVariable());
                hashSet2.add(variable2.getName());
            }
        }
        ArrayList arrayList = new ArrayList(sameDiff.getListeners());
        int i2 = -1;
        if (arrayList.isEmpty()) {
            sameDiff.addListeners(new NonInplaceValidationListener());
            i2 = 0;
        } else {
            boolean z5 = false;
            int i3 = 0;
            Iterator it = arrayList.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (((Listener) it.next()) instanceof NonInplaceValidationListener) {
                    z5 = true;
                    i2 = i3;
                    break;
                }
                i3++;
            }
            if (!z5) {
                sameDiff.addListeners(new NonInplaceValidationListener());
                i2 = i3;
            }
        }
        Map<String, INDArray> calculateGradients = sameDiff.calculateGradients(map, hashSet2);
        sameDiff.getListeners().remove(i2);
        HashMap hashMap = new HashMap();
        for (SDVariable sDVariable2 : sameDiff.variables()) {
            if (!hashSet.contains(sDVariable2.name()) && sDVariable2.hasGradient()) {
                if (sameDiff.grad(sDVariable2.name()) == null) {
                    throw new IllegalStateException("Null gradient variable for \"" + sDVariable2.name() + "\"");
                }
                INDArray iNDArray = calculateGradients.get(sDVariable2.name());
                if (iNDArray == null) {
                    throw new IllegalStateException("Null gradient array encountered for variable: " + sDVariable2.name());
                }
                if (!Arrays.equals(sDVariable2.getArr().shape(), iNDArray.shape())) {
                    throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + sDVariable2.name() + "\": shape " + Arrays.toString(sDVariable2.getArr().shape()) + " vs. gradient shape " + Arrays.toString(iNDArray.shape()));
                }
                hashMap.put(sDVariable2.name(), iNDArray.dup());
            }
        }
        int i4 = 0;
        int i5 = 0;
        double d4 = 0.0d;
        Random random = new Random(12345L);
        for (SDVariable sDVariable3 : sameDiff.variables()) {
            if (!hashSet.contains(sDVariable3.name()) && sDVariable3.dataType().isFPType()) {
                if (set == null || !set.contains(sDVariable3.name())) {
                    if (sDVariable3.dataType() != DataType.DOUBLE) {
                        log.warn("DataType for variable {} is not double (is: {}) may cause precision issues in gradient checks", sDVariable3.name(), sDVariable3.dataType());
                    }
                    String name = sDVariable3.name();
                    INDArray arr = sDVariable3.getArr();
                    long length = arr.length();
                    if (z) {
                        log.info("Starting test for variable \"{}\" with {} values", sDVariable3.name(), Long.valueOf(length));
                    }
                    if (i <= 0 || subset == null || i >= arr.length()) {
                        ndIndexIterator = new NdIndexIterator('c', arr.shape());
                    } else {
                        long[] shape = arr.shape();
                        ArrayList arrayList2 = new ArrayList();
                        if (subset != Subset.RANDOM) {
                            long j = length / i;
                            long j2 = 0;
                            while (true) {
                                long j3 = j2;
                                if (j3 >= length) {
                                    break;
                                }
                                arrayList2.add(Shape.ind2subC(shape, j3));
                                j2 = j3 + j;
                            }
                        } else {
                            HashSet hashSet3 = new HashSet();
                            while (hashSet3.size() < i) {
                                hashSet3.add(Integer.valueOf(random.nextInt((int) arr.length())));
                            }
                            ArrayList arrayList3 = new ArrayList(hashSet3);
                            Collections.sort(arrayList3);
                            Iterator it2 = arrayList3.iterator();
                            while (it2.hasNext()) {
                                arrayList2.add(Shape.ind2subC(shape, ((Integer) it2.next()).intValue()));
                            }
                        }
                        ndIndexIterator = arrayList2.iterator();
                    }
                    INDArray iNDArray2 = map2 == null ? null : map2.get(sDVariable3.name());
                    if (iNDArray2 != null) {
                        Preconditions.checkState(arr.equalShapes(iNDArray2), "Variable \"%s\": Gradient check mask and array shapes must be equal: got %s vs. mask shape %s", sDVariable3.name(), arr.shape(), iNDArray2.shape());
                        Preconditions.checkState(iNDArray2.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", sDVariable3.name(), iNDArray2.dataType());
                    }
                    int i6 = 0;
                    while (ndIndexIterator.hasNext()) {
                        long[] next = ndIndexIterator.next();
                        String replaceAll = z ? Arrays.toString(next).replaceAll(StringUtils.SPACE, "") : null;
                        if (iNDArray2 == null || iNDArray2.getDouble(next) != 0.0d) {
                            i5++;
                            double d5 = arr.getDouble(next);
                            arr.putScalar(next, d5 + d);
                            double d6 = 0.0d;
                            Iterator<INDArray> it3 = sameDiff.output(map, lossVariables).values().iterator();
                            while (it3.hasNext()) {
                                d6 += it3.next().sumNumber().doubleValue();
                            }
                            arr.putScalar(next, d5 - d);
                            double d7 = 0.0d;
                            Iterator<INDArray> it4 = sameDiff.output(map, lossVariables).values().iterator();
                            while (it4.hasNext()) {
                                d7 += it4.next().sumNumber().doubleValue();
                            }
                            arr.putScalar(next, d5);
                            double d8 = (d6 - d7) / (2.0d * d);
                            INDArray iNDArray3 = (INDArray) hashMap.get(sDVariable3.name());
                            if (iNDArray3 == null) {
                                log.warn("No gradient array for variable \"{}\" was found, skipping variable...", sDVariable3.name());
                            } else {
                                double d9 = iNDArray3.getDouble(next);
                                if (Double.isInfinite(d8) || Double.isNaN(d8)) {
                                    throw new IllegalStateException("Numerical gradient was " + d8 + " for variable \"" + name + "\", parameter " + i6 + " of " + length + " (position: " + replaceAll + ")");
                                }
                                if (Double.isInfinite(d9) || Double.isNaN(d9)) {
                                    throw new IllegalStateException("Analytic (SameDiff) gradient was " + d9 + " for variable \"" + name + "\", parameter " + i6 + " of " + length + " (position: " + replaceAll + ")");
                                }
                                double abs = (d8 == 0.0d && d9 == 0.0d) ? 0.0d : Math.abs(d9 - d8) / Math.abs(Math.abs(d9) + Math.abs(d8));
                                if (abs > d4) {
                                    d4 = abs;
                                }
                                if (abs > d2 || Double.isNaN(abs)) {
                                    double abs2 = Math.abs(d9 - d8);
                                    if (abs2 >= d3) {
                                        log.info("Param " + i6 + " (" + name + replaceAll + ") FAILED: grad= " + d9 + ", numericalGrad= " + d8 + ", relError= " + abs + ", absError=" + abs2 + ", scorePlus=" + d6 + ", scoreMinus= " + d7);
                                        if (z2) {
                                            return false;
                                        }
                                        i4++;
                                    } else if (z) {
                                        log.info("Param " + i6 + " (" + name + replaceAll + ") passed: grad= " + d9 + ", numericalGrad= " + d8 + ", relError= " + abs + "; absolute error = " + abs2 + " < minAbsoluteError = " + d3);
                                    }
                                } else if (z) {
                                    log.info("Param " + i6 + " (" + name + replaceAll + ") passed: grad= " + d9 + ", numericalGrad= " + d8 + ", relError= " + abs);
                                }
                                i6++;
                            }
                        }
                    }
                } else {
                    log.info("Grad check: skipping variable \"{}\"", sDVariable3.name());
                }
            }
        }
        log.info("GradCheckUtil.checkGradients(): " + i5 + " params checked, " + (i5 - i4) + " passed, " + i4 + " failed. Largest relative error = " + d4);
        if (z4 && !isDebugMode) {
            sameDiff.disableDebugging();
        }
        return i4 == 0;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v210, types: [java.util.Iterator] */
    public static boolean checkActivationGradients(ActGradConfig actGradConfig) {
        NdIndexIterator ndIndexIterator;
        SameDiff sd = actGradConfig.getSd();
        List<String> activationGradsToCheck = actGradConfig.getActivationGradsToCheck();
        double maxRelError = actGradConfig.getMaxRelError();
        double minAbsError = actGradConfig.getMinAbsError();
        Preconditions.checkState(sd != null, "SameDiff instance was not set in configuration");
        Preconditions.checkState((activationGradsToCheck == null || activationGradsToCheck.isEmpty()) ? false : true, "No activation gradients were specified to gradient check");
        Preconditions.checkState(actGradConfig.getEps() > 0.0d, "Epsilon has not been set");
        Preconditions.checkState(maxRelError > 0.0d, "Max relative error must be set (is 0.0)");
        for (String str : activationGradsToCheck) {
            SDVariable variable = sd.getVariables().get(str).getVariable();
            Preconditions.checkState(variable != null, "No variable with name \"%s\" was found", str);
            Preconditions.checkState(variable.getVariableType() == VariableType.ARRAY, "Only variables with type ARRAY may be gradient checked using this method. Variable \"%s\" has type %s", str, variable.getVariableType());
            Preconditions.checkState(variable.dataType().isFPType(), "Cannot gradient check activation variable \"%s\": must be floating point type. Is type: %s", str, variable.dataType());
            if (variable.dataType() != DataType.DOUBLE) {
                log.warn("Floating point variable {} is not double precision - this may result in spurious failures due to limited precision. Variable is type: {}", str, variable.dataType());
            }
        }
        sd.isDebugMode();
        if (actGradConfig.isDebugMode()) {
            sd.enableDebugMode();
        }
        if (!actGradConfig.isSkipValidation()) {
            validateInternalState(sd, true);
        }
        List<String> lossVariables = sd.getLossVariables();
        Preconditions.checkState((lossVariables == null || lossVariables.isEmpty()) ? false : true, "Expected 1 or more loss function variables for gradient check, got %s", lossVariables);
        sd.createGradFunction();
        HashSet hashSet = new HashSet();
        for (String str2 : activationGradsToCheck) {
            Preconditions.checkState(sd.getVariable(str2).gradient() != null, "Could not get gradient for activation \"%s\": gradient variable is null", str2);
            hashSet.add(str2);
        }
        Map<String, INDArray> calculateGradients = sd.calculateGradients(actGradConfig.getPlaceholderValues(), new ArrayList(hashSet));
        HashMap hashMap = new HashMap();
        for (String str3 : activationGradsToCheck) {
            INDArray iNDArray = calculateGradients.get(str3);
            Preconditions.checkState(iNDArray != null, "No activation gradient array for variable \"%s\"", str3);
            hashMap.put(str3, iNDArray.dup());
        }
        int i = 0;
        double d = 0.0d;
        ActivationGradientCheckListener activationGradientCheckListener = new ActivationGradientCheckListener();
        sd.setListeners(activationGradientCheckListener);
        Random random = new Random(12345L);
        int maxPerParam = actGradConfig.getMaxPerParam();
        for (String str4 : activationGradsToCheck) {
            long length = ((INDArray) hashMap.get(str4)).length();
            if (actGradConfig.isPrint()) {
                log.info("Starting test for variable \"{}\" with {} values", str4, Long.valueOf(length));
            }
            if (maxPerParam <= 0 || actGradConfig.getSubset() == null || maxPerParam >= length) {
                ndIndexIterator = new NdIndexIterator('c', ((INDArray) hashMap.get(str4)).shape());
            } else {
                long[] shape = ((INDArray) hashMap.get(str4)).shape();
                ArrayList arrayList = new ArrayList();
                if (actGradConfig.getSubset() != Subset.RANDOM) {
                    long j = length / maxPerParam;
                    long j2 = 0;
                    while (true) {
                        long j3 = j2;
                        if (j3 >= length) {
                            break;
                        }
                        arrayList.add(Shape.ind2subC(shape, j3));
                        j2 = j3 + j;
                    }
                } else {
                    HashSet hashSet2 = new HashSet();
                    while (hashSet2.size() < maxPerParam) {
                        hashSet2.add(Integer.valueOf(random.nextInt((int) length)));
                    }
                    ArrayList arrayList2 = new ArrayList(hashSet2);
                    Collections.sort(arrayList2);
                    Iterator it = arrayList2.iterator();
                    while (it.hasNext()) {
                        arrayList.add(Shape.ind2subC(shape, ((Integer) it.next()).intValue()));
                    }
                }
                ndIndexIterator = arrayList.iterator();
            }
            INDArray iNDArray2 = actGradConfig.getGradCheckMask() == null ? null : actGradConfig.getGradCheckMask().get(str4);
            activationGradientCheckListener.setVariableName(str4);
            int i2 = 0;
            while (ndIndexIterator.hasNext()) {
                long[] next = ndIndexIterator.next();
                String replaceAll = actGradConfig.isPrint() ? Arrays.toString(next).replaceAll(StringUtils.SPACE, "") : null;
                if (iNDArray2 == null || iNDArray2.getDouble(next) != 0.0d) {
                    activationGradientCheckListener.setIdx(next);
                    activationGradientCheckListener.setEps(actGradConfig.getEps());
                    double d2 = 0.0d;
                    Iterator<INDArray> it2 = sd.output(actGradConfig.getPlaceholderValues(), lossVariables).values().iterator();
                    while (it2.hasNext()) {
                        d2 += it2.next().sumNumber().doubleValue();
                    }
                    activationGradientCheckListener.setEps(-actGradConfig.getEps());
                    double d3 = 0.0d;
                    Iterator<INDArray> it3 = sd.output(actGradConfig.getPlaceholderValues(), lossVariables).values().iterator();
                    while (it3.hasNext()) {
                        d3 += it3.next().sumNumber().doubleValue();
                    }
                    double eps = (d2 - d3) / (2.0d * actGradConfig.getEps());
                    double d4 = ((INDArray) hashMap.get(str4)).getDouble(next);
                    if (Double.isInfinite(eps) || Double.isNaN(eps)) {
                        throw new IllegalStateException("Numerical gradient was " + eps + " for variable \"" + str4 + "\", parameter " + i2 + " of " + length + " (position: " + replaceAll + ")");
                    }
                    if (Double.isInfinite(d4) || Double.isNaN(d4)) {
                        throw new IllegalStateException("Analytic (SameDiff) gradient was " + d4 + " for variable \"" + str4 + "\", parameter " + i2 + " of " + length + " (position: " + replaceAll + ")");
                    }
                    double abs = (eps == 0.0d && d4 == 0.0d) ? 0.0d : Math.abs(d4 - eps) / Math.abs(Math.abs(d4) + Math.abs(eps));
                    if (abs > d) {
                        d = abs;
                    }
                    if (abs > maxRelError || Double.isNaN(abs)) {
                        double abs2 = Math.abs(d4 - eps);
                        if (abs2 >= minAbsError) {
                            if (actGradConfig.isPrint()) {
                                log.info("Param " + i2 + " (" + str4 + replaceAll + ") FAILED: grad= " + d4 + ", numericalGrad= " + eps + ", relError= " + abs + ", absError=" + abs2 + ", scorePlus=" + d2 + ", scoreMinus= " + d3);
                            }
                            if (actGradConfig.isExitOnFirstFailure()) {
                                return false;
                            }
                            i++;
                        } else if (actGradConfig.isPrint()) {
                            log.info("Param " + i2 + " (" + str4 + replaceAll + ") passed: grad= " + d4 + ", numericalGrad= " + eps + ", relError= " + abs + "; absolute error = " + abs2 + " < minAbsoluteError = " + minAbsError);
                        }
                    } else if (actGradConfig.isPrint()) {
                        log.info("Param " + i2 + " (" + str4 + replaceAll + ") passed: grad= " + d4 + ", numericalGrad= " + eps + ", relError= " + abs);
                    }
                    i2++;
                }
            }
        }
        return i == 0;
    }

    public static void validateInternalState(SameDiff sameDiff, boolean z) {
        DifferentialFunction[] ops = sameDiff.ops();
        List<SDVariable> variables = sameDiff.variables();
        HashSet hashSet = new HashSet();
        for (SDVariable sDVariable : variables) {
            if (hashSet.contains(sDVariable.name())) {
                throw new IllegalStateException("Variable with name " + sDVariable.name() + " already encountered");
            }
            hashSet.add(sDVariable.name());
        }
        Preconditions.checkState(variables.size() == hashSet.size(), "Duplicate variables in variables() list");
        Map<String, SameDiffOp> ops2 = sameDiff.getOps();
        Preconditions.checkState(ops.length == ops2.size(), "All functions not present in incomingArgsReverse");
        for (DifferentialFunction differentialFunction : ops) {
            Preconditions.checkState(ops2.containsKey(differentialFunction.getOwnName()), differentialFunction.getOwnName() + " not present in ops map");
            List<String> inputsToOp = ops2.get(differentialFunction.getOwnName()).getInputsToOp();
            if (inputsToOp != null) {
                for (String str : inputsToOp) {
                    Preconditions.checkState(hashSet.contains(str), "Variable " + str + " in op inputs not a known variable name");
                }
            }
            List<String> outputsOfOp = ops2.get(differentialFunction.getOwnName()).getOutputsOfOp();
            if (outputsOfOp != null) {
                for (String str2 : outputsOfOp) {
                    Preconditions.checkState(hashSet.contains(str2), "Variable " + str2 + " in op outputs not a known variable name");
                }
            }
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, SameDiffOp> entry : ops2.entrySet()) {
            List<String> outputsOfOp2 = entry.getValue().getOutputsOfOp();
            if (outputsOfOp2 != null) {
                for (String str3 : outputsOfOp2) {
                    if (hashMap.containsKey(str3)) {
                        throw new IllegalStateException("Already saw variable \"" + str3 + "\" as output for op \"" + ((String) hashMap.get(str3)) + "\": expected variables to be present as an output only once; also seen as output for op \"" + entry.getKey() + "\"");
                    }
                    hashMap.put(str3, entry.getKey());
                }
            }
        }
        Map<String, Variable> variables2 = sameDiff.getVariables();
        Preconditions.checkState(variables.size() == variables2.size(), "Variable map size check failed");
        for (Map.Entry<String, Variable> entry2 : variables2.entrySet()) {
            Preconditions.checkState(entry2.getKey().equals(entry2.getValue().getVariable().name()), "Name not equal");
        }
        if (z) {
            if (sameDiff.getFunction(AdaGradUpdater.GRAD_STATE) == null) {
                sameDiff.createGradFunction();
            }
            SameDiff function = sameDiff.getFunction(AdaGradUpdater.GRAD_STATE);
            validateInternalState(function, false);
            for (DifferentialFunction differentialFunction2 : ops) {
                Preconditions.checkNotNull(function.getOpById(differentialFunction2.getOwnName()), "DifferentialFunction " + differentialFunction2.getOwnName() + " from original SameDiff instance not present in grad fn");
            }
        }
    }

    private static <T> T getObject(String str, Object obj, Class<?> cls) {
        try {
            Field declaredField = cls.getDeclaredField(str);
            declaredField.setAccessible(true);
            return (T) declaredField.get(obj);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
