package org.nd4j.autodiff.samediff;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import com.google.common.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder;
import com.rits.cloning.Cloner;
import com.rits.cloning.IFastCloner;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.apache.camel.util.URISupport;
import org.bytedeco.javacpp.BytePointer;
import org.nd4j.autodiff.execution.conf.ExecutionMode;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.functions.FunctionProperties;
import org.nd4j.autodiff.samediff.flow.FlowPath;
import org.nd4j.autodiff.util.cloner.DataBufferFastCloner;
import org.nd4j.autodiff.util.cloner.INDArrayFastCloner;
import org.nd4j.base.Preconditions;
import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.GradientOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.distances.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.accum.distances.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.accum.distances.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUConfiguration;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.collection.IntArrayKeyMap;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossCosineProximity;
import org.nd4j.linalg.lossfunctions.impl.LossHinge;
import org.nd4j.linalg.lossfunctions.impl.LossKLD;
import org.nd4j.linalg.lossfunctions.impl.LossL1;
import org.nd4j.linalg.lossfunctions.impl.LossL2;
import org.nd4j.linalg.lossfunctions.impl.LossMAE;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
import org.nd4j.linalg.lossfunctions.impl.LossMSLE;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import org.nd4j.linalg.lossfunctions.impl.LossPoisson;
import org.nd4j.linalg.lossfunctions.impl.LossSquaredHinge;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ConstantInitScheme;
import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.PropertyAccessor;

/* loaded from: input_file:org/nd4j/autodiff/samediff/SameDiff.class */
public class SameDiff {
    private Map<String[], DifferentialFunction> incomingArgs;
    private Map<String[], DifferentialFunction> outgoingArgs;
    private Map<String, String[]> incomingArgsReverse;
    private Map<String, String[]> outgoingArgsReverse;
    private Map<String, int[]> permuteOrder;
    private boolean shouldBootStrap;
    private Set<String> importedVarName;
    private Map<String, String> baseNameForFunctionInstanceId;
    private DifferentialFunctionFactory functionFactory;
    private Map<String, SDVariable> variableMap;
    private Map<String, int[]> variableNameToShape;
    private Map<String, SDVariable> gradients;
    private Map<String, SDVariable> forwardVarForGrad;
    private Map<String, INDArray> variableNameToArr;
    private Map<String, List<DifferentialFunction>> functionsArgsFor;
    private Map<String, List<DifferentialFunction>> functionOutputFor;
    private ThreadLocal<FlowPath> localFlowPath;
    private Map<String, List<String>> propertiesToResolve;
    private Map<String, Map<String, Object>> propertiesForFunction;
    private Map<String, List<String[]>> placeHolderMap;
    private Map<String, int[]> placeHolderOriginalShapes;
    private Set<String> placeHolderVarNames;
    private IdentityHashMap<INDArray, SDVariable> reverseArrayLookup;
    private MemoryWorkspace workspace;
    private Map<String, SameDiffFunctionDefinition> sameDiffFunctionDefinitionMap;
    private Map<String, SameDiff> sameDiffFunctionInstances;
    private Set<String> placeHolderFunctions;
    private Map<String, DifferentialFunction> functionInstancesById;
    private Table<String, String, String> fieldVariableResolutionMapping;
    private transient AtomicBoolean wasRegistered;
    private boolean debugMode;
    private Map<int[], Op> opsForResult;
    private boolean resolvedVariables;
    boolean logExecution;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) SameDiff.class);
    private static Cloner cloner = newCloner();
    private static Map<String, Method> opMethods = new HashMap();

    /* loaded from: input_file:org/nd4j/autodiff/samediff/SameDiff$DefaultSameDiffConditional.class */
    public static class DefaultSameDiffConditional implements SameDiffConditional {
        @Override // org.nd4j.autodiff.samediff.SameDiff.SameDiffConditional
        public SDVariable eval(SameDiff sameDiff, SameDiffFunctionDefinition sameDiffFunctionDefinition, SDVariable[] sDVariableArr) {
            sameDiff.defineFunction("eval", sameDiffFunctionDefinition, sDVariableArr);
            sameDiff.invokeFunctionOn("eval", sameDiff);
            return ((DifferentialFunction) new ArrayList(sameDiff.functionInstancesById.values()).get(sameDiff.functionInstancesById.size() - 1)).outputVariables()[0];
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/SameDiff$SameDiffBuilder.class */
    public static class SameDiffBuilder {
        private Map<String[], DifferentialFunction> incomingArgs;
        private Map<String[], DifferentialFunction> outgoingArgs;
        private Map<String, String[]> incomingArgsReverse;
        private Map<String, String[]> outgoingArgsReverse;
        private Map<String, int[]> permuteOrder;
        private boolean shouldBootStrap;
        private Set<String> importedVarName;
        private Map<String, String> baseNameForFunctionInstanceId;
        private DifferentialFunctionFactory functionFactory;
        private Map<String, SDVariable> variableMap;
        private Map<String, int[]> variableNameToShape;
        private Map<String, SDVariable> gradients;
        private Map<String, SDVariable> forwardVarForGrad;
        private Map<String, INDArray> variableNameToArr;
        private Map<String, List<DifferentialFunction>> functionsArgsFor;
        private Map<String, List<DifferentialFunction>> functionOutputFor;
        private ThreadLocal<FlowPath> localFlowPath;
        private Map<String, List<String>> propertiesToResolve;
        private Map<String, Map<String, Object>> propertiesForFunction;
        private Map<String, List<String[]>> placeHolderMap;
        private Map<String, int[]> placeHolderOriginalShapes;
        private Set<String> placeHolderVarNames;
        private IdentityHashMap<INDArray, SDVariable> reverseArrayLookup;
        private MemoryWorkspace workspace;
        private Map<String, SameDiffFunctionDefinition> sameDiffFunctionDefinitionMap;
        private Map<String, SameDiff> sameDiffFunctionInstances;
        private Set<String> placeHolderFunctions;
        private Map<String, DifferentialFunction> functionInstancesById;
        private Table<String, String, String> fieldVariableResolutionMapping;
        private AtomicBoolean wasRegistered;
        private boolean debugMode;
        private Map<int[], Op> opsForResult;
        private boolean resolvedVariables;
        private boolean logExecution;

        SameDiffBuilder() {
        }

        public SameDiffBuilder incomingArgs(Map<String[], DifferentialFunction> map) {
            this.incomingArgs = map;
            return this;
        }

        public SameDiffBuilder outgoingArgs(Map<String[], DifferentialFunction> map) {
            this.outgoingArgs = map;
            return this;
        }

        public SameDiffBuilder incomingArgsReverse(Map<String, String[]> map) {
            this.incomingArgsReverse = map;
            return this;
        }

        public SameDiffBuilder outgoingArgsReverse(Map<String, String[]> map) {
            this.outgoingArgsReverse = map;
            return this;
        }

        public SameDiffBuilder permuteOrder(Map<String, int[]> map) {
            this.permuteOrder = map;
            return this;
        }

        public SameDiffBuilder shouldBootStrap(boolean z) {
            this.shouldBootStrap = z;
            return this;
        }

        public SameDiffBuilder importedVarName(Set<String> set) {
            this.importedVarName = set;
            return this;
        }

        public SameDiffBuilder baseNameForFunctionInstanceId(Map<String, String> map) {
            this.baseNameForFunctionInstanceId = map;
            return this;
        }

        public SameDiffBuilder functionFactory(DifferentialFunctionFactory differentialFunctionFactory) {
            this.functionFactory = differentialFunctionFactory;
            return this;
        }

        public SameDiffBuilder variableMap(Map<String, SDVariable> map) {
            this.variableMap = map;
            return this;
        }

        public SameDiffBuilder variableNameToShape(Map<String, int[]> map) {
            this.variableNameToShape = map;
            return this;
        }

        public SameDiffBuilder gradients(Map<String, SDVariable> map) {
            this.gradients = map;
            return this;
        }

        public SameDiffBuilder forwardVarForGrad(Map<String, SDVariable> map) {
            this.forwardVarForGrad = map;
            return this;
        }

        public SameDiffBuilder variableNameToArr(Map<String, INDArray> map) {
            this.variableNameToArr = map;
            return this;
        }

        public SameDiffBuilder functionsArgsFor(Map<String, List<DifferentialFunction>> map) {
            this.functionsArgsFor = map;
            return this;
        }

        public SameDiffBuilder functionOutputFor(Map<String, List<DifferentialFunction>> map) {
            this.functionOutputFor = map;
            return this;
        }

        public SameDiffBuilder localFlowPath(ThreadLocal<FlowPath> threadLocal) {
            this.localFlowPath = threadLocal;
            return this;
        }

        public SameDiffBuilder propertiesToResolve(Map<String, List<String>> map) {
            this.propertiesToResolve = map;
            return this;
        }

        public SameDiffBuilder propertiesForFunction(Map<String, Map<String, Object>> map) {
            this.propertiesForFunction = map;
            return this;
        }

        public SameDiffBuilder placeHolderMap(Map<String, List<String[]>> map) {
            this.placeHolderMap = map;
            return this;
        }

        public SameDiffBuilder placeHolderOriginalShapes(Map<String, int[]> map) {
            this.placeHolderOriginalShapes = map;
            return this;
        }

        public SameDiffBuilder placeHolderVarNames(Set<String> set) {
            this.placeHolderVarNames = set;
            return this;
        }

        public SameDiffBuilder reverseArrayLookup(IdentityHashMap<INDArray, SDVariable> identityHashMap) {
            this.reverseArrayLookup = identityHashMap;
            return this;
        }

        public SameDiffBuilder workspace(MemoryWorkspace memoryWorkspace) {
            this.workspace = memoryWorkspace;
            return this;
        }

        public SameDiffBuilder sameDiffFunctionDefinitionMap(Map<String, SameDiffFunctionDefinition> map) {
            this.sameDiffFunctionDefinitionMap = map;
            return this;
        }

        public SameDiffBuilder sameDiffFunctionInstances(Map<String, SameDiff> map) {
            this.sameDiffFunctionInstances = map;
            return this;
        }

        public SameDiffBuilder placeHolderFunctions(Set<String> set) {
            this.placeHolderFunctions = set;
            return this;
        }

        public SameDiffBuilder functionInstancesById(Map<String, DifferentialFunction> map) {
            this.functionInstancesById = map;
            return this;
        }

        public SameDiffBuilder fieldVariableResolutionMapping(Table<String, String, String> table) {
            this.fieldVariableResolutionMapping = table;
            return this;
        }

        public SameDiffBuilder wasRegistered(AtomicBoolean atomicBoolean) {
            this.wasRegistered = atomicBoolean;
            return this;
        }

        public SameDiffBuilder debugMode(boolean z) {
            this.debugMode = z;
            return this;
        }

        public SameDiffBuilder opsForResult(Map<int[], Op> map) {
            this.opsForResult = map;
            return this;
        }

        public SameDiffBuilder resolvedVariables(boolean z) {
            this.resolvedVariables = z;
            return this;
        }

        public SameDiffBuilder logExecution(boolean z) {
            this.logExecution = z;
            return this;
        }

        public SameDiff build() {
            return new SameDiff(this.incomingArgs, this.outgoingArgs, this.incomingArgsReverse, this.outgoingArgsReverse, this.permuteOrder, this.shouldBootStrap, this.importedVarName, this.baseNameForFunctionInstanceId, this.functionFactory, this.variableMap, this.variableNameToShape, this.gradients, this.forwardVarForGrad, this.variableNameToArr, this.functionsArgsFor, this.functionOutputFor, this.localFlowPath, this.propertiesToResolve, this.propertiesForFunction, this.placeHolderMap, this.placeHolderOriginalShapes, this.placeHolderVarNames, this.reverseArrayLookup, this.workspace, this.sameDiffFunctionDefinitionMap, this.sameDiffFunctionInstances, this.placeHolderFunctions, this.functionInstancesById, this.fieldVariableResolutionMapping, this.wasRegistered, this.debugMode, this.opsForResult, this.resolvedVariables, this.logExecution);
        }

        public String toString() {
            return "SameDiff.SameDiffBuilder(incomingArgs=" + this.incomingArgs + ", outgoingArgs=" + this.outgoingArgs + ", incomingArgsReverse=" + this.incomingArgsReverse + ", outgoingArgsReverse=" + this.outgoingArgsReverse + ", permuteOrder=" + this.permuteOrder + ", shouldBootStrap=" + this.shouldBootStrap + ", importedVarName=" + this.importedVarName + ", baseNameForFunctionInstanceId=" + this.baseNameForFunctionInstanceId + ", functionFactory=" + this.functionFactory + ", variableMap=" + this.variableMap + ", variableNameToShape=" + this.variableNameToShape + ", gradients=" + this.gradients + ", forwardVarForGrad=" + this.forwardVarForGrad + ", variableNameToArr=" + this.variableNameToArr + ", functionsArgsFor=" + this.functionsArgsFor + ", functionOutputFor=" + this.functionOutputFor + ", localFlowPath=" + this.localFlowPath + ", propertiesToResolve=" + this.propertiesToResolve + ", propertiesForFunction=" + this.propertiesForFunction + ", placeHolderMap=" + this.placeHolderMap + ", placeHolderOriginalShapes=" + this.placeHolderOriginalShapes + ", placeHolderVarNames=" + this.placeHolderVarNames + ", reverseArrayLookup=" + this.reverseArrayLookup + ", workspace=" + this.workspace + ", sameDiffFunctionDefinitionMap=" + this.sameDiffFunctionDefinitionMap + ", sameDiffFunctionInstances=" + this.sameDiffFunctionInstances + ", placeHolderFunctions=" + this.placeHolderFunctions + ", functionInstancesById=" + this.functionInstancesById + ", fieldVariableResolutionMapping=" + this.fieldVariableResolutionMapping + ", wasRegistered=" + this.wasRegistered + ", debugMode=" + this.debugMode + ", opsForResult=" + this.opsForResult + ", resolvedVariables=" + this.resolvedVariables + ", logExecution=" + this.logExecution + URISupport.RAW_TOKEN_END;
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/SameDiff$SameDiffConditional.class */
    public interface SameDiffConditional {
        SDVariable eval(SameDiff sameDiff, SameDiffFunctionDefinition sameDiffFunctionDefinition, SDVariable[] sDVariableArr);
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/SameDiff$SameDiffFunctionDefinition.class */
    public interface SameDiffFunctionDefinition {
        SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> map, SDVariable[] sDVariableArr);
    }

    public static Cloner newCloner() {
        Cloner cloner2 = new Cloner();
        INDArrayFastCloner iNDArrayFastCloner = new INDArrayFastCloner();
        cloner2.registerFastCloner(Nd4j.getBackend().getNDArrayClass(), iNDArrayFastCloner);
        cloner2.registerFastCloner(Nd4j.getBackend().getComplexNDArrayClass(), iNDArrayFastCloner);
        DataBufferFastCloner dataBufferFastCloner = new DataBufferFastCloner();
        DataBufferFactory dataBufferFactory = Nd4j.getDataBufferFactory();
        doReg(cloner2, dataBufferFastCloner, dataBufferFactory.intBufferClass());
        doReg(cloner2, dataBufferFastCloner, dataBufferFactory.longBufferClass());
        doReg(cloner2, dataBufferFastCloner, dataBufferFactory.halfBufferClass());
        doReg(cloner2, dataBufferFastCloner, dataBufferFactory.floatBufferClass());
        doReg(cloner2, dataBufferFastCloner, dataBufferFactory.doubleBufferClass());
        doReg(cloner2, dataBufferFastCloner, CompressedDataBuffer.class);
        return cloner2;
    }

    private static void doReg(Cloner cloner2, IFastCloner iFastCloner, Class<?> cls) {
        if (cls != null) {
            cloner2.registerFastCloner(cls, iFastCloner);
        }
    }

    public void updateVariableName(String str, String str2) {
        SDVariable variable = getVariable(str);
        this.variableMap.remove(variable.getVarName());
        variable.setVarName(str2);
        this.variableMap.put(str2, variable);
        for (Map.Entry<String, String[]> entry : this.outgoingArgsReverse.entrySet()) {
            for (int i = 0; i < entry.getValue().length; i++) {
                if (entry.getValue()[i].equals(str)) {
                    entry.getValue()[i] = str2;
                }
            }
        }
        for (Map.Entry<String, String[]> entry2 : this.incomingArgsReverse.entrySet()) {
            for (int i2 = 0; i2 < entry2.getValue().length; i2++) {
                if (entry2.getValue()[i2].equals(str)) {
                    entry2.getValue()[i2] = str2;
                }
            }
        }
        if (this.variableNameToArr.containsKey(str)) {
            this.variableNameToArr.put(str2, this.variableNameToArr.remove(str));
        }
        if (this.variableNameToShape.containsKey(str)) {
            this.variableNameToShape.put(str2, this.variableNameToShape.remove(str));
        }
        if (this.gradients.containsKey(str)) {
            this.gradients.put(str2, this.gradients.remove(str));
        }
        if (this.forwardVarForGrad.containsKey(str)) {
            this.forwardVarForGrad.put(str2, this.forwardVarForGrad.remove(str));
        }
        if (this.placeHolderMap.containsKey(str)) {
            this.placeHolderMap.put(str2, this.placeHolderMap.remove(str));
        }
        if (this.functionsArgsFor.containsKey(str)) {
            List<DifferentialFunction> remove = this.functionsArgsFor.remove(str);
            for (DifferentialFunction differentialFunction : remove) {
                if (differentialFunction instanceof BaseOp) {
                    BaseOp baseOp = (BaseOp) differentialFunction;
                    if (baseOp.getXVertexId() != null && baseOp.getXVertexId().equals(str)) {
                        baseOp.setXVertexId(str2);
                    }
                    if (baseOp.getYVertexId() != null && baseOp.getYVertexId().equals(str)) {
                        baseOp.setYVertexId(str2);
                    }
                    if (baseOp.getZVertexId() != null && baseOp.getZVertexId().equals(str)) {
                        baseOp.setZVertexId(str2);
                    }
                }
            }
            this.functionsArgsFor.put(str2, remove);
        }
        if (this.functionOutputFor.containsKey(str)) {
            List<DifferentialFunction> remove2 = this.functionOutputFor.remove(str);
            for (DifferentialFunction differentialFunction2 : remove2) {
                if (differentialFunction2 instanceof BaseOp) {
                    BaseOp baseOp2 = (BaseOp) differentialFunction2;
                    if (baseOp2.getXVertexId() != null && baseOp2.getXVertexId().equals(str)) {
                        baseOp2.setXVertexId(str2);
                    }
                    if (baseOp2.getYVertexId() != null && baseOp2.getYVertexId().equals(str)) {
                        baseOp2.setYVertexId(str2);
                    }
                    if (baseOp2.getZVertexId() != null && baseOp2.getZVertexId().equals(str)) {
                        baseOp2.setZVertexId(str2);
                    }
                }
            }
            this.functionOutputFor.put(str2, remove2);
        }
        this.variableMap.remove(str);
    }

    public SameDiff disableDebugging() {
        this.debugMode = false;
        return this;
    }

    public SameDiff enableDebugMode() {
        this.debugMode = true;
        return this;
    }

    public DifferentialFunctionFactory f() {
        return this.functionFactory;
    }

    public SDVariable invokeGraphOn(SameDiff sameDiff) {
        HashMap hashMap = new HashMap();
        int i = 1;
        for (SDVariable sDVariable : variables()) {
            SDVariable sDVariable2 = (SDVariable) cloner.deepCloneDontCloneInstances(sDVariable, sDVariable.getSameDiff());
            SDVariable var = sameDiff.var(sDVariable2);
            if (sDVariable.getArr() != null) {
                sameDiff.associateArrayWithVariable(sDVariable.getArr(), var);
            }
            hashMap.put(Integer.valueOf(i), Integer.valueOf(i));
            sDVariable2.setSameDiff(sameDiff);
            i++;
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (DifferentialFunction differentialFunction : this.functionInstancesById.values()) {
            if (!(differentialFunction instanceof SDVariable)) {
                DifferentialFunction differentialFunction2 = (DifferentialFunction) cloner.deepCloneDontCloneInstances(differentialFunction, differentialFunction.getSameDiff());
                differentialFunction2.setSameDiff(sameDiff);
                differentialFunction2.setOwnName(differentialFunction.getOwnName());
                if (sameDiff.functionExists(differentialFunction.getOwnName())) {
                    sameDiff.putFunctionForId(differentialFunction.getOwnName(), differentialFunction);
                }
                linkedHashMap.put(differentialFunction.getOwnName(), differentialFunction2);
                SDVariable[] args = differentialFunction.args();
                SDVariable[] outputVariables = differentialFunction.outputVariables();
                sameDiff.addArgsFor(args, differentialFunction2);
                sameDiff.addOutgoingFor(outputVariables, differentialFunction);
                for (SDVariable sDVariable3 : differentialFunction2.args()) {
                    sDVariable3.setSameDiff(sameDiff);
                }
                for (SDVariable sDVariable4 : differentialFunction2.outputVariables()) {
                    sDVariable4.setSameDiff(sameDiff);
                }
                sameDiff.functionInstancesById.put(differentialFunction.getOwnName(), differentialFunction);
            }
        }
        for (Map.Entry<INDArray, SDVariable> entry : this.reverseArrayLookup.entrySet()) {
            sameDiff.reverseArrayLookup.put(entry.getKey(), sameDiff.getVariable(entry.getValue().getVarName()));
        }
        return sameDiff.variables().get(sameDiff.variables().size() - 1);
    }

    public boolean functionExists(String str) {
        return this.functionInstancesById.containsKey(str);
    }

    public DifferentialFunction getFunctionById(String str) {
        if (this.functionInstancesById.containsKey(str)) {
            return this.functionInstancesById.get(str);
        }
        throw new ND4JIllegalStateException("No function with id " + str + " found!");
    }

    public void putFunctionForId(String str, DifferentialFunction differentialFunction) {
        if (this.functionInstancesById.containsKey(str)) {
            throw new ND4JIllegalStateException("Function by id already exists!");
        }
        if (differentialFunction instanceof SDVariable) {
            throw new ND4JIllegalStateException("Function must not be a variable!");
        }
        this.functionInstancesById.put(str, differentialFunction);
    }

    public String[] getInputsForFunction(DifferentialFunction differentialFunction) {
        if (this.incomingArgsReverse.containsKey(differentialFunction.getOwnName())) {
            return this.incomingArgsReverse.get(differentialFunction.getOwnName());
        }
        throw new ND4JIllegalStateException("Illegal function instance id found " + differentialFunction.getOwnName());
    }

    public String[] getOutputsForFunction(DifferentialFunction differentialFunction) {
        return this.outgoingArgsReverse.get(differentialFunction.getOwnName());
    }

    public SDVariable[] getOutputVariablesForFunction(DifferentialFunction differentialFunction) {
        String[] outputsForFunction = getOutputsForFunction(differentialFunction);
        if (outputsForFunction == null) {
            throw new ND4JIllegalStateException("No inputs found for function " + differentialFunction);
        }
        SDVariable[] sDVariableArr = new SDVariable[outputsForFunction.length];
        for (int i = 0; i < outputsForFunction.length; i++) {
            sDVariableArr[i] = getVariable(outputsForFunction[i]);
        }
        return sDVariableArr;
    }

    public SDVariable[] getInputVariablesForFunction(DifferentialFunction differentialFunction) {
        String[] inputsForFunction = getInputsForFunction(differentialFunction);
        if (inputsForFunction == null) {
            throw new ND4JIllegalStateException("No inputs found for function " + differentialFunction);
        }
        SDVariable[] sDVariableArr = new SDVariable[inputsForFunction.length];
        for (int i = 0; i < inputsForFunction.length; i++) {
            sDVariableArr[i] = getVariable(inputsForFunction[i]);
            if (sDVariableArr[i] == null) {
                throw new ND4JIllegalStateException("Found null variable at index " + i);
            }
        }
        return sDVariableArr;
    }

    public void updateArrayForVarName(String str, INDArray iNDArray) {
        if (!this.variableNameToArr.containsKey(str)) {
            throw new ND4JIllegalStateException("Array for " + str + " does not exist. Please use putArrayForVertexId instead.");
        }
        this.variableNameToArr.put(str, iNDArray);
        this.reverseArrayLookup.put(iNDArray, getVariable(str));
    }

    public void putArrayForVarName(String str, INDArray iNDArray) {
        if (str == null) {
            throw new ND4JIllegalStateException("No null names allowed!");
        }
        if (this.variableNameToArr.containsKey(str)) {
            throw new ND4JIllegalStateException("Array for " + str + " already exists!");
        }
        this.variableNameToArr.put(str, iNDArray);
    }

    public int[] getShapeForVarName(String str) {
        return this.variableNameToArr.containsKey(str) ? this.variableNameToArr.get(str).shape() : this.variableNameToShape.get(str);
    }

    public void updateShapeForVarName(String str, int[] iArr) {
        if (iArr == null) {
            throw new ND4JIllegalStateException("Null shapes not allowed!");
        }
        if (this.variableNameToArr.containsKey(str) && !Arrays.equals(this.variableNameToArr.get(str).shape(), iArr)) {
            throw new ND4JIllegalStateException("Already found an existing array!");
        }
        for (int i : iArr) {
            if (i < 1) {
                addAsPlaceHolder(str);
                this.placeHolderOriginalShapes.put(str, iArr);
                return;
            }
        }
        this.variableNameToShape.put(str, iArr);
    }

    public void putShapeForVarName(String str, int[] iArr) {
        if (iArr == null) {
            throw new ND4JIllegalStateException("Shape must not be null!");
        }
        if (this.variableNameToShape.containsKey(str)) {
            throw new ND4JIllegalStateException("Shape for " + str + " already exists!");
        }
        for (int i : iArr) {
            if (i < 1) {
                addAsPlaceHolder(str);
                this.placeHolderOriginalShapes.put(str, iArr);
                return;
            }
        }
        this.variableNameToShape.put(str, iArr);
    }

    public boolean shapeAlreadyExistsForVarName(String str) {
        return this.variableNameToShape.containsKey(str) || arrayAlreadyExistsForVarName(str);
    }

    public boolean arrayAlreadyExistsForVarName(String str) {
        return this.variableNameToArr.containsKey(str);
    }

    public INDArray getArrForVarName(String str) {
        return this.variableNameToArr.get(str);
    }

    public void associateArrayWithVariable(INDArray iNDArray, SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new ND4JIllegalArgumentException("Variable must not be null!");
        }
        if (iNDArray == null) {
            throw new ND4JIllegalArgumentException("Array must not be null");
        }
        this.reverseArrayLookup.put(iNDArray, sDVariable);
        this.variableNameToArr.put(sDVariable.getVarName(), iNDArray);
        if (shapeAlreadyExistsForVarName(sDVariable.getVarName())) {
            updateShapeForVarName(sDVariable.getVarName(), iNDArray.shape());
        } else {
            putShapeForVarName(sDVariable.getVarName(), iNDArray.shape());
        }
    }

    public void putSubFunction(String str, SameDiff sameDiff) {
        if (this.sameDiffFunctionInstances.containsKey(str) && this.sameDiffFunctionInstances.get(str) != sameDiff) {
            throw new ND4JIllegalStateException("Unable to replace samediff namespace. Please choose another opName");
        }
        this.sameDiffFunctionInstances.put(str, sameDiff);
    }

    public Map<String, SDVariable> variableMap() {
        return this.variableMap;
    }

    public SDVariable invoke(Op op, SDVariable sDVariable, SDVariable sDVariable2) {
        if (!opMethods.containsKey(op.opName())) {
            throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
        }
        if (sDVariable == null || sDVariable2 == null) {
            try {
                return (SDVariable) opMethods.get(op.opName()).invoke(this, sDVariable);
            } catch (Exception e) {
            }
        } else {
            try {
                return (SDVariable) opMethods.get(op.opName()).invoke(this, sDVariable, sDVariable2);
            } catch (Exception e2) {
            }
        }
        throw new ND4JIllegalStateException("Illegal method opName " + op.opName());
    }

    public SDVariable getVariableForArray(INDArray iNDArray) {
        return this.reverseArrayLookup.get(iNDArray);
    }

    public Collection<String> definedFunctionNames() {
        return this.sameDiffFunctionInstances.keySet();
    }

    public long memoryForGraph() {
        return numElements() * DataTypeUtil.lengthForDtype(Nd4j.dataType());
    }

    public SDVariable invoke(Op op, SDVariable sDVariable) {
        return invoke(op, sDVariable, null);
    }

    private SameDiff() {
        this.shouldBootStrap = true;
        this.localFlowPath = new ThreadLocal<>();
        this.wasRegistered = new AtomicBoolean(false);
        this.resolvedVariables = false;
        this.logExecution = true;
        this.functionFactory = new DifferentialFunctionFactory(this);
        this.variableMap = new LinkedHashMap();
        this.sameDiffFunctionDefinitionMap = new LinkedHashMap();
        this.sameDiffFunctionInstances = new LinkedHashMap();
        this.gradients = new LinkedHashMap();
        this.forwardVarForGrad = new LinkedHashMap();
        this.opsForResult = new IntArrayKeyMap();
        this.reverseArrayLookup = new IdentityHashMap<>();
        this.variableNameToArr = new LinkedHashMap();
        this.variableNameToShape = new LinkedHashMap();
        this.placeHolderMap = new LinkedHashMap();
        this.placeHolderVarNames = new LinkedHashSet();
        this.placeHolderOriginalShapes = new LinkedHashMap();
        this.incomingArgs = new LinkedHashMap();
        this.outgoingArgs = new LinkedHashMap();
        this.incomingArgsReverse = new LinkedHashMap();
        this.outgoingArgsReverse = new LinkedHashMap();
        this.functionInstancesById = new LinkedHashMap();
        this.placeHolderFunctions = new LinkedHashSet();
        this.functionsArgsFor = new LinkedHashMap();
        this.functionOutputFor = new LinkedHashMap();
        this.baseNameForFunctionInstanceId = new LinkedHashMap();
        this.importedVarName = new LinkedHashSet();
        this.permuteOrder = new LinkedHashMap();
        this.propertiesToResolve = new LinkedHashMap();
        this.propertiesForFunction = new LinkedHashMap();
        this.fieldVariableResolutionMapping = HashBasedTable.create();
    }

    public void addPropertyToResolve(DifferentialFunction differentialFunction, String str) {
        if (this.propertiesToResolve.containsKey(differentialFunction.getOwnName())) {
            this.propertiesToResolve.get(differentialFunction.getOwnName()).add(str);
            return;
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(str);
        this.propertiesToResolve.put(differentialFunction.getOwnName(), arrayList);
    }

    public List<String> propertiesToResolveForFunction(DifferentialFunction differentialFunction) {
        return !this.propertiesToResolve.containsKey(differentialFunction.getOwnName()) ? Collections.emptyList() : this.propertiesToResolve.get(differentialFunction.getOwnName());
    }

    public boolean hasPropertiesToResolve(DifferentialFunction differentialFunction) {
        return this.propertiesToResolve.containsKey(differentialFunction.getOwnName());
    }

    public <T> T getPropertyForFunction(DifferentialFunction differentialFunction, String str) {
        if (this.propertiesForFunction.containsKey(differentialFunction.getOwnName())) {
            return (T) this.propertiesForFunction.get(differentialFunction.getOwnName()).get(str);
        }
        return null;
    }

    public void addPropertyForFunction(DifferentialFunction differentialFunction, String str, INDArray iNDArray) {
        addPropertyForFunction(differentialFunction, str, (Object) iNDArray);
    }

    public void addPropertyForFunction(DifferentialFunction differentialFunction, String str, long j) {
        addPropertyForFunction(differentialFunction, str, Long.valueOf(j));
    }

    private void addPropertyForFunction(DifferentialFunction differentialFunction, String str, Object obj) {
        if (!this.propertiesForFunction.containsKey(differentialFunction.getOwnName())) {
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            linkedHashMap.put(str, obj);
            this.propertiesForFunction.put(differentialFunction.getOwnName(), linkedHashMap);
        } else {
            Map<String, Object> map = this.propertiesForFunction.get(differentialFunction.getOwnName());
            if (map.containsKey(str)) {
                throw new ND4JIllegalStateException("Attempting to override property " + str);
            }
            map.put(str, obj);
        }
    }

    public void addVariableMappingForField(DifferentialFunction differentialFunction, String str, String str2) {
        this.fieldVariableResolutionMapping.put(differentialFunction.getOwnName(), str, str2);
    }

    public String getVarNameForFieldAndFunction(DifferentialFunction differentialFunction, String str) {
        return this.fieldVariableResolutionMapping.get(differentialFunction.getOwnName(), str);
    }

    public boolean isImportVariable(String str) {
        return this.importedVarName.contains(str);
    }

    public void addVarNameForImport(String str) {
        this.importedVarName.add(str);
    }

    public void setBaseNameForFunctionInstanceId(String str, DifferentialFunction differentialFunction) {
        this.baseNameForFunctionInstanceId.put(differentialFunction.getOwnName(), str);
    }

    public String getBaseNameForFunction(DifferentialFunction differentialFunction) {
        return this.baseNameForFunctionInstanceId.get(differentialFunction.getOwnName());
    }

    public <X extends SDVariable> X setupFunction(X x) {
        Preconditions.checkNotNull(x, "Passed in function must not be null!");
        if (!(x instanceof SDVariable)) {
            return x;
        }
        if (x.getSameDiff() != this) {
            x.setSameDiff(this);
        }
        return x;
    }

    public void addOutgoingFor(SDVariable[] sDVariableArr, DifferentialFunction differentialFunction) {
        String[] strArr = new String[sDVariableArr.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = sDVariableArr[i].getVarName();
        }
        addOutgoingFor(strArr, differentialFunction);
    }

    public void addOutgoingFor(String[] strArr, DifferentialFunction differentialFunction) {
        if (differentialFunction.getOwnName() == null) {
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
        }
        if (this.outgoingArgsReverse.containsKey(differentialFunction.getOwnName())) {
            throw new ND4JIllegalStateException("Outgoing arguments already declared for " + differentialFunction);
        }
        if (strArr == null) {
            throw new ND4JIllegalStateException("Var names can not be null!");
        }
        for (String str : strArr) {
            if (str == null) {
                throw new ND4JIllegalStateException("Variable name elements can not be null!");
            }
        }
        this.outgoingArgsReverse.put(differentialFunction.getOwnName(), strArr);
        this.outgoingArgs.put(strArr, differentialFunction);
        for (String str2 : strArr) {
            List<DifferentialFunction> list = this.functionOutputFor.get(str2);
            if (list == null) {
                list = new ArrayList();
                this.functionOutputFor.put(str2, list);
            }
            list.add(differentialFunction);
        }
    }

    public void addArgsFor(String[] strArr, DifferentialFunction differentialFunction) {
        if (differentialFunction.getOwnName() == null) {
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
        }
        for (String str : strArr) {
            if (isPlaceHolder(str)) {
                this.placeHolderFunctions.add(differentialFunction.getOwnName());
            }
        }
        this.incomingArgs.put(strArr, differentialFunction);
        this.incomingArgsReverse.put(differentialFunction.getOwnName(), strArr);
        for (String str2 : strArr) {
            List<DifferentialFunction> list = this.functionsArgsFor.get(str2);
            if (list == null) {
                list = new ArrayList();
                this.functionsArgsFor.put(str2, list);
            }
            list.add(differentialFunction);
        }
    }

    public void addArgsFor(SDVariable[] sDVariableArr, DifferentialFunction differentialFunction) {
        String[] strArr = new String[sDVariableArr.length];
        for (int i = 0; i < strArr.length; i++) {
            if (sDVariableArr[i] == null) {
                throw new ND4JIllegalStateException("Found null variable at index " + i);
            }
            strArr[i] = sDVariableArr[i].getVarName();
        }
        addArgsFor(strArr, differentialFunction);
    }

    public boolean hasArgs(int[] iArr) {
        return this.incomingArgs.containsKey(iArr);
    }

    public boolean hasArgs(DifferentialFunction differentialFunction) {
        String[] strArr = this.incomingArgsReverse.get(differentialFunction.getOwnName());
        return (strArr == null || this.incomingArgs.get(strArr) == null) ? false : true;
    }

    public DifferentialFunction[] functions() {
        Collection<DifferentialFunction> values = this.functionInstancesById.values();
        return (DifferentialFunction[]) values.toArray(new DifferentialFunction[values.size()]);
    }

    public int hashCode() {
        return (31 * super.hashCode()) + (this.variableMap != null ? this.variableMap.hashCode() : 0);
    }

    public static SameDiff create(SameDiff sameDiff) {
        SameDiff build = builder().variableMap(sameDiff.variableMap).sameDiffFunctionInstances(sameDiff.sameDiffFunctionInstances).build();
        build.functionFactory = new DifferentialFunctionFactory(build);
        return build;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        SameDiff sameDiff = (SameDiff) obj;
        if (this.variableMap != null) {
            if (!this.variableMap.equals(sameDiff.variableMap)) {
                return false;
            }
        } else if (sameDiff.variableMap != null) {
            return false;
        }
        if (this.sameDiffFunctionDefinitionMap != null) {
            if (!this.sameDiffFunctionDefinitionMap.equals(sameDiff.sameDiffFunctionDefinitionMap)) {
                return false;
            }
        } else if (sameDiff.sameDiffFunctionDefinitionMap != null) {
            return false;
        }
        return this.sameDiffFunctionInstances != null ? this.sameDiffFunctionInstances.equals(sameDiff.sameDiffFunctionInstances) : sameDiff.sameDiffFunctionInstances == null;
    }

    public static SameDiff create() {
        return new SameDiff();
    }

    public INDArray[] eval(Map<String, INDArray> map) {
        SameDiff dup = dup();
        List<DifferentialFunction> right = dup.exec().getRight();
        if (right.isEmpty()) {
            throw new IllegalStateException("No ops found to execute.");
        }
        INDArray[] iNDArrayArr = new INDArray[right.size()];
        for (int i = 0; i < iNDArrayArr.length; i++) {
            iNDArrayArr[i] = dup.getArrForVarName(right.get(i).outputVariables()[0].getVarName());
        }
        return iNDArrayArr;
    }

    public SameDiff dup() {
        return (SameDiff) newCloner().deepClone(this);
    }

    public long numElements() {
        long j = 0;
        while (variables().iterator().hasNext()) {
            j += ArrayUtil.prod(r0.next().getShape());
        }
        return j;
    }

    private void initWorkspace() {
        this.workspace = Nd4j.getWorkspaceManager().createNewWorkspace(WorkspaceConfiguration.builder().initialSize(memoryForGraph()).policyAllocation(AllocationPolicy.OVERALLOCATE).policyLearning(LearningPolicy.FIRST_LOOP).build());
        Nd4j.getWorkspaceManager().setWorkspaceForCurrentThread(this.workspace);
    }

    public List<SDVariable> variables() {
        return new ArrayList(this.variableMap.values());
    }

    public SDVariable one(String str, int[] iArr) {
        return var(str, iArr, new ConstantInitScheme('f', 1.0d));
    }

    public SDVariable onesLike(SDVariable sDVariable) {
        return onesLike(null, sDVariable);
    }

    public SDVariable onesLike(String str, SDVariable sDVariable) {
        return f().onesLike(str, sDVariable);
    }

    public SDVariable zero(String str, int[] iArr) {
        return var(str, iArr, new ZeroInitScheme());
    }

    public SDVariable zerosLike(SDVariable sDVariable) {
        return zerosLike(null, sDVariable);
    }

    public SDVariable zerosLike(String str, SDVariable sDVariable) {
        return f().zerosLike(str, sDVariable);
    }

    public SDVariable var(String str, int[] iArr, WeightInitScheme weightInitScheme) {
        if (this.variableMap.containsKey(str) && this.variableMap.get(str).getArr() != null) {
            return this.variableMap.get(str);
        }
        if (str == null || str.length() < 1) {
            throw new IllegalArgumentException("Name for variable must be defined");
        }
        if (this.workspace == null) {
            initWorkspace();
        }
        SDVariable build = SDVariable.builder().sameDiff(this).shape(iArr).weightInitScheme(weightInitScheme).varName(str).build();
        addVariable(build);
        this.variableMap.put(str, build);
        return build;
    }

    public SDVariable var(String str, int[] iArr) {
        return var(str, iArr, new ZeroInitScheme());
    }

    public SDVariable var(final SDVariable sDVariable) {
        if (this.variableMap.containsKey(sDVariable.getVarName()) && this.variableMap.get(sDVariable.getVarName()).getArr() != null) {
            return this.variableMap.get(sDVariable.getVarName());
        }
        if (sDVariable.getVarName() == null || sDVariable.getVarName().length() < 1) {
            throw new IllegalArgumentException("Name for variable must be defined");
        }
        if (sDVariable == null) {
            throw new IllegalArgumentException("Array for " + sDVariable.getVarName() + " must not be null");
        }
        if (this.workspace == null) {
            initWorkspace();
        }
        SDVariable build = SDVariable.builder().sameDiff(this).shape(sDVariable.getShape()).varName(sDVariable.getVarName()).weightInitScheme(new NDArraySupplierInitScheme(new NDArraySupplierInitScheme.NDArraySupplier() { // from class: org.nd4j.autodiff.samediff.SameDiff.1
            @Override // org.nd4j.weightinit.impl.NDArraySupplierInitScheme.NDArraySupplier
            public INDArray getArr() {
                if (sDVariable.getArr() == null) {
                    SameDiff.this.associateArrayWithVariable(sDVariable.getWeightInitScheme().create(sDVariable.getShape()), sDVariable);
                }
                return sDVariable.getArr();
            }
        })).build();
        this.variableMap.put(sDVariable.getVarName(), build);
        return build;
    }

    public void removeArgFromFunction(String str, DifferentialFunction differentialFunction) {
        SDVariable[] args = differentialFunction.args();
        for (SDVariable sDVariable : args) {
            if (sDVariable.getVarName().equals(str)) {
                String[] strArr = this.incomingArgsReverse.get(differentialFunction.getOwnName());
                this.incomingArgs.remove(strArr);
                this.incomingArgsReverse.remove(differentialFunction.getOwnName());
                ArrayList arrayList = new ArrayList(args.length - 1);
                for (int i = 0; i < args.length; i++) {
                    if (!strArr[i].equals(str)) {
                        arrayList.add(strArr[i]);
                    }
                }
                String[] strArr2 = (String[]) arrayList.toArray(new String[arrayList.size()]);
                this.incomingArgs.put(strArr2, differentialFunction);
                this.incomingArgsReverse.put(differentialFunction.getOwnName(), strArr2);
                return;
            }
        }
    }

    public SDVariable var(String str, INDArray iNDArray) {
        if (this.variableMap.containsKey(str) && this.variableMap.get(str).getArr() != null) {
            return this.variableMap.get(str);
        }
        if (str == null || str.length() < 1) {
            throw new IllegalArgumentException("Name for variable must be defined");
        }
        if (iNDArray == null) {
            throw new IllegalArgumentException("Array for " + str + " must not be null");
        }
        if (this.workspace == null) {
            initWorkspace();
        }
        final INDArray migrate = iNDArray.migrate();
        SDVariable build = SDVariable.builder().sameDiff(this).shape(iNDArray.shape()).varName(str).weightInitScheme(new NDArraySupplierInitScheme(new NDArraySupplierInitScheme.NDArraySupplier() { // from class: org.nd4j.autodiff.samediff.SameDiff.2
            @Override // org.nd4j.weightinit.impl.NDArraySupplierInitScheme.NDArraySupplier
            public INDArray getArr() {
                return migrate;
            }
        })).build();
        associateArrayWithVariable(iNDArray, build);
        if (ArrayUtil.prod(iNDArray.shape()) == 1) {
            build.setScalarValue(Double.valueOf(iNDArray.getDouble(0)));
        }
        addVariable(build);
        if (getShapeForVarName(str) == null) {
            putShapeForVarName(str, iNDArray.shape());
        }
        this.reverseArrayLookup.put(iNDArray, build);
        this.variableMap.put(str, build);
        return build;
    }

    public SDVariable getVariable(String str) {
        return this.variableMap.get(str);
    }

    public SDVariable getGradForVariable(String str) {
        return this.gradients.get(str);
    }

    public void setGradientForVariableName(String str, SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new ND4JIllegalStateException("Unable to set null gradient for variable name " + str);
        }
        this.gradients.put(str, sDVariable);
    }

    public SDVariable getForwardVariableForVertexId(int i) {
        return this.forwardVarForGrad.get(Integer.valueOf(i));
    }

    public void setForwardVariableForVarName(String str, SDVariable sDVariable) {
        this.forwardVarForGrad.put(str, sDVariable);
    }

    public SDVariable grad(String str) {
        if (!this.sameDiffFunctionInstances.containsKey("grad")) {
            throw new IllegalStateException("Unable to obtain gradient. Please run execBackwards() first.");
        }
        return getFunction("grad").getGradForVariable(getFunction("grad").getVariable(str).getVarName());
    }

    public SDVariable avgPooling2d(SDVariable[] sDVariableArr, Pooling2DConfig pooling2DConfig) {
        return avgPooling2d(null, sDVariableArr, pooling2DConfig);
    }

    public SDVariable avgPooling2d(String str, SDVariable[] sDVariableArr, Pooling2DConfig pooling2DConfig) {
        return updateVariableNameAndReference(f().avgPooling2d(sDVariableArr, pooling2DConfig), str);
    }

    public SDVariable maxPooling2d(SDVariable[] sDVariableArr, Pooling2DConfig pooling2DConfig) {
        return maxPooling2d(null, sDVariableArr, pooling2DConfig);
    }

    public SDVariable maxPooling2d(String str, SDVariable[] sDVariableArr, Pooling2DConfig pooling2DConfig) {
        return updateVariableNameAndReference(f().maxPooling2d(sDVariableArr, pooling2DConfig), str);
    }

    public SDVariable avgPooling3d(SDVariable[] sDVariableArr, Pooling3DConfig pooling3DConfig) {
        return avgPooling3d(null, sDVariableArr, pooling3DConfig);
    }

    public SDVariable avgPooling3d(String str, SDVariable[] sDVariableArr, Pooling3DConfig pooling3DConfig) {
        return updateVariableNameAndReference(f().avgPooling3d(sDVariableArr, pooling3DConfig), str);
    }

    public SDVariable maxPooling3d(SDVariable[] sDVariableArr, Pooling3DConfig pooling3DConfig) {
        return maxPooling3d(null, sDVariableArr, pooling3DConfig);
    }

    public SDVariable maxPooling3d(String str, SDVariable[] sDVariableArr, Pooling3DConfig pooling3DConfig) {
        return updateVariableNameAndReference(f().maxPooling3d(sDVariableArr, pooling3DConfig), str);
    }

    public SDVariable conv1d(SDVariable[] sDVariableArr, Conv1DConfig conv1DConfig) {
        return conv1d(null, sDVariableArr, conv1DConfig);
    }

    public SDVariable conv1d(String str, SDVariable[] sDVariableArr, Conv1DConfig conv1DConfig) {
        return updateVariableNameAndReference(f().conv1d(sDVariableArr, conv1DConfig), str);
    }

    public SDVariable localResponseNormalization(SDVariable sDVariable, LocalResponseNormalizationConfig localResponseNormalizationConfig) {
        return localResponseNormalization(null, sDVariable, localResponseNormalizationConfig);
    }

    public SDVariable localResponseNormalization(String str, SDVariable sDVariable, LocalResponseNormalizationConfig localResponseNormalizationConfig) {
        return updateVariableNameAndReference(f().localResponseNormalization(sDVariable, localResponseNormalizationConfig), str);
    }

    public SDVariable conv2d(SDVariable[] sDVariableArr, Conv2DConfig conv2DConfig) {
        return conv2d(null, sDVariableArr, conv2DConfig);
    }

    public SDVariable conv2d(String str, SDVariable[] sDVariableArr, Conv2DConfig conv2DConfig) {
        return updateVariableNameAndReference(f().conv2d(sDVariableArr, conv2DConfig), str);
    }

    public SDVariable depthWiseConv2d(SDVariable[] sDVariableArr, Conv2DConfig conv2DConfig) {
        return depthWiseConv2d(null, sDVariableArr, conv2DConfig);
    }

    public SDVariable depthWiseConv2d(String str, SDVariable[] sDVariableArr, Conv2DConfig conv2DConfig) {
        return updateVariableNameAndReference(f().depthWiseConv2d(sDVariableArr, conv2DConfig), str);
    }

    public SDVariable sconv2d(SDVariable[] sDVariableArr, Conv2DConfig conv2DConfig) {
        return sconv2d(null, sDVariableArr, conv2DConfig);
    }

    public SDVariable sconv2d(String str, SDVariable[] sDVariableArr, Conv2DConfig conv2DConfig) {
        return updateVariableNameAndReference(f().sconv2d(sDVariableArr, conv2DConfig), str);
    }

    public SDVariable deconv2d(SDVariable[] sDVariableArr, DeConv2DConfig deConv2DConfig) {
        return deconv2d(null, sDVariableArr, deConv2DConfig);
    }

    public SDVariable deconv2d(String str, SDVariable[] sDVariableArr, DeConv2DConfig deConv2DConfig) {
        return updateVariableNameAndReference(f().deconv2d(sDVariableArr, deConv2DConfig), str);
    }

    public SDVariable conv3d(SDVariable[] sDVariableArr, Conv3DConfig conv3DConfig) {
        return conv3d(null, sDVariableArr, conv3DConfig);
    }

    public SDVariable conv3d(String str, SDVariable[] sDVariableArr, Conv3DConfig conv3DConfig) {
        return updateVariableNameAndReference(f().conv3d(sDVariableArr, conv3DConfig), str);
    }

    public SDVariable batchNorm(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, boolean z, boolean z2, double d) {
        return batchNorm(null, sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, z, z2, d);
    }

    public SDVariable batchNorm(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, boolean z, boolean z2, double d) {
        return updateVariableNameAndReference(f().batchNorm(sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, z, z2, d), str);
    }

    public SDVariable scalar(String str, double d) {
        return var(str, Nd4j.scalar(d));
    }

    public SDVariable gte(SDVariable sDVariable, double d) {
        return gte((String) null, sDVariable, d);
    }

    public SDVariable lte(SDVariable sDVariable, double d) {
        return lte((String) null, sDVariable, d);
    }

    public SDVariable gt(SDVariable sDVariable, double d) {
        return gt((String) null, sDVariable, d);
    }

    public SDVariable lt(SDVariable sDVariable, double d) {
        return lt((String) null, sDVariable, d);
    }

    public SDVariable neq(SDVariable sDVariable, double d) {
        return neq((String) null, sDVariable, d);
    }

    public SDVariable eq(SDVariable sDVariable, double d) {
        return eq((String) null, sDVariable, d);
    }

    public SDVariable gte(SDVariable sDVariable, SDVariable sDVariable2) {
        return gte((String) null, sDVariable, sDVariable2);
    }

    public SDVariable lte(SDVariable sDVariable, SDVariable sDVariable2) {
        return lte((String) null, sDVariable, sDVariable2);
    }

    public SDVariable gt(SDVariable sDVariable, SDVariable sDVariable2) {
        return gt((String) null, sDVariable, sDVariable2);
    }

    public SDVariable lt(SDVariable sDVariable, SDVariable sDVariable2) {
        return lt((String) null, sDVariable, sDVariable2);
    }

    public SDVariable neq(SDVariable sDVariable, SDVariable sDVariable2) {
        return neq((String) null, sDVariable, sDVariable2);
    }

    public SDVariable eq(SDVariable sDVariable, SDVariable sDVariable2) {
        return eq((String) null, sDVariable, sDVariable2);
    }

    public SDVariable or(SDVariable sDVariable, SDVariable sDVariable2) {
        return or(null, sDVariable, sDVariable2);
    }

    public SDVariable and(SDVariable sDVariable, SDVariable sDVariable2) {
        return and(null, sDVariable, sDVariable2);
    }

    public SDVariable and(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(f().and(sDVariable, sDVariable2), str);
    }

    public SDVariable xor(SDVariable sDVariable, SDVariable sDVariable2) {
        return xor(null, sDVariable, sDVariable2);
    }

    public SDVariable xor(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(f().xor(sDVariable, sDVariable2), str);
    }

    public SDVariable abs(SDVariable sDVariable) {
        return abs(null, sDVariable);
    }

    public SDVariable abs(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().abs(sDVariable), str);
    }

    public SDVariable neg(SDVariable sDVariable) {
        return neg(null, sDVariable);
    }

    public SDVariable cos(SDVariable sDVariable) {
        return cos(null, sDVariable);
    }

    public SDVariable sin(SDVariable sDVariable) {
        return sin(null, sDVariable);
    }

    public SDVariable tan(SDVariable sDVariable) {
        return tan(null, sDVariable);
    }

    public SDVariable invertPermutation(SDVariable sDVariable) {
        return invertPermutation(null, sDVariable);
    }

    public SDVariable invertPermutation(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().invertPermutation(sDVariable, false), str);
    }

    public SDVariable acos(SDVariable sDVariable) {
        return acos(null, sDVariable);
    }

    public SDVariable asin(SDVariable sDVariable) {
        return asin(null, sDVariable);
    }

    public SDVariable atan(SDVariable sDVariable) {
        return atan(null, sDVariable);
    }

    public SDVariable atan2(SDVariable sDVariable, SDVariable sDVariable2) {
        return atan2(null, sDVariable, sDVariable2);
    }

    public SDVariable atan2(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(f().atan2(sDVariable, sDVariable2), str);
    }

    public SDVariable cosh(SDVariable sDVariable) {
        return cosh(null, sDVariable);
    }

    public SDVariable sinh(SDVariable sDVariable) {
        return sinh(null, sDVariable);
    }

    public SDVariable tanh(SDVariable sDVariable) {
        return tanh(null, sDVariable);
    }

    public SDVariable acosh(SDVariable sDVariable) {
        return acosh(null, sDVariable);
    }

    public SDVariable asinh(SDVariable sDVariable) {
        return asinh(null, sDVariable);
    }

    public SDVariable atanh(SDVariable sDVariable) {
        return atanh(null, sDVariable);
    }

    public SDVariable exp(SDVariable sDVariable) {
        return exp(null, sDVariable);
    }

    public SDVariable rsqrt(SDVariable sDVariable) {
        return rsqrt(null, sDVariable);
    }

    public SDVariable expm1(SDVariable sDVariable) {
        return expm1(null, sDVariable);
    }

    public SDVariable log1p(SDVariable sDVariable) {
        return log1p(null, sDVariable);
    }

    public SDVariable isInfinite(SDVariable sDVariable) {
        return isInfinite(null, sDVariable);
    }

    public SDVariable isNaN(SDVariable sDVariable) {
        return isNaN(null, sDVariable);
    }

    public SDVariable round(SDVariable sDVariable) {
        return round(null, sDVariable);
    }

    public SDVariable isFinite(SDVariable sDVariable) {
        return isFinite(null, sDVariable);
    }

    public SDVariable log(SDVariable sDVariable) {
        return log(null, sDVariable);
    }

    public SDVariable cube(SDVariable sDVariable) {
        return cube(null, sDVariable);
    }

    public SDVariable pow(SDVariable sDVariable, double d) {
        return pow(null, sDVariable, d);
    }

    public SDVariable sqrt(SDVariable sDVariable) {
        return sqrt(null, sDVariable);
    }

    public SDVariable square(SDVariable sDVariable) {
        return square(null, sDVariable);
    }

    public SDVariable floor(SDVariable sDVariable) {
        return floor(null, sDVariable);
    }

    public SDVariable ceil(SDVariable sDVariable) {
        return ceil(null, sDVariable);
    }

    public SDVariable ceil(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().ceil(sDVariable), str);
    }

    public SDVariable clipByValue(SDVariable sDVariable, double d, double d2) {
        return clipByValue(null, sDVariable, d, d2);
    }

    public SDVariable clipByValue(String str, SDVariable sDVariable, double d, double d2) {
        return updateVariableNameAndReference(f().clipByValue(sDVariable, d, d2), str);
    }

    public SDVariable clipByNorm(SDVariable sDVariable, double d) {
        return clipByNorm(null, sDVariable, d);
    }

    public SDVariable clipByNorm(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(f().clipByNorm(sDVariable, d), str);
    }

    public SDVariable relu(SDVariable sDVariable, double d) {
        return relu(null, sDVariable, d);
    }

    public SDVariable relu6(SDVariable sDVariable, double d) {
        return relu6(null, sDVariable, d);
    }

    public SDVariable softmax(SDVariable sDVariable) {
        return softmax(null, sDVariable);
    }

    public SDVariable logSoftmax(SDVariable sDVariable) {
        return logSoftmax(null, sDVariable);
    }

    public SDVariable logSoftmax(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().logSoftmax(sDVariable), str);
    }

    public SDVariable selu(SDVariable sDVariable) {
        return selu(null, sDVariable);
    }

    public SDVariable selu(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().selu(sDVariable), str);
    }

    public SDVariable mergeAdd(SDVariable... sDVariableArr) {
        return mergeAdd(null, sDVariableArr);
    }

    public SDVariable mergeAdd(String str, SDVariable[] sDVariableArr) {
        return updateVariableNameAndReference(f().mergeadd(sDVariableArr), str);
    }

    public SDVariable batchToSpace(SDVariable sDVariable, int[] iArr, int[][] iArr2) {
        return batchToSpace(null, sDVariable, iArr, iArr2);
    }

    public SDVariable batchToSpace(String str, SDVariable sDVariable, int[] iArr, int[][] iArr2) {
        return updateVariableNameAndReference(f().batchToSpace(sDVariable, iArr, iArr2), str);
    }

    public SDVariable depthToSpace(SDVariable sDVariable, int i, String str) {
        return depthToSpace(null, sDVariable, i, str);
    }

    public SDVariable depthToSpace(String str, SDVariable sDVariable, int i, String str2) {
        return updateVariableNameAndReference(f().depthToSpace(sDVariable, i, str2), str);
    }

    public SDVariable spaceToBatch(SDVariable sDVariable, int[] iArr, int[][] iArr2) {
        return spaceToBatch(null, sDVariable, iArr, iArr2);
    }

    public SDVariable spaceToBatch(String str, SDVariable sDVariable, int[] iArr, int[][] iArr2) {
        return updateVariableNameAndReference(f().spaceToBatch(sDVariable, iArr, iArr2), str);
    }

    public SDVariable spaceToDepth(SDVariable sDVariable, int i, String str) {
        return spaceToDepth(null, sDVariable, i, str);
    }

    public SDVariable spaceToDepth(String str, SDVariable sDVariable, int i, String str2) {
        return updateVariableNameAndReference(f().spaceToDepth(sDVariable, i, str2), str);
    }

    public SDVariable[] dynamicPartition(SDVariable sDVariable, SDVariable sDVariable2, int i) {
        return dynamicPartition(null, sDVariable, sDVariable2, i);
    }

    public SDVariable[] dynamicPartition(String[] strArr, SDVariable sDVariable, SDVariable sDVariable2, int i) {
        return updateVariableNamesAndReferences(f().dynamicPartition(sDVariable, sDVariable2, i), strArr);
    }

    public SDVariable dynamicStitch(SDVariable[] sDVariableArr, SDVariable[] sDVariableArr2) {
        return dynamicStitch(null, sDVariableArr, sDVariableArr2);
    }

    public SDVariable dynamicStitch(String str, SDVariable[] sDVariableArr, SDVariable[] sDVariableArr2) {
        return updateVariableNameAndReference(f().dynamicStitch(sDVariableArr, sDVariableArr2), str);
    }

    public SDVariable dilation2D(SDVariable sDVariable, SDVariable sDVariable2, int[] iArr, int[] iArr2, boolean z) {
        return dilation2D(null, sDVariable, sDVariable2, iArr, iArr2, z);
    }

    public SDVariable dilation2D(String str, SDVariable sDVariable, SDVariable sDVariable2, int[] iArr, int[] iArr2, boolean z) {
        return updateVariableNameAndReference(f().dilation2D(sDVariable, sDVariable2, iArr, iArr2, z), str);
    }

    public SDVariable shape(SDVariable sDVariable) {
        return shape(null, sDVariable);
    }

    public SDVariable shape(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().shape(sDVariable), str);
    }

    public SDVariable cross(SDVariable sDVariable, SDVariable sDVariable2) {
        return cross(null, sDVariable, sDVariable2);
    }

    public SDVariable cross(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(f().cross(sDVariable, sDVariable2), str);
    }

    public SDVariable gather(SDVariable sDVariable, int i, int[] iArr) {
        return gather(null, sDVariable, i, iArr);
    }

    public SDVariable gather(String str, SDVariable sDVariable, int i, int[] iArr) {
        return updateVariableNameAndReference(f().gather(sDVariable, i, iArr), str);
    }

    public SDVariable gatherNd(SDVariable sDVariable, SDVariable sDVariable2) {
        return gatherNd(null, sDVariable, sDVariable2);
    }

    public SDVariable gatherNd(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(f().gatherNd(sDVariable, sDVariable2), str);
    }

    public SDVariable repeat(SDVariable sDVariable, int i) {
        return repeat(null, sDVariable, i);
    }

    public SDVariable repeat(String str, SDVariable sDVariable, int i) {
        return updateVariableNameAndReference(f().repeat(sDVariable, i), str);
    }

    public SDVariable stack(SDVariable[] sDVariableArr, int i) {
        return stack(null, sDVariableArr, i);
    }

    public SDVariable stack(String str, SDVariable[] sDVariableArr, int i) {
        return updateVariableNameAndReference(f().stack(sDVariableArr, i), str);
    }

    public SDVariable parallel_stack(SDVariable[] sDVariableArr) {
        return parallel_stack(null, sDVariableArr);
    }

    public SDVariable parallel_stack(String str, SDVariable[] sDVariableArr) {
        return updateVariableNameAndReference(f().parallel_stack(sDVariableArr), str);
    }

    public SDVariable[] unstack(SDVariable sDVariable, int i) {
        return unstack(null, sDVariable, i);
    }

    public SDVariable[] unstack(String[] strArr, SDVariable sDVariable, int i) {
        return updateVariableNamesAndReferences(f().unstack(sDVariable, i), strArr);
    }

    public SDVariable erf(SDVariable sDVariable) {
        return erf(null, sDVariable);
    }

    public SDVariable erf(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().erf(sDVariable), str);
    }

    public SDVariable erfc(SDVariable sDVariable) {
        return erfc(null, sDVariable);
    }

    public SDVariable erfc(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().erfc(sDVariable), str);
    }

    public SDVariable diag(SDVariable sDVariable) {
        return diag(null, sDVariable);
    }

    public SDVariable diag(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().diag(sDVariable), str);
    }

    public SDVariable diagPart(SDVariable sDVariable) {
        return diagPart(null, sDVariable);
    }

    public SDVariable diagPart(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().diagPart(sDVariable), str);
    }

    public SDVariable oneHot(SDVariable sDVariable, int i) {
        return oneHot(null, sDVariable, i, -1, 1.0d, 0.0d);
    }

    public SDVariable oneHot(SDVariable sDVariable, int i, int i2, double d, double d2) {
        return oneHot(null, sDVariable, i, i2, d, d2);
    }

    public SDVariable oneHot(String str, SDVariable sDVariable, int i) {
        return oneHot(str, sDVariable, i, -1, 1.0d, 0.0d);
    }

    public SDVariable oneHot(String str, SDVariable sDVariable, int i, int i2, double d, double d2) {
        return updateVariableNameAndReference(f().onehot(sDVariable, i, i2, d, d2), str);
    }

    public SDVariable reciprocal(SDVariable sDVariable) {
        return reciprocal(null, sDVariable);
    }

    public SDVariable reciprocal(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().reciprocal(sDVariable), str);
    }

    public SDVariable gradientBackwardsMarker(SDVariable sDVariable) {
        return gradientBackwardsMarker(generateNewVarName(new GradientBackwardsMarker().opName(), 0), sDVariable);
    }

    public SDVariable hardTanh(SDVariable sDVariable) {
        return hardTanh(null, sDVariable);
    }

    public SDVariable hardTanhDerivative(SDVariable sDVariable) {
        return hardTanhDerivative(null, sDVariable);
    }

    public SDVariable sigmoid(SDVariable sDVariable) {
        return sigmoid(null, sDVariable);
    }

    public SDVariable sigmoidDerivative(SDVariable sDVariable, SDVariable sDVariable2) {
        return sigmoidDerivative(null, sDVariable, sDVariable2);
    }

    public SDVariable logSigmoid(SDVariable sDVariable) {
        return logSigmoid(null, sDVariable);
    }

    public SDVariable logSigmoid(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().logSigmoid(sDVariable), str);
    }

    public SDVariable sign(SDVariable sDVariable) {
        return sign(null, sDVariable);
    }

    public SDVariable softsign(SDVariable sDVariable) {
        return softsign(null, sDVariable);
    }

    public SDVariable softsignDerivative(SDVariable sDVariable) {
        return softsignDerivative(null, sDVariable);
    }

    public SDVariable softplus(SDVariable sDVariable) {
        return softplus(null, sDVariable);
    }

    public SDVariable swish(SDVariable sDVariable) {
        return swish(null, sDVariable);
    }

    public SDVariable swish(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().swish(sDVariable), str);
    }

    public SDVariable elu(SDVariable sDVariable) {
        return elu(null, sDVariable);
    }

    public SDVariable eluDerivative(SDVariable sDVariable) {
        return eluDerivative(null, sDVariable);
    }

    public SDVariable leakyRelu(SDVariable sDVariable, double d) {
        return leakyRelu(null, sDVariable, d);
    }

    public SDVariable mean(SDVariable sDVariable) {
        return mean((String) null, sDVariable);
    }

    public SDVariable mean(SDVariable sDVariable, int... iArr) {
        return mean(null, sDVariable, iArr);
    }

    public SDVariable standardDeviation(SDVariable sDVariable, boolean z, int... iArr) {
        return standardDeviation(null, sDVariable, z, iArr);
    }

    public SDVariable variance(SDVariable sDVariable, boolean z, int... iArr) {
        return variance(null, sDVariable, z, iArr);
    }

    public SDVariable sum(SDVariable sDVariable, int... iArr) {
        return sum(null, sDVariable, iArr);
    }

    public SDVariable prod(SDVariable sDVariable, int... iArr) {
        return prod(null, sDVariable, iArr);
    }

    public SDVariable max(SDVariable sDVariable, int... iArr) {
        return max((String) null, sDVariable, iArr);
    }

    public SDVariable max(SDVariable sDVariable, SDVariable sDVariable2) {
        return max((String) null, sDVariable, sDVariable2);
    }

    public SDVariable max(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(f().max(sDVariable, sDVariable2), str);
    }

    public SDVariable countZero(SDVariable sDVariable) {
        return countZero(null, sDVariable);
    }

    public SDVariable countZero(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().countZero(sDVariable), str);
    }

    public SDVariable zeroFraction(SDVariable sDVariable) {
        return zeroFraction(null, sDVariable);
    }

    public SDVariable zeroFraction(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().zeroFraction(sDVariable), str);
    }

    public SDVariable countNonZero(SDVariable sDVariable) {
        return countNonZero(null, sDVariable);
    }

    public SDVariable countNonZero(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().countNonZero(sDVariable), str);
    }

    public SDVariable min(SDVariable sDVariable, int... iArr) {
        return min((String) null, sDVariable, iArr);
    }

    public SDVariable min(SDVariable sDVariable, SDVariable sDVariable2) {
        return min((String) null, sDVariable, sDVariable2);
    }

    public SDVariable min(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(f().min(sDVariable, sDVariable2), str);
    }

    public SDVariable argmax(SDVariable sDVariable, int... iArr) {
        return argmax(null, sDVariable, iArr);
    }

    public SDVariable argmax(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(f().argmax(sDVariable, iArr), str);
    }

    public SDVariable argmin(SDVariable sDVariable, int... iArr) {
        return argmin(null, sDVariable, iArr);
    }

    public SDVariable argmin(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(f().argmin(sDVariable, iArr), str);
    }

    public SDVariable cumsum(SDVariable sDVariable, boolean z, boolean z2, int... iArr) {
        return cumsum(null, sDVariable, z, z2, iArr);
    }

    public SDVariable cumsum(String str, SDVariable sDVariable, boolean z, boolean z2, int... iArr) {
        return updateVariableNameAndReference(f().cumsum(sDVariable, z, z2, iArr), str);
    }

    public SDVariable cumprod(SDVariable sDVariable, boolean z, boolean z2, int... iArr) {
        return cumprod(null, sDVariable, z, z2, iArr);
    }

    public SDVariable cumprod(String str, SDVariable sDVariable, boolean z, boolean z2, int... iArr) {
        return updateVariableNameAndReference(f().cumprod(sDVariable, z, z2, iArr), str);
    }

    public SDVariable biasAdd(SDVariable sDVariable, SDVariable sDVariable2) {
        return biasAdd(null, sDVariable, sDVariable2);
    }

    public SDVariable biasAdd(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(f().biasAdd(sDVariable, sDVariable2), str);
    }

    public SDVariable reshape(SDVariable sDVariable, int... iArr) {
        return reshape(null, sDVariable, iArr);
    }

    public SDVariable reverse(SDVariable sDVariable, int... iArr) {
        return reverse(null, sDVariable, iArr);
    }

    public SDVariable reverse(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(f().reverse(sDVariable, iArr), str);
    }

    public SDVariable reverseSequence(String str, SDVariable sDVariable, SDVariable sDVariable2, int i, int i2) {
        return updateVariableNameAndReference(f().reverseSequence(sDVariable, sDVariable2, i, i2), str);
    }

    public SDVariable reverseSequence(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(f().reverseSequence(sDVariable, sDVariable2), str);
    }

    public SDVariable reverseSequence(SDVariable sDVariable, SDVariable sDVariable2, int i, int i2) {
        return reverseSequence(null, sDVariable, sDVariable2, i, i2);
    }

    public SDVariable reverseSequence(SDVariable sDVariable, SDVariable sDVariable2) {
        return reverseSequence(null, sDVariable, sDVariable2);
    }

    public SDVariable sequenceMask(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(f().sequenceMask(sDVariable, sDVariable2), str);
    }

    public SDVariable sequenceMask(SDVariable sDVariable, SDVariable sDVariable2) {
        return sequenceMask((String) null, sDVariable, sDVariable2);
    }

    public SDVariable sequenceMask(String str, SDVariable sDVariable, int i) {
        return updateVariableNameAndReference(f().sequenceMask(sDVariable, i), str);
    }

    public SDVariable sequenceMask(SDVariable sDVariable, int i) {
        return sequenceMask((String) null, sDVariable, i);
    }

    public SDVariable sequenceMask(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(f().sequenceMask(sDVariable), str);
    }

    public SDVariable sequenceMask(SDVariable sDVariable) {
        return updateVariableNameAndReference(f().sequenceMask(sDVariable), null);
    }

    public SDVariable assign(SDVariable sDVariable, SDVariable sDVariable2) {
        return assign(null, sDVariable, sDVariable2);
    }

    public SDVariable assign(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(f().assign(sDVariable, sDVariable2), str);
    }

    public SDVariable transpose(SDVariable sDVariable) {
        return transpose(null, sDVariable);
    }

    public SDVariable permute(SDVariable sDVariable, int... iArr) {
        return permute(null, sDVariable, iArr);
    }

    public SDVariable rollAxis(SDVariable sDVariable, int i) {
        return rollAxis(null, sDVariable, i);
    }

    public SDVariable concat(int i, SDVariable... sDVariableArr) {
        return concat(null, i, sDVariableArr);
    }

    public SDVariable[] moments(SDVariable sDVariable, int... iArr) {
        return moments(null, sDVariable, iArr);
    }

    public SDVariable[] moments(String[] strArr, SDVariable sDVariable, int... iArr) {
        return updateVariableNamesAndReferences(f().moments(sDVariable, iArr), strArr);
    }

    public SDVariable[] normalizeMoments(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, double d) {
        return normalizeMoments(null, sDVariable, sDVariable2, sDVariable3, d);
    }

    public SDVariable[] normalizeMoments(String[] strArr, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, double d) {
        return updateVariableNamesAndReferences(f().normalizeMoments(sDVariable, sDVariable2, sDVariable3, d), strArr);
    }

    public SDVariable tile(SDVariable sDVariable, int[] iArr) {
        return tile(null, sDVariable, iArr);
    }

    public SDVariable fill(SDVariable sDVariable, double d) {
        return fill(null, sDVariable, d);
    }

    public SDVariable dropout(SDVariable sDVariable, double d) {
        return dropout(null, sDVariable, d);
    }

    public SDVariable dropout(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(f().dropout(sDVariable, d), str);
    }

    public SDVariable xwPlusB(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return xwPlusB(null, sDVariable, sDVariable2, sDVariable3);
    }

    public SDVariable xwPlusB(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return updateVariableNameAndReference(f().xwPlusB(sDVariable, sDVariable2, sDVariable3), str);
    }

    public SDVariable reluLayer(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return reluLayer(null, sDVariable, sDVariable2, sDVariable3);
    }

    public SDVariable reluLayer(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return updateVariableNameAndReference(f().reluLayer(sDVariable, sDVariable2, sDVariable3), str);
    }

    public SDVariable mmul(SDVariable sDVariable, SDVariable sDVariable2, MMulTranspose mMulTranspose) {
        return mmul(null, sDVariable, sDVariable2, mMulTranspose);
    }

    public SDVariable mmul(SDVariable sDVariable, SDVariable sDVariable2) {
        return mmul((String) null, sDVariable, sDVariable2);
    }

    public SDVariable tensorMmul(SDVariable sDVariable, SDVariable sDVariable2, int[][] iArr) {
        return tensorMmul(null, sDVariable, sDVariable2, iArr);
    }

    public SDVariable cosineSimilarity(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return cosineSimilarity(generateNewVarName(CosineSimilarity.OP_NAME, 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable euclideanDistance(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return euclideanDistance(generateNewVarName(EuclideanDistance.OP_NAME, 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable manhattanDistance(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return manhattanDistance(generateNewVarName(ManhattanDistance.OP_NAME, 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable cosineDistance(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return cosineDistance(null, sDVariable, sDVariable2, iArr);
    }

    public SDVariable cosineDistance(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.cosineDistance(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable hammingDistance(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return hammingDistance(null, sDVariable, sDVariable2, iArr);
    }

    public SDVariable hammingDistance(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.hammingDistance(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable jaccardDistance(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return jaccardDistance(null, sDVariable, sDVariable2, iArr);
    }

    public SDVariable jaccardDistance(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.jaccardDistance(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossBinaryXENT(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossBinaryXENT(generateNewVarName(new LossBinaryXENT().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable lossCosineSimilarity(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossCosineSimilarity(generateNewVarName(new LossCosineProximity().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable lossHinge(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossHinge(generateNewVarName(new LossHinge().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable lossKLD(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossKLD(generateNewVarName(new LossKLD().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable lossL1(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossL1(generateNewVarName(new LossL1().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable lossL2(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossL2(generateNewVarName(new LossL2().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable lossMAE(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossMAE(generateNewVarName(new LossMAE().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable lossMSE(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossMSE(generateNewVarName(new LossMSE().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable lossMCXENT(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossMCXENT(generateNewVarName(new LossMCXENT().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable lossMSLE(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossMSLE(generateNewVarName(new LossMSLE().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable lossNegativeLogLikelihood(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossNegativeLogLikelihood(generateNewVarName(new LossNegativeLogLikelihood().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable lossPoisson(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossPoisson(generateNewVarName(new LossPoisson().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable lossSquaredHinge(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return lossSquaredHinge(generateNewVarName(new LossSquaredHinge().opName(), 0), sDVariable, sDVariable2, iArr);
    }

    public SDVariable gradientBackwardsMarker(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.gradientBackwardsMarker(sDVariable), str);
    }

    public SDVariable neq(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(this.functionFactory.neq(sDVariable, d), str);
    }

    public SDVariable eq(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(this.functionFactory.eq(sDVariable, d), str);
    }

    public SDVariable gte(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(this.functionFactory.gte(sDVariable, d), str);
    }

    public SDVariable lte(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(this.functionFactory.lte(sDVariable, d), str);
    }

    public SDVariable gt(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(this.functionFactory.gt(sDVariable, d), str);
    }

    public SDVariable lt(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(this.functionFactory.lt(sDVariable, d), str);
    }

    public SDVariable neq(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(this.functionFactory.neq(sDVariable, sDVariable2), str);
    }

    public SDVariable eq(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(this.functionFactory.eq(sDVariable, sDVariable2), str);
    }

    public SDVariable gte(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(this.functionFactory.gte(sDVariable, sDVariable2), str);
    }

    public SDVariable lte(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(this.functionFactory.lte(sDVariable, sDVariable2), str);
    }

    public SDVariable gt(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(this.functionFactory.gt(sDVariable, sDVariable2), str);
    }

    public SDVariable lt(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(this.functionFactory.lt(sDVariable, sDVariable2), str);
    }

    public SDVariable or(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(this.functionFactory.or(sDVariable, sDVariable2), str);
    }

    public SDVariable neg(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.neg(sDVariable), str);
    }

    public SDVariable isNonDecreasing(SDVariable sDVariable) {
        return isNonDecreasing(null, sDVariable);
    }

    public SDVariable isNonDecreasing(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.isNonDecreasing(sDVariable), str);
    }

    public SDVariable isStrictlyIncreasing(SDVariable sDVariable) {
        return isStrictlyIncreasing(null, sDVariable);
    }

    public SDVariable isStrictlyIncreasing(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.isStrictlyIncreasing(sDVariable), str);
    }

    public SDVariable isNumericTensor(SDVariable sDVariable) {
        return isNumericTensor(null, sDVariable);
    }

    public SDVariable isNumericTensor(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.isNumericTensor(sDVariable), str);
    }

    public SDVariable cos(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.cos(sDVariable), str);
    }

    public SDVariable sin(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.sin(sDVariable), str);
    }

    public SDVariable tan(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.tan(sDVariable), str);
    }

    public SDVariable acos(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.acos(sDVariable), str);
    }

    public SDVariable asin(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.asin(sDVariable), str);
    }

    public SDVariable atan(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.atan(sDVariable), str);
    }

    public SDVariable cosh(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.cosh(sDVariable), str);
    }

    public SDVariable sinh(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.sinh(sDVariable), str);
    }

    public SDVariable tanh(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.tanh(sDVariable), str);
    }

    public SDVariable acosh(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.acosh(sDVariable), str);
    }

    public SDVariable asinh(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.asinh(sDVariable), str);
    }

    public SDVariable atanh(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.atanh(sDVariable), str);
    }

    public SDVariable exp(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.exp(sDVariable), str);
    }

    public SDVariable expm1(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.expm1(sDVariable), str);
    }

    public SDVariable rsqrt(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.rsqrt(sDVariable), str);
    }

    public SDVariable log(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.log(sDVariable), str);
    }

    public SDVariable log1p(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.log1p(sDVariable), str);
    }

    public SDVariable isFinite(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.isFinite(sDVariable), str);
    }

    public SDVariable isInfinite(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.isInfinite(sDVariable), str);
    }

    public SDVariable isNaN(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.isNaN(sDVariable), str);
    }

    public SDVariable round(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.round(sDVariable), str);
    }

    public SDVariable pow(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(this.functionFactory.pow(sDVariable, d), str);
    }

    public SDVariable cube(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.cube(sDVariable), str);
    }

    public SDVariable sqrt(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.sqrt(sDVariable), str);
    }

    public SDVariable square(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.square(sDVariable), str);
    }

    public SDVariable floor(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.floor(sDVariable), str);
    }

    public SDVariable relu(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(this.functionFactory.relu(sDVariable, d), str);
    }

    public SDVariable relu6(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(this.functionFactory.relu6(sDVariable, d), str);
    }

    public SDVariable softmax(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.softmax(sDVariable), str);
    }

    public SDVariable softmaxDerivative(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(this.functionFactory.softmaxDerivative(sDVariable, sDVariable2), str);
    }

    public SDVariable hardTanh(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.hardTanh(sDVariable), str);
    }

    public SDVariable hardTanhDerivative(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.hardTanhDerivative(sDVariable), str);
    }

    public SDVariable sigmoid(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.sigmoid(sDVariable), str);
    }

    public SDVariable sigmoidDerivative(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(this.functionFactory.sigmoidDerivative(sDVariable, sDVariable2), str);
    }

    public SDVariable sign(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.sign(sDVariable), str);
    }

    public SDVariable softsign(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.softsign(sDVariable), str);
    }

    public SDVariable softsignDerivative(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.softsignDerivative(sDVariable), str);
    }

    public SDVariable softplus(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.softplus(sDVariable), str);
    }

    public SDVariable elu(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.elu(sDVariable), str);
    }

    public SDVariable eluDerivative(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.eluDerivative(sDVariable), str);
    }

    public SDVariable leakyRelu(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(this.functionFactory.leakyRelu(sDVariable, d), str);
    }

    public SDVariable leakyReluDerivative(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(this.functionFactory.leakyReluDerivative(sDVariable, d), str);
    }

    public SDVariable mean(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.mean(sDVariable, new int[0]), str);
    }

    public SDVariable mean(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.mean(sDVariable, iArr), str);
    }

    public SDVariable standardDeviation(String str, SDVariable sDVariable, boolean z, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.std(sDVariable, z, iArr), str);
    }

    public SDVariable variance(String str, SDVariable sDVariable, boolean z, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.variance(sDVariable, z, iArr), str);
    }

    public SDVariable sum(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.sum(sDVariable, iArr), str);
    }

    public SDVariable prod(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.prod(sDVariable, iArr), str);
    }

    public SDVariable max(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.max(sDVariable, iArr), str);
    }

    public SDVariable min(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.min(sDVariable, iArr), str);
    }

    public SDVariable norm1(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(f().norm1(sDVariable, iArr), str);
    }

    public SDVariable norm2(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(f().norm2(sDVariable, iArr), str);
    }

    public SDVariable normmax(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(f().normmax(sDVariable, iArr), str);
    }

    public SDVariable reshape(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.reshape(sDVariable, iArr), str);
    }

    public SDVariable transpose(String str, SDVariable sDVariable) {
        return updateVariableNameAndReference(this.functionFactory.transpose(sDVariable), str);
    }

    public SDVariable permute(String str, SDVariable sDVariable, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.permute(sDVariable, iArr), str);
    }

    public SDVariable rollAxis(String str, SDVariable sDVariable, int i) {
        return updateVariableNameAndReference(this.functionFactory.rollAxis(sDVariable, i), str);
    }

    public SDVariable fill(String str, SDVariable sDVariable, double d) {
        return updateVariableNameAndReference(this.functionFactory.fill(sDVariable, d), str);
    }

    public SDVariable concat(String str, int i, SDVariable... sDVariableArr) {
        return updateVariableNameAndReference(this.functionFactory.concat(i, sDVariableArr), str);
    }

    public SDVariable tile(String str, SDVariable sDVariable, int[] iArr) {
        return updateVariableNameAndReference(this.functionFactory.tile(sDVariable, iArr), str);
    }

    public SDVariable mmul(String str, SDVariable sDVariable, SDVariable sDVariable2, MMulTranspose mMulTranspose) {
        return updateVariableNameAndReference(this.functionFactory.mmul(sDVariable, sDVariable2, mMulTranspose), str);
    }

    public SDVariable mmul(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return mmul(str, sDVariable, sDVariable2, MMulTranspose.allFalse());
    }

    public SDVariable tensorMmul(String str, SDVariable sDVariable, SDVariable sDVariable2, int[][] iArr) {
        return updateVariableNameAndReference(this.functionFactory.tensorMmul(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable cosineSimilarity(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.cosineSimilarity(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable euclideanDistance(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.euclideanDistance(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable manhattanDistance(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.manhattanDistance(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable sigmoidCrossEntropyWithLogits(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, int i, double d) {
        return sigmoidCrossEntropyWithLogits(null, sDVariable, sDVariable2, sDVariable3, i, d);
    }

    public SDVariable sigmoidCrossEntropyWithLogits(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, int i, double d) {
        return updateVariableNameAndReference(f().sigmoidCrossEntropyWithLogits(sDVariable, sDVariable2, sDVariable3, i, d), str);
    }

    public SDVariable softmaxCrossEntropyWithLogits(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, int i, double d) {
        return softmaxCrossEntropyWithLogits(null, sDVariable, sDVariable2, sDVariable3, i, d);
    }

    public SDVariable softmaxCrossEntropyWithLogits(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, int i, double d) {
        return updateVariableNameAndReference(f().softmaxCrossEntropyWithLogits(sDVariable, sDVariable2, sDVariable3, i, d), str);
    }

    public SDVariable weightedCrossEntropyWithLogits(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return weightedCrossEntropyWithLogits(null, sDVariable, sDVariable2, sDVariable3);
    }

    public SDVariable weightedCrossEntropyWithLogits(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return updateVariableNameAndReference(f().weightedCrossEntropyWithLogits(sDVariable, sDVariable2, sDVariable3), str);
    }

    public SDVariable lossBinaryXENT(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossBinaryXENT(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossCosineSimilarity(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossCosineSimilarity(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossHinge(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossHinge(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossKLD(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossKLD(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossL1(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossL1(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossL2(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossL2(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossMAE(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossMAE(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossMSE(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossMSE(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossMCXENT(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossMCXENT(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossMSLE(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossMSLE(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossNegativeLogLikelihood(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossNegativeLogLikelihood(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossPoisson(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossPoisson(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable lossSquaredHinge(String str, SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return updateVariableNameAndReference(this.functionFactory.lossSquaredHinge(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable expandDims(SDVariable sDVariable, int i) {
        return expandDims(null, sDVariable, i);
    }

    public SDVariable expandDims(String str, SDVariable sDVariable, int i) {
        return updateVariableNameAndReference(f().expandDims(sDVariable, i), str);
    }

    public SDVariable squeeze(SDVariable sDVariable, int i) {
        return squeeze(null, sDVariable, i);
    }

    public SDVariable squeeze(String str, SDVariable sDVariable, int i) {
        return updateVariableNameAndReference(f().squeeze(sDVariable, i), str);
    }

    public SDVariable confusionMatrix(SDVariable sDVariable, SDVariable sDVariable2) {
        return confusionMatrix((String) null, sDVariable, sDVariable2);
    }

    public SDVariable confusionMatrix(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return updateVariableNameAndReference(f().confusionMatrix(sDVariable, sDVariable2), str);
    }

    public SDVariable confusionMatrix(SDVariable sDVariable, SDVariable sDVariable2, Integer num) {
        return confusionMatrix((String) null, sDVariable, sDVariable2, num);
    }

    public SDVariable confusionMatrix(String str, SDVariable sDVariable, SDVariable sDVariable2, Integer num) {
        return updateVariableNameAndReference(f().confusionMatrix(sDVariable, sDVariable2, num), str);
    }

    public SDVariable confusionMatrix(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return confusionMatrix((String) null, sDVariable, sDVariable2, sDVariable3);
    }

    public SDVariable confusionMatrix(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return updateVariableNameAndReference(f().confusionMatrix(sDVariable, sDVariable2, sDVariable3), str);
    }

    public SDVariable confusionMatrix(SDVariable sDVariable, SDVariable sDVariable2, Integer num, SDVariable sDVariable3) {
        return confusionMatrix(null, sDVariable, sDVariable2, num, sDVariable3);
    }

    public SDVariable confusionMatrix(String str, SDVariable sDVariable, SDVariable sDVariable2, Integer num, SDVariable sDVariable3) {
        return updateVariableNameAndReference(f().confusionMatrix(sDVariable, sDVariable2, num, sDVariable3), str);
    }

    public void addVariable(SDVariable sDVariable) {
        if (this.variableMap == null) {
            this.variableMap = new HashMap();
        }
        Preconditions.checkState(sDVariable.getSameDiff() == this, "Samediff instance must be the same.");
        if (this.variableMap.containsKey(sDVariable.getVarName()) && !this.variableMap.get(sDVariable.getVarName()).equals(sDVariable)) {
            throw new IllegalArgumentException("Variable already found with variable opName " + sDVariable.getVarName());
        }
        Preconditions.checkState(sDVariable.getSameDiff() == this, "Same diff instance for variable must be the same!");
        this.variableMap.put(sDVariable.getVarName(), sDVariable);
    }

    public String generateNewVarName(String str, int i) {
        String str2;
        if (getVariable(str) == null && i == 0) {
            return str;
        }
        int i2 = 1;
        String str3 = str + "_1" + (i > 0 ? ":" + i : "");
        while (true) {
            str2 = str3;
            if (getVariable(str2) == null) {
                break;
            }
            i2++;
            str3 = str + "_" + i2 + (i > 0 ? ":" + i : "");
        }
        if (getVariable(str2) != null) {
            throw new ND4JIllegalStateException("Converged on already generated variable!");
        }
        return str2;
    }

    public SDVariable lstm(String str, LSTMCellConfiguration lSTMCellConfiguration) {
        return new LSTMCell(this, lSTMCellConfiguration).outputVariables(str)[0];
    }

    public SDVariable sruCell(SRUCellConfiguration sRUCellConfiguration) {
        return new SRUCell(this, sRUCellConfiguration).outputVariables()[0];
    }

    public SDVariable sru(SRUConfiguration sRUConfiguration) {
        return new SRU(this, sRUConfiguration).outputVariables()[0];
    }

    public SDVariable gru(GRUCellConfiguration gRUCellConfiguration) {
        return new GRUCell(this, gRUCellConfiguration).outputVariables()[0];
    }

    public SDVariable sruCell(String str, SRUCellConfiguration sRUCellConfiguration) {
        return new SRUCell(this, sRUCellConfiguration).outputVariables(str)[0];
    }

    public SDVariable sru(String str, SRUConfiguration sRUConfiguration) {
        return new SRU(this, sRUConfiguration).outputVariables(str)[0];
    }

    public SDVariable gru(String str, GRUCellConfiguration gRUCellConfiguration) {
        return new GRUCell(this, gRUCellConfiguration).outputVariables(str)[0];
    }

    public SDVariable slice(SDVariable sDVariable, int[] iArr, int[] iArr2) {
        return slice(null, sDVariable, iArr, iArr2);
    }

    public SDVariable slice(String str, SDVariable sDVariable, int[] iArr, int[] iArr2) {
        return updateVariableNameAndReference(f().slice(sDVariable, iArr, iArr2), str);
    }

    public SDVariable stridedSlice(SDVariable sDVariable, int[] iArr, int[] iArr2, int[] iArr3) {
        return stridedSlice(null, sDVariable, iArr, iArr2, iArr3);
    }

    public SDVariable stridedSlice(String str, SDVariable sDVariable, int[] iArr, int[] iArr2, int[] iArr3) {
        return stridedSlice(str, sDVariable, iArr, iArr2, iArr3, 0, 0, 0, 0, 0);
    }

    public SDVariable stridedSlice(SDVariable sDVariable, int[] iArr, int[] iArr2, int[] iArr3, int i, int i2, int i3, int i4, int i5) {
        return stridedSlice(null, sDVariable, iArr, iArr2, iArr3, i, i2, i3, i4, i5);
    }

    public SDVariable stridedSlice(String str, SDVariable sDVariable, int[] iArr, int[] iArr2, int[] iArr3, int i, int i2, int i3, int i4, int i5) {
        return updateVariableNameAndReference(f().stridedSlice(sDVariable, iArr, iArr2, iArr3, i, i2, i3, i4, i5), str);
    }

    public SDVariable scatterAdd(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return updateVariableNameAndReference(f().scatterAdd(sDVariable, sDVariable2, sDVariable3), str);
    }

    public SDVariable scatterMul(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return updateVariableNameAndReference(f().scatterMul(sDVariable, sDVariable2, sDVariable3), str);
    }

    public SDVariable scatterSub(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return updateVariableNameAndReference(f().scatterSub(sDVariable, sDVariable2, sDVariable3), str);
    }

    public SDVariable scatterDiv(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return updateVariableNameAndReference(f().scatterDiv(sDVariable, sDVariable2, sDVariable3), str);
    }

    public SDVariable scatterAdd(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return scatterAdd(null, sDVariable, sDVariable2, sDVariable3);
    }

    public SDVariable scatterMul(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return scatterMul(null, sDVariable, sDVariable2, sDVariable3);
    }

    public SDVariable scatterSub(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return scatterSub(null, sDVariable, sDVariable2, sDVariable3);
    }

    public SDVariable scatterDiv(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return scatterDiv(null, sDVariable, sDVariable2, sDVariable3);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public SDVariable[] generateOutputVariableForOp(DifferentialFunction differentialFunction, String str) {
        String str2;
        if (str == null || (str.isEmpty() && getBaseNameForFunction(differentialFunction) != null)) {
            str = getBaseNameForFunction(differentialFunction);
        }
        if (str == null) {
            str = differentialFunction.opName();
        }
        List<int[]> calculateOutputShape = differentialFunction.calculateOutputShape();
        if (calculateOutputShape == null || calculateOutputShape.isEmpty()) {
            if (differentialFunction instanceof CustomOp) {
                CustomOpDescriptor descriptor = ((CustomOp) differentialFunction).getDescriptor();
                if (descriptor == null || descriptor.getNumOutputs() <= 0) {
                    throw new ND4JIllegalStateException("No output variables found!");
                }
                char ordering = differentialFunction.args()[0].getArr() != null ? differentialFunction.args()[0].getArr().ordering() : 'c';
                SDVariable[] sDVariableArr = new SDVariable[descriptor.getNumOutputs()];
                for (int i = 0; i < sDVariableArr.length; i++) {
                    SDVariable variable = getVariable(str);
                    if (variable == null) {
                        variable = var(generateNewVarName(str, i), null, new ZeroInitScheme(ordering));
                    } else if (i > 0 && !this.importedVarName.contains(str)) {
                        variable = getVariable(generateNewVarName(str, i));
                    }
                    if (variable == null) {
                        variable = var(generateNewVarName(str, i), null, new ZeroInitScheme(ordering));
                    }
                    sDVariableArr[i] = variable;
                }
                return sDVariableArr;
            }
            if ((differentialFunction instanceof BaseOp) && calculateOutputShape.isEmpty()) {
                SDVariable[] sDVariableArr2 = new SDVariable[1];
                SDVariable variable2 = getVariable(str);
                char ordering2 = differentialFunction.args()[0].getArr() != null ? differentialFunction.args()[0].getArr().ordering() : 'c';
                if (variable2 == null) {
                    variable2 = var(str, null, new ZeroInitScheme(ordering2));
                } else if (!this.importedVarName.contains(str)) {
                    variable2 = var(generateNewVarName(str, 0), null, new ZeroInitScheme(ordering2));
                }
                if (variable2 == null) {
                    variable2 = var(str, null, new ZeroInitScheme(ordering2));
                }
                sDVariableArr2[0] = variable2;
                return sDVariableArr2;
            }
        }
        char ordering3 = differentialFunction.args()[0].getArr() != null ? differentialFunction.args()[0].getArr().ordering() : 'c';
        SDVariable[] sDVariableArr3 = new SDVariable[calculateOutputShape.size()];
        differentialFunction.getOwnName();
        String str3 = str;
        int i2 = 0;
        while (i2 < sDVariableArr3.length) {
            int[] iArr = calculateOutputShape.get(i2);
            String str4 = str3 + (i2 > 0 ? ":" + i2 : "");
            SDVariable variable3 = getVariable(str4);
            if (variable3 == null) {
                variable3 = var(str4, iArr, new ZeroInitScheme(ordering3));
            } else if (iArr != null && !shapeAlreadyExistsForVarName(variable3.getVarName())) {
                putShapeForVarName(variable3.getVarName(), iArr);
            } else if ((iArr == null || !shapeAlreadyExistsForVarName(variable3.getVarName())) && !this.importedVarName.contains(str4)) {
                int i3 = 1;
                String str5 = str4 + "_1" + (i2 > 0 ? ":" + i2 : "");
                while (true) {
                    str2 = str5;
                    if (getVariable(str2) == null) {
                        break;
                    }
                    i3++;
                    str5 = str4 + "_" + i3 + (i2 > 0 ? ":" + i2 : "");
                }
                if (getVariable(str2) != null) {
                    throw new ND4JIllegalStateException("Converged on already generated variable!");
                }
                variable3 = var(str2, iArr, new ZeroInitScheme(ordering3));
            }
            if (variable3 == null) {
                variable3 = var(str4 + (i2 > 0 ? ":" + i2 : ""), iArr, new ZeroInitScheme(ordering3));
            }
            sDVariableArr3[i2] = variable3;
            i2++;
        }
        return sDVariableArr3;
    }

    public SDVariable[] generateOutputVariableForOp(DifferentialFunction differentialFunction) {
        return generateOutputVariableForOp(differentialFunction, differentialFunction.opName());
    }

    public SameDiff getFunction(String str) {
        return this.sameDiffFunctionInstances.get(str);
    }

    public INDArray execAndEndResult(List<DifferentialFunction> list) {
        List<DifferentialFunction> exec = exec(list);
        return ((Op) exec.get(exec.size() - 1)).z();
    }

    public INDArray execAndEndResult() {
        resolveVariablesWith(Collections.emptyMap());
        List<DifferentialFunction> right = exec().getRight();
        return right.get(right.size() - 1).outputVariables()[0].getArr();
    }

    public INDArray yetAnotherExecMethod(@NonNull Map<String, INDArray> map) {
        if (map == null) {
            throw new NullPointerException("inputs");
        }
        if (!this.wasRegistered.get()) {
            synchronized (this) {
                if (!this.wasRegistered.get()) {
                    Nd4j.getExecutioner().registerGraph(hashCode(), new BytePointer(asFlatBuffers()));
                    this.wasRegistered.set(true);
                }
            }
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (String str : map.keySet()) {
            linkedHashMap.put(this.variableMap.get(str).getVarName(), map.get(str));
        }
        Map<String, INDArray> executeGraph = Nd4j.getExecutioner().executeGraph(hashCode(), linkedHashMap);
        if (executeGraph.size() == 0) {
            throw new ND4JIllegalStateException("Execution failed");
        }
        ArrayList arrayList = new ArrayList(executeGraph.values());
        return (INDArray) arrayList.get(arrayList.size() - 1);
    }

    public List<DifferentialFunction> exec(List<DifferentialFunction> list) {
        for (int i = 0; i < list.size(); i++) {
            Nd4j.getExecutioner().exec((Op) list.get(i));
        }
        return list;
    }

    public While whileStatement(SameDiffConditional sameDiffConditional, SameDiffFunctionDefinition sameDiffFunctionDefinition, SameDiffFunctionDefinition sameDiffFunctionDefinition2, SDVariable[] sDVariableArr) {
        return While.builder().inputVars(sDVariableArr).condition(sameDiffFunctionDefinition).predicate(sameDiffConditional).trueBody(sameDiffFunctionDefinition2).parent(this).blockName("while-" + UUID.randomUUID().toString()).build();
    }

    public If ifStatement(SameDiffConditional sameDiffConditional, SameDiffFunctionDefinition sameDiffFunctionDefinition, SameDiffFunctionDefinition sameDiffFunctionDefinition2, SameDiffFunctionDefinition sameDiffFunctionDefinition3, SDVariable[] sDVariableArr) {
        return If.builder().conditionBody(sameDiffFunctionDefinition).falseBody(sameDiffFunctionDefinition3).trueBody(sameDiffFunctionDefinition2).predicate(sameDiffConditional).inputVars(sDVariableArr).parent(this).blockName("if-" + UUID.randomUUID().toString()).build();
    }

    public SDVariable invokeFunctionOn(String str, SameDiff sameDiff) {
        return this.sameDiffFunctionInstances.get(str).invokeGraphOn(sameDiff);
    }

    public SameDiff defineFunction(String str, SameDiffFunctionDefinition sameDiffFunctionDefinition, SDVariable[] sDVariableArr) {
        if (!this.sameDiffFunctionInstances.containsKey(str)) {
            SameDiff create = create();
            create.workspace = this.workspace;
            SDVariable[] sDVariableArr2 = new SDVariable[sDVariableArr.length];
            for (int i = 0; i < sDVariableArr2.length; i++) {
                sDVariableArr2[i] = create.var(sDVariableArr[i]);
            }
            sameDiffFunctionDefinition.define(create, null, sDVariableArr2);
            this.sameDiffFunctionInstances.put(str, create);
        }
        return this.sameDiffFunctionInstances.get(str);
    }

    public void defineFunction(String str, SameDiffFunctionDefinition sameDiffFunctionDefinition) {
        defineFunction(str, sameDiffFunctionDefinition, new LinkedHashMap());
    }

    public void defineFunction(String str, SameDiffFunctionDefinition sameDiffFunctionDefinition, Map<String, INDArray> map) {
        if (this.sameDiffFunctionInstances.containsKey(str)) {
            return;
        }
        SameDiff create = create();
        create.workspace = this.workspace;
        sameDiffFunctionDefinition.define(create, map, null);
        this.sameDiffFunctionInstances.put(str, create);
    }

    public INDArray execAndEndResult(String str) {
        return this.sameDiffFunctionInstances.get(str).execAndEndResult();
    }

    public Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> exec(String str) {
        return this.debugMode ? this.sameDiffFunctionInstances.get(str).enableDebugMode().exec() : this.sameDiffFunctionInstances.get(str).exec();
    }

    public List<DifferentialFunction> exec(String str, List<DifferentialFunction> list) {
        return this.sameDiffFunctionInstances.get(str).exec(list);
    }

    public Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> execBackwards() {
        if (getFunction("grad") == null) {
            defineFunction("grad", new SameDiffFunctionDefinition() { // from class: org.nd4j.autodiff.samediff.SameDiff.3
                @Override // org.nd4j.autodiff.samediff.SameDiff.SameDiffFunctionDefinition
                public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> map, SDVariable[] sDVariableArr) {
                    if (SameDiff.this.debugMode) {
                        sameDiff.enableDebugMode();
                    }
                    this.invokeGraphOn(sameDiff);
                    ArrayList<DifferentialFunction> arrayList = new ArrayList(sameDiff.functionInstancesById.values());
                    if (arrayList.isEmpty()) {
                        throw new ND4JIllegalStateException("No ops found!");
                    }
                    for (DifferentialFunction differentialFunction : arrayList) {
                        if (!(differentialFunction instanceof SDVariable)) {
                            for (SDVariable sDVariable : differentialFunction.args()) {
                                sDVariable.setSameDiff(sameDiff);
                            }
                            for (SDVariable sDVariable2 : differentialFunction.outputVariables()) {
                                sDVariable2.setSameDiff(sameDiff);
                            }
                            differentialFunction.setSameDiff(sameDiff);
                        }
                    }
                    SDVariable sDVariable3 = ((DifferentialFunction) arrayList.get(arrayList.size() - 1)).outputVariables()[0];
                    SDVariable var = sameDiff.var("one-var", Nd4j.scalar(1.0d));
                    sameDiff.forwardVarForGrad.put(sDVariable3.getVarName(), var);
                    sameDiff.gradients.put(sDVariable3.getVarName(), var);
                    sameDiff.gradientBackwardsMarker(sDVariable3);
                    ArrayList<DifferentialFunction> arrayList2 = new ArrayList(sameDiff.functionInstancesById.values());
                    Collections.reverse(arrayList2);
                    for (DifferentialFunction differentialFunction2 : arrayList2) {
                        if (differentialFunction2 instanceof GradientBackwardsMarker) {
                            SameDiff.log.warn("Action op state is null for " + differentialFunction2.opName());
                        } else {
                            Preconditions.checkState(differentialFunction2.getSameDiff() == sameDiff, "Wrong samediff instance found!");
                            SDVariable[] outputVariables = differentialFunction2.outputVariables();
                            for (SDVariable sDVariable4 : outputVariables) {
                                if (sDVariable4.getSameDiff() != sameDiff) {
                                    sDVariable4.setSameDiff(sameDiff);
                                }
                            }
                            ArrayList arrayList3 = new ArrayList();
                            for (SDVariable sDVariable5 : outputVariables) {
                                SDVariable gradient = sDVariable5.gradient();
                                if (gradient == null) {
                                    throw new ND4JIllegalStateException("No gradient found for " + sDVariable5.getVarName());
                                }
                                arrayList3.add(gradient);
                            }
                            differentialFunction2.diff(arrayList3);
                        }
                    }
                    if (sameDiff.isDebugMode()) {
                        Iterator<SDVariable> it = SameDiff.this.variables().iterator();
                        while (it.hasNext()) {
                            it.next().gradient();
                        }
                    }
                    return new SDVariable[]{sameDiff.var("grad", new int[]{1, 1})};
                }
            });
        }
        Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> exec = exec("grad");
        SameDiff function = getFunction("grad");
        if (function.isDebugMode()) {
            Iterator<SDVariable> it = function.variables().iterator();
            while (it.hasNext()) {
                it.next().gradient();
            }
        }
        return exec;
    }

    public INDArray execBackwardAndEndResult() {
        List<DifferentialFunction> right = execBackwards().getRight();
        Object obj = (DifferentialFunction) right.get(right.size() - 1);
        if (obj instanceof Op) {
            return ((Op) obj).z();
        }
        if (obj instanceof DynamicCustomOp) {
            return ((DynamicCustomOp) obj).getOutputArgument(0);
        }
        return null;
    }

    public INDArray execWithPlaceHolderAndEndResult(Map<String, INDArray> map) {
        resolveVariablesWith(map);
        return execAndEndResult();
    }

    public void setOriginalPlaceHolderShape(String str, int[] iArr) {
        if (!isPlaceHolder(str)) {
            throw new ND4JIllegalStateException("Vertex id " + str + " does not appear to be a place holder. Did you forget to call addPlaceHolder?");
        }
        if (iArr == null) {
            throw new ND4JIllegalStateException("Null and 0 length shape arrays not allowed");
        }
        if (this.placeHolderOriginalShapes.containsKey(str) && !Arrays.equals(this.placeHolderOriginalShapes.get(str), iArr)) {
            throw new ND4JIllegalStateException("Unable to add a new shape for vertex id " + str);
        }
        this.placeHolderOriginalShapes.put(str, iArr);
    }

    public int[] getOriginalShapeForPlaceHolder(String str) {
        return this.placeHolderOriginalShapes.get(str);
    }

    public boolean isPlaceHolder(String str) {
        return this.placeHolderVarNames.contains(str);
    }

    public void addAsPlaceHolder(String str) {
        this.placeHolderVarNames.add(str);
        if (getVariable(str) == null || getVariable(str).getShape() == null) {
            return;
        }
        this.placeHolderOriginalShapes.put(str, getVariable(str).getShape());
    }

    public void resolveVariablesWith(Map<String, INDArray> map) {
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            if (getVariable(entry.getKey()) == null) {
                throw new ND4JIllegalStateException("No variable name found for " + entry.getKey());
            }
            if (this.placeHolderOriginalShapes.containsKey(entry.getKey())) {
                int[] iArr = this.placeHolderOriginalShapes.get(entry.getKey());
                if (iArr.length == entry.getValue().rank()) {
                    for (int i = 0; i < iArr.length; i++) {
                        if (iArr[i] != entry.getValue().shape()[i] && iArr[i] >= 1) {
                            throw new ND4JIllegalStateException("Incompatible shape passed for variable. " + Arrays.toString(entry.getValue().shape()));
                        }
                    }
                } else {
                    continue;
                }
            }
        }
        for (Map.Entry<String, INDArray> entry2 : map.entrySet()) {
            if (!this.placeHolderVarNames.contains(entry2.getKey())) {
                throw new ND4JIllegalStateException("Illegal variable " + entry2.getKey() + " passed in. Variable found not to be a place holder variable");
            }
            int[] originalShapeForPlaceHolder = getOriginalShapeForPlaceHolder(entry2.getKey());
            if (!Shape.isPlaceholderShape(originalShapeForPlaceHolder) && !Shape.shapeEquals(originalShapeForPlaceHolder, entry2.getValue().shape())) {
                throw new ND4JIllegalStateException("Place holder shape specified was " + Arrays.toString(originalShapeForPlaceHolder) + " but array shape was " + Arrays.toString(entry2.getValue().shape()));
            }
            updateShapeForVarName(entry2.getKey(), entry2.getValue().shape());
            associateArrayWithVariable(entry2.getValue(), getVariable(entry2.getKey()));
            updateArrayForVarName(entry2.getKey(), entry2.getValue());
        }
        for (String str : this.propertiesToResolve.keySet()) {
            Object obj = (DifferentialFunction) this.functionInstancesById.get(str);
            if (!this.functionInstancesById.containsKey(str)) {
                throw new ND4JIllegalStateException("Unable to resolve function name " + str);
            }
            if (obj instanceof CustomOp) {
                ((CustomOp) obj).populateInputsAndOutputsFromSameDiff();
            }
        }
        this.resolvedVariables = true;
    }

    public boolean allPlaceHolderVariablesResolved() {
        Iterator<String> it = this.placeHolderVarNames.iterator();
        while (it.hasNext()) {
            if (getVariable(it.next()).getArr() == null) {
                return false;
            }
        }
        return true;
    }

    public void putPlaceHolderForVariable(String str, String... strArr) {
        for (String str2 : strArr) {
            if (!this.variableMap.containsKey(str2)) {
                throw new ND4JIllegalStateException("No variable found for " + str2);
            }
        }
        List<String[]> list = this.placeHolderMap.get(str);
        if (list == null) {
            list = new ArrayList();
            this.placeHolderMap.put(str, list);
        }
        list.add(strArr);
    }

    public boolean hasPlaceHolderVariables(String str) {
        return this.placeHolderMap.containsKey(str);
    }

    public List<String[]> getPlaceHoldersFor(String str) {
        return this.placeHolderMap.get(str);
    }

    public Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> execWithPlaceHolder(Map<String, INDArray> map) {
        resolveVariablesWith(map);
        return exec();
    }

    public List<SDVariable> getVariablesAssociatedWithFunctions(List<DifferentialFunction> list) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<DifferentialFunction> it = list.iterator();
        while (it.hasNext()) {
            arrayList.addAll(Arrays.asList(it.next().outputVariables()));
        }
        return arrayList;
    }

    public SDVariable updateVariableNameAndReference(SDVariable sDVariable, String str) {
        if (sDVariable == null) {
            throw new NullPointerException("Null input: No variable found for updating!");
        }
        if (str == null && this.variableMap.containsKey(sDVariable.getVarName())) {
            str = generateNewVarName(sDVariable.getVarName(), 0);
        }
        if (str == null || sDVariable.getVarName().equals(str)) {
            return sDVariable;
        }
        String varName = sDVariable.getVarName();
        sDVariable.setVarName(str);
        updateVariableName(varName, str);
        return sDVariable;
    }

    public SDVariable[] updateVariableNamesAndReferences(SDVariable[] sDVariableArr, String[] strArr) {
        int length = sDVariableArr.length;
        SDVariable[] sDVariableArr2 = new SDVariable[length];
        for (int i = 0; i < length; i++) {
            sDVariableArr2[i] = updateVariableNameAndReference(sDVariableArr[i], strArr == null ? null : strArr[i]);
        }
        return sDVariableArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> exec() {
        DifferentialFunction differentialFunction;
        if (!this.resolvedVariables) {
            resolveVariablesWith(new LinkedHashMap<>());
        }
        ArrayList arrayList = new ArrayList();
        this.localFlowPath.set(new FlowPath());
        FlowPath flowPath = this.localFlowPath.get();
        HashMap hashMap = new HashMap();
        ArrayList arrayList2 = new ArrayList(this.functionInstancesById.values());
        boolean z = false;
        ArrayDeque arrayDeque = new ArrayDeque();
        boolean z2 = false;
        int i = 0;
        int i2 = 0;
        while (i < arrayList2.size()) {
            i2++;
            String opName = ((DifferentialFunction) arrayList2.get(i)).opName();
            if (!z && opName.equals(new GradientBackwardsMarker().opName())) {
                z = true;
            }
            if (!opName.equals(new GradientBackwardsMarker().opName())) {
                DifferentialFunction differentialFunction2 = (DifferentialFunction) arrayList2.get(i);
                String ownName = differentialFunction2.getOwnName();
                flowPath.ensureNodeStateExists(differentialFunction2.getOwnName());
                if (differentialFunction2 instanceof SDVariable) {
                    continue;
                } else {
                    String[] inputsForFunction = getInputsForFunction(differentialFunction2);
                    log.debug("Step: {}; Executing op {} for node [{}]", Integer.valueOf(i2), opName, ownName);
                    boolean z3 = false;
                    if (differentialFunction2 instanceof Merge) {
                        String str = inputsForFunction[0];
                        String str2 = inputsForFunction[1];
                        if (!flowPath.isActive(str) && !flowPath.isActive(str2)) {
                            z3 = true;
                        }
                    } else if (!(differentialFunction2 instanceof Exit)) {
                        if (z2) {
                            z2 = false;
                            String str3 = (String) arrayDeque.removeLast();
                            flowPath.activateFrame(str3, false);
                            flowPath.forgetFrame(str3);
                        }
                        int length = inputsForFunction.length;
                        int i3 = 0;
                        while (true) {
                            if (i3 >= length) {
                                break;
                            }
                            if (!flowPath.isActive(inputsForFunction[i3])) {
                                flowPath.markActive(differentialFunction2.getOwnName(), false);
                                z3 = true;
                                break;
                            }
                            i3++;
                        }
                    }
                    if (z3) {
                        continue;
                    } else {
                        differentialFunction2.resolvePropertiesFromSameDiffBeforeExecution();
                        flowPath.markActive(differentialFunction2.getOwnName(), true);
                        if (differentialFunction2 instanceof LoopCond) {
                            INDArray arr = getInputVariablesForFunction(differentialFunction2)[0].getArr();
                            this.variableNameToArr.put(differentialFunction2.getOwnName(), arr.dup(arr.ordering()));
                            flowPath.markExecuted(differentialFunction2.getOwnName(), true);
                            if (((int) arr.getDouble(0)) == 1) {
                                flowPath.incrementNumberOfCycles((String) arrayDeque.getLast());
                            }
                        } else if (differentialFunction2 instanceof Enter) {
                            INDArray arr2 = getInputVariablesForFunction(differentialFunction2)[0].getArr();
                            this.variableNameToArr.put(differentialFunction2.getOwnName(), arr2.dup(arr2.ordering()));
                            flowPath.markExecuted(differentialFunction2.getOwnName(), true);
                            String frameName = ((Enter) differentialFunction2).getFrameName();
                            if (!flowPath.isRegisteredFrame(frameName)) {
                                flowPath.registerFrame(frameName);
                                arrayDeque.addLast(frameName);
                            }
                        } else if (differentialFunction2 instanceof Exit) {
                            String str4 = (String) arrayDeque.getLast();
                            ((Exit) differentialFunction2).setFrameName(str4);
                            if (!flowPath.isFrameActive(str4)) {
                                flowPath.markActive(differentialFunction2.getOwnName(), false);
                                z2 = true;
                            } else if (flowPath.isRewindPlanned(str4)) {
                                flowPath.planRewind(str4, false);
                                i = flowPath.getRewindPosition(str4);
                                int i4 = i + 1;
                                flowPath.setRewindPosition(str4, -1);
                            } else {
                                INDArray arr3 = getInputVariablesForFunction(differentialFunction2)[0].getArr();
                                this.variableNameToArr.put(differentialFunction2.getOwnName(), arr3.dup(arr3.ordering()));
                                flowPath.markExecuted(differentialFunction2.getOwnName(), true);
                                z2 = true;
                            }
                        } else if (differentialFunction2 instanceof NextIteration) {
                            SDVariable[] inputVariablesForFunction = getInputVariablesForFunction(differentialFunction2);
                            String str5 = (String) arrayDeque.getLast();
                            INDArray arr4 = inputVariablesForFunction[0].getArr();
                            this.variableNameToArr.put(differentialFunction2.getOwnName(), arr4.dup(arr4.ordering()));
                            flowPath.markExecuted(differentialFunction2.getOwnName(), true);
                            if (!flowPath.isRewindPlanned(str5)) {
                                flowPath.planRewind(str5, true);
                            }
                        } else if (differentialFunction2 instanceof Merge) {
                            SDVariable[] inputVariablesForFunction2 = getInputVariablesForFunction(differentialFunction2);
                            String str6 = arrayDeque.size() > 0 ? (String) arrayDeque.getLast() : null;
                            if (str6 != null) {
                                flowPath.activateFrame(str6, true);
                            }
                            if (str6 != null) {
                                flowPath.setRewindPositionOnce(str6, i - 1);
                            }
                            if (inputVariablesForFunction2.length == 2 && (differentialFunction = this.functionInstancesById.get(inputVariablesForFunction2[1].getVarName())) != null && (differentialFunction instanceof NextIteration)) {
                                ((NextIteration) differentialFunction).setFrameName(str6);
                            }
                            if (flowPath.wasExecuted(inputVariablesForFunction2[1].getVarName())) {
                                INDArray arr5 = inputVariablesForFunction2[1].getArr();
                                this.variableNameToArr.put(differentialFunction2.getOwnName(), arr5.dup(arr5.ordering()));
                                flowPath.markExecuted(inputVariablesForFunction2[1].getVarName(), false);
                            } else {
                                INDArray arr6 = inputVariablesForFunction2[0].getArr();
                                this.variableNameToArr.put(differentialFunction2.getOwnName(), arr6.dup(arr6.ordering()));
                            }
                            flowPath.markExecuted(differentialFunction2.getOwnName(), true);
                        } else if (differentialFunction2 instanceof Switch) {
                            ((CustomOp) differentialFunction2).populateInputsAndOutputsFromSameDiff();
                            SDVariable[] inputVariablesForFunction3 = getInputVariablesForFunction(differentialFunction2);
                            INDArray arr7 = inputVariablesForFunction3[0].getArr();
                            if (((int) inputVariablesForFunction3[1].getArr().getDouble(0)) == 0) {
                                flowPath.setActiveBranch(differentialFunction2.getOwnName(), 0);
                                flowPath.markActive(differentialFunction2.getOwnName(), true);
                                flowPath.markActive(differentialFunction2.getOwnName() + ":1", false);
                                this.variableNameToArr.put(differentialFunction2.getOwnName(), arr7.dup(arr7.ordering()));
                            } else {
                                flowPath.setActiveBranch(differentialFunction2.getOwnName(), 1);
                                this.variableNameToArr.put(differentialFunction2.getOwnName() + ":1", arr7.dup(arr7.ordering()));
                                flowPath.markActive(differentialFunction2.getOwnName(), false);
                                flowPath.markActive(differentialFunction2.getOwnName() + ":1", true);
                            }
                            flowPath.markExecuted(differentialFunction2.getOwnName(), true);
                        } else if (differentialFunction2 instanceof If) {
                            If r0 = (If) differentialFunction2;
                            if (!z) {
                                r0.getPredicateExecution().exec();
                                if (r0.getTargetBoolean().getArr().sumNumber().doubleValue() > 0.0d) {
                                    r0.getLoopBodyExecution().exec();
                                    r0.exectedTrueOrFalse(true);
                                } else {
                                    r0.getFalseBodyExecution().exec();
                                    r0.exectedTrueOrFalse(false);
                                }
                            } else {
                                if (r0.getTrueBodyExecuted() == null) {
                                    throw new ND4JIllegalStateException("No body was run.");
                                }
                                Iterator<SDVariable> it = (r0.getTrueBodyExecuted().booleanValue() ? r0.getLoopBodyExecution().getVariablesAssociatedWithFunctions(r0.getLoopBodyExecution().execBackwards().getRight()) : r0.getFalseBodyExecution().getVariablesAssociatedWithFunctions(r0.getFalseBodyExecution().execBackwards().getRight())).iterator();
                                while (it.hasNext()) {
                                    var(it.next());
                                }
                            }
                            flowPath.markExecuted(differentialFunction2.getOwnName(), true);
                            arrayList.add(differentialFunction2);
                        } else if (differentialFunction2 instanceof While) {
                            While r02 = (While) differentialFunction2;
                            if (z) {
                                Iterator<SDVariable> it2 = r02.getLoopBodyExecution().execBackwards().getFirst().keySet().iterator();
                                while (it2.hasNext()) {
                                    it2.next().getArr().muli(Integer.valueOf(r02.getNumLooped()));
                                }
                            } else {
                                SameDiff loopBodyExecution = r02.getLoopBodyExecution();
                                r02.getPredicateExecution().exec();
                                while (r02.getTargetBoolean().getArr().sumNumber().doubleValue() > 0.0d) {
                                    loopBodyExecution.exec();
                                    r02.getPredicateExecution().exec();
                                    r02.incrementLoopCounter();
                                }
                                ArrayList arrayList3 = new ArrayList();
                                arrayList3.addAll(Arrays.asList(((DifferentialFunction) new ArrayList(loopBodyExecution.functionInstancesById.values()).get(loopBodyExecution.functionInstancesById.values().size() - 1)).outputVariables()));
                                r02.setOutputVars((SDVariable[]) arrayList3.toArray(new SDVariable[arrayList3.size()]));
                                arrayList.add(differentialFunction2);
                            }
                            flowPath.markExecuted(differentialFunction2.getOwnName(), true);
                        } else if (differentialFunction2 instanceof CustomOp) {
                            DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) differentialFunction2;
                            dynamicCustomOp.populateInputsAndOutputsFromSameDiff();
                            dynamicCustomOp.assertValidForExecution();
                            dynamicCustomOp.updateInputsFromSameDiff();
                            Nd4j.getExecutioner().exec(dynamicCustomOp);
                            flowPath.markExecuted(differentialFunction2.getOwnName(), true);
                            arrayList.add(dynamicCustomOp);
                        } else if (differentialFunction2 instanceof Op) {
                            SDVariable[] inputVariablesForFunction4 = getInputVariablesForFunction(differentialFunction2);
                            Op op = (Op) differentialFunction2;
                            op.setX(inputVariablesForFunction4[0].getArr());
                            if (inputVariablesForFunction4.length == 2) {
                                op.setY(inputVariablesForFunction4[1].getArr());
                            }
                            if (differentialFunction2.getDimensions() == null) {
                                Nd4j.getExecutioner().exec(op);
                            } else if (op.isExecSpecial()) {
                                op.exec();
                            } else {
                                int[] dimensions = differentialFunction2.getDimensions();
                                if (differentialFunction2 instanceof Accumulation) {
                                    Accumulation accumulation = (Accumulation) differentialFunction2;
                                    Nd4j.getExecutioner().exec(accumulation, dimensions);
                                    if (differentialFunction2.outputVariables()[0].getArr() == null) {
                                        SDVariable sDVariable = differentialFunction2.outputVariables()[0];
                                        updateArrayForVarName(sDVariable.getVarName(), accumulation.z());
                                        updateShapeForVarName(sDVariable.getVarName(), accumulation.z().shape());
                                    }
                                } else if (differentialFunction2 instanceof BroadcastOp) {
                                    Nd4j.getExecutioner().exec((BroadcastOp) differentialFunction2, dimensions);
                                } else if (differentialFunction2 instanceof GradientOp) {
                                    Nd4j.getExecutioner().exec(op);
                                } else if (differentialFunction2 instanceof IndexAccumulation) {
                                    Nd4j.getExecutioner().exec((IndexAccumulation) differentialFunction2, dimensions);
                                } else if (differentialFunction2 instanceof TransformOp) {
                                    Nd4j.getExecutioner().exec((TransformOp) differentialFunction2, dimensions);
                                }
                            }
                            flowPath.markExecuted(differentialFunction2.getOwnName(), true);
                            arrayList.add(differentialFunction2);
                        }
                    }
                }
            }
            i++;
        }
        return new Pair<>(hashMap, arrayList);
    }

    public void printFunction(DifferentialFunction differentialFunction) {
        if (this.logExecution && !(differentialFunction instanceof SDVariable)) {
            StringBuilder sb = new StringBuilder();
            for (SDVariable sDVariable : differentialFunction.args()) {
                sb.append(" Variable " + sDVariable.getVarName() + " Shape for " + Arrays.toString(sDVariable.getShape()));
            }
            for (SDVariable sDVariable2 : differentialFunction.outputVariables()) {
                sb.append("  Output variable " + sDVariable2.getVarName() + " is " + Arrays.toString(sDVariable2.getShape()));
            }
            StringBuilder sb2 = new StringBuilder();
            for (SDVariable sDVariable3 : differentialFunction.args()) {
                sb2.append(" Input shape for " + sDVariable3.getVarName() + " is  " + Arrays.toString(getShapeForVarName(sDVariable3.getVarName())));
            }
            for (SDVariable sDVariable4 : differentialFunction.outputVariables()) {
                sb2.append(" Output shape for " + sDVariable4.getVarName() + " is  " + Arrays.toString(getShapeForVarName(sDVariable4.getVarName())));
            }
        }
    }

    public static int[] permuteDataFormatForSameDiff(String str, boolean z) {
        String upperCase = str.toUpperCase();
        int[] iArr = new int[4];
        if (z) {
            iArr[0] = upperCase.indexOf(87);
            iArr[1] = upperCase.indexOf(67);
            iArr[2] = upperCase.indexOf(78);
            iArr[3] = upperCase.indexOf(72);
            return iArr;
        }
        for (int i = 0; i < upperCase.length(); i++) {
            if ("NCHW".indexOf(upperCase.charAt(i)) < 0) {
                throw new ND4JIllegalStateException("Illegal convolution data format string passed in " + upperCase + " must be some variant of NCHW");
            }
        }
        for (int i2 = 0; i2 < "NCHW".length(); i2++) {
            iArr[i2] = "NCHW".indexOf(upperCase.charAt(i2));
        }
        return iArr;
    }

    public void updateVariable(String str, INDArray iNDArray) {
        if (this.variableNameToArr.containsKey(str)) {
            updateArrayForVarName(str, iNDArray);
        } else {
            putArrayForVarName(str, iNDArray);
        }
    }

    protected int asFlatNode(String str, @NonNull SameDiff sameDiff, @NonNull FlatBufferBuilder flatBufferBuilder) {
        if (sameDiff == null) {
            throw new NullPointerException("scope");
        }
        if (flatBufferBuilder == null) {
            throw new NullPointerException("bufferBuilder");
        }
        int createString = flatBufferBuilder.createString(str);
        return FlatNode.createFlatNode(flatBufferBuilder, createString, createString, (byte) 119, 10L, 0, 0, 0, (byte) 0, 0, 0, 0, 0, -1, 0.0f, 0, 0);
    }

    public static Pair<String, Integer> parseVariable(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("varName");
        }
        if (!str.contains(":")) {
            return Pair.pairOf(str, 0);
        }
        String[] split = str.split(":");
        Integer valueOf = Integer.valueOf(split[split.length - 1]);
        if (split.length == 2) {
            return Pair.pairOf(split[0], valueOf);
        }
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < split.length - 1; i++) {
            sb.append(split[i]);
            if (i < split.length - 2) {
                sb.append(":");
            }
        }
        return Pair.pairOf(sb.toString(), valueOf);
    }

    protected int asFlatNode(@NonNull DifferentialFunction differentialFunction, @NonNull FlatBufferBuilder flatBufferBuilder, List<SDVariable> list, Map<String, Integer> map, Map<String, Integer> map2, Map<String, Integer> map3, AtomicInteger atomicInteger) {
        int[] iArr;
        if (differentialFunction == null) {
            throw new NullPointerException("node");
        }
        if (flatBufferBuilder == null) {
            throw new NullPointerException("bufferBuilder");
        }
        differentialFunction.opName();
        long opNum = getOpNum(differentialFunction.opName(), differentialFunction.opType());
        float[] fArr = differentialFunction.getExtraArgs() != null ? new float[differentialFunction.getExtraArgs().length] : new float[0];
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = ((Number) differentialFunction.getExtraArgs()[i]).floatValue();
        }
        if (differentialFunction.opType() == Op.Type.CUSTOM) {
            iArr = ((DynamicCustomOp) differentialFunction).iArgs();
        } else if (differentialFunction instanceof Enter) {
            String frameName = ((Enter) differentialFunction).getFrameName();
            if (!map3.containsKey(frameName)) {
                map3.put(frameName, Integer.valueOf(atomicInteger.incrementAndGet()));
            }
            iArr = new int[]{map3.get(frameName).intValue()};
        } else {
            iArr = new int[0];
        }
        ArrayList arrayList = new ArrayList();
        SDVariable[] outputVariables = differentialFunction.outputVariables();
        int[] iArr2 = new int[outputVariables.length];
        for (int i2 = 0; i2 < iArr2.length; i2++) {
            iArr2[i2] = list.indexOf(outputVariables[i2]);
        }
        SDVariable[] args = differentialFunction.args();
        log.trace("");
        for (SDVariable sDVariable : args) {
            Pair<String, Integer> parseVariable = parseVariable(sDVariable.getVarName());
            if (!map.containsKey(parseVariable.getFirst())) {
                if (!parseVariable.getFirst().contains("NextIteration")) {
                    throw new ND4JIllegalStateException("Unknown variable used in input: [" + parseVariable.getFirst() + PropertyAccessor.PROPERTY_KEY_SUFFIX);
                }
                int incrementAndGet = atomicInteger.incrementAndGet();
                map2.put(parseVariable.getFirst(), Integer.valueOf(incrementAndGet));
                map.put(parseVariable.getFirst(), Integer.valueOf(incrementAndGet));
            }
            arrayList.add(Integer.valueOf(IntPair.createIntPair(flatBufferBuilder, map.get(parseVariable.getFirst()).intValue(), parseVariable.getSecond().intValue())));
        }
        log.debug("Own Name: {}", differentialFunction.getOwnName());
        int intValue = map2.containsKey(differentialFunction.getOwnName()) ? map2.get(differentialFunction.getOwnName()).intValue() : atomicInteger.incrementAndGet();
        map.put(differentialFunction.getOwnName(), Integer.valueOf(intValue));
        int asFlatProperties = FunctionProperties.asFlatProperties(flatBufferBuilder, new ArrayList());
        int createInputVector = FlatNode.createInputVector(flatBufferBuilder, new int[0]);
        int createInputPairedVector = FlatNode.createInputPairedVector(flatBufferBuilder, Ints.toArray(arrayList));
        int createOutputVector = FlatNode.createOutputVector(flatBufferBuilder, iArr2);
        int createExtraParamsVector = FlatNode.createExtraParamsVector(flatBufferBuilder, fArr);
        int createExtraIntegerVector = FlatNode.createExtraIntegerVector(flatBufferBuilder, iArr);
        int createDimensionsVector = FlatNode.createDimensionsVector(flatBufferBuilder, differentialFunction.getDimensions() != null ? differentialFunction.getDimensions() : new int[0]);
        int createString = flatBufferBuilder.createString((outputVariables == null || outputVariables.length < 1 || outputVariables[0] == null) ? "" : outputVariables[0].getVarName());
        int createString2 = flatBufferBuilder.createString("");
        if (differentialFunction.opType() == null) {
            log.warn("Null-op node: {}", differentialFunction);
        }
        return FlatNode.createFlatNode(flatBufferBuilder, intValue, createString, getFlatOpType(differentialFunction.opType()), opNum, asFlatProperties, createInputVector, createInputPairedVector, (byte) 0, createOutputVector, createExtraParamsVector, createExtraIntegerVector, createDimensionsVector, -1, (differentialFunction.opType() != Op.Type.SCALAR || differentialFunction.getScalarValue() == null) ? 0.0f : differentialFunction.getScalarValue().floatValue(), 0, createString2);
    }

    public ByteBuffer asFlatBuffers(@NonNull ExecutorConfiguration executorConfiguration) {
        if (executorConfiguration == null) {
            throw new NullPointerException("configuration");
        }
        Nd4j.getExecutioner().commit();
        FlatBufferBuilder flatBufferBuilder = new FlatBufferBuilder(1024);
        AtomicInteger atomicInteger = new AtomicInteger(0);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        List<SDVariable> arrayList4 = new ArrayList<>(variables());
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Map<String, Integer> linkedHashMap2 = new LinkedHashMap<>();
        Map<String, Integer> linkedHashMap3 = new LinkedHashMap<>();
        int i = 0;
        for (SDVariable sDVariable : variables()) {
            log.debug("Exporting variable: [{}]", sDVariable.getVarName());
            if (sDVariable.getArr() != null && sDVariable.getShape() != null) {
                Pair<String, Integer> parseVariable = parseVariable(sDVariable.getVarName());
                linkedHashMap.put(parseVariable.getFirst(), Integer.valueOf(atomicInteger.incrementAndGet()));
                log.debug("Adding [{}] as [{}]", parseVariable.getFirst(), Integer.valueOf(atomicInteger.get()));
                arrayList.add(Integer.valueOf(FlatVariable.createFlatVariable(flatBufferBuilder, IntPair.createIntPair(flatBufferBuilder, atomicInteger.get(), 0), flatBufferBuilder.createString(sDVariable.getVarName()), 0, sDVariable.getArr().toFlatArray(flatBufferBuilder), -1)));
            }
        }
        Iterator<DifferentialFunction> it = this.functionInstancesById.values().iterator();
        while (it.hasNext()) {
            arrayList3.add(Integer.valueOf(asFlatNode(it.next(), flatBufferBuilder, arrayList4, linkedHashMap, linkedHashMap2, linkedHashMap3, atomicInteger)));
        }
        for (Map.Entry<String, SameDiff> entry : this.sameDiffFunctionInstances.entrySet()) {
            arrayList3.add(Integer.valueOf(asFlatNode(entry.getKey(), entry.getValue(), flatBufferBuilder)));
            List<SDVariable> arrayList5 = new ArrayList<>(entry.getValue().variables());
            for (SDVariable sDVariable2 : entry.getValue().variables()) {
                INDArray arr = sDVariable2.getArr();
                if (arr != null) {
                    int createString = flatBufferBuilder.createString(sDVariable2.getVarName());
                    int flatArray = arr.toFlatArray(flatBufferBuilder);
                    i++;
                    int createIntPair = IntPair.createIntPair(flatBufferBuilder, i, 0);
                    Pair<String, Integer> parseVariable2 = parseVariable(sDVariable2.getVarName());
                    linkedHashMap.put(parseVariable2.getFirst(), Integer.valueOf(i));
                    log.debug("Adding [{}] as [{}]", parseVariable2.getFirst(), Integer.valueOf(i));
                    arrayList.add(Integer.valueOf(FlatVariable.createFlatVariable(flatBufferBuilder, createIntPair, createString, 0, flatArray, -1)));
                }
            }
            Iterator<DifferentialFunction> it2 = entry.getValue().functionInstancesById.values().iterator();
            while (it2.hasNext()) {
                arrayList3.add(Integer.valueOf(asFlatNode(it2.next(), flatBufferBuilder, arrayList5, linkedHashMap, linkedHashMap2, linkedHashMap3, atomicInteger)));
            }
        }
        flatBufferBuilder.finish(FlatGraph.createFlatGraph(flatBufferBuilder, 119L, FlatGraph.createVariablesVector(flatBufferBuilder, Ints.toArray(arrayList)), FlatGraph.createNodesVector(flatBufferBuilder, Ints.toArray(arrayList3)), FlatGraph.createVariablesVector(flatBufferBuilder, Ints.toArray(arrayList2)), executorConfiguration.getFlatConfiguration(flatBufferBuilder)));
        return flatBufferBuilder.dataBuffer();
    }

    public ByteBuffer asFlatBuffers() {
        return asFlatBuffers(ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).executionMode(ExecutionMode.SEQUENTIAL).profilingMode(OpExecutioner.ProfilingMode.DISABLED).gatherTimings(true).build());
    }

    public static ByteOrder getOrderFromByte(byte b) {
        return b == 0 ? ByteOrder.LITTLE_ENDIAN : ByteOrder.BIG_ENDIAN;
    }

    public static byte getOrderAsByte() {
        return ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN) ? (byte) 1 : (byte) 0;
    }

    /* JADX WARN: Failed to calculate best type for var: r13v1 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Failed to calculate best type for var: r14v0 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException
     */
    /* JADX WARN: Not initialized variable reg: 13, insn: 0x00dd: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r13 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:81:0x00dd */
    /* JADX WARN: Not initialized variable reg: 14, insn: 0x00e2: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r14 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:83:0x00e2 */
    /* JADX WARN: Type inference failed for: r13v1, types: [java.io.BufferedOutputStream] */
    /* JADX WARN: Type inference failed for: r14v0, types: [java.lang.Throwable] */
    public void asFlatFile(@NonNull File file) throws IOException {
        ?? r13;
        ?? r14;
        if (file == null) {
            throw new NullPointerException("file");
        }
        ByteBuffer asFlatBuffers = asFlatBuffers();
        int position = asFlatBuffers.position();
        byte[] array = asFlatBuffers.array();
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        Throwable th = null;
        try {
            try {
                BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream);
                Throwable th2 = null;
                DataOutputStream dataOutputStream = new DataOutputStream(bufferedOutputStream);
                Throwable th3 = null;
                try {
                    try {
                        dataOutputStream.write(array, position, array.length - position);
                        if (dataOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    dataOutputStream.close();
                                } catch (Throwable th4) {
                                    th3.addSuppressed(th4);
                                }
                            } else {
                                dataOutputStream.close();
                            }
                        }
                        if (bufferedOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    bufferedOutputStream.close();
                                } catch (Throwable th5) {
                                    th2.addSuppressed(th5);
                                }
                            } else {
                                bufferedOutputStream.close();
                            }
                        }
                        if (fileOutputStream != null) {
                            if (0 == 0) {
                                fileOutputStream.close();
                                return;
                            }
                            try {
                                fileOutputStream.close();
                            } catch (Throwable th6) {
                                th.addSuppressed(th6);
                            }
                        }
                    } catch (Throwable th7) {
                        th3 = th7;
                        throw th7;
                    }
                } catch (Throwable th8) {
                    if (dataOutputStream != null) {
                        if (th3 != null) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th9) {
                                th3.addSuppressed(th9);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                    throw th8;
                }
            } catch (Throwable th10) {
                if (fileOutputStream != null) {
                    if (0 != 0) {
                        try {
                            fileOutputStream.close();
                        } catch (Throwable th11) {
                            th.addSuppressed(th11);
                        }
                    } else {
                        fileOutputStream.close();
                    }
                }
                throw th10;
            }
        } catch (Throwable th12) {
            if (r13 != 0) {
                if (r14 != 0) {
                    try {
                        r13.close();
                    } catch (Throwable th13) {
                        r14.addSuppressed(th13);
                    }
                } else {
                    r13.close();
                }
            }
            throw th12;
        }
    }

    /* JADX WARN: Failed to calculate best type for var: r14v0 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Failed to calculate best type for var: r15v0 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException
     */
    /* JADX WARN: Not initialized variable reg: 14, insn: 0x00f0: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r14 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:76:0x00f0 */
    /* JADX WARN: Not initialized variable reg: 15, insn: 0x00f5: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r15 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:78:0x00f5 */
    /* JADX WARN: Type inference failed for: r14v0, types: [java.io.BufferedOutputStream] */
    /* JADX WARN: Type inference failed for: r15v0, types: [java.lang.Throwable] */
    public void asFlatFile(@NonNull File file, @NonNull ExecutorConfiguration executorConfiguration) throws IOException {
        ?? r14;
        ?? r15;
        if (file == null) {
            throw new NullPointerException("file");
        }
        if (executorConfiguration == null) {
            throw new NullPointerException("configuration");
        }
        ByteBuffer asFlatBuffers = asFlatBuffers(executorConfiguration);
        int position = asFlatBuffers.position();
        byte[] array = asFlatBuffers.array();
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        Throwable th = null;
        try {
            try {
                BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream);
                Throwable th2 = null;
                DataOutputStream dataOutputStream = new DataOutputStream(bufferedOutputStream);
                Throwable th3 = null;
                try {
                    try {
                        dataOutputStream.write(array, position, array.length - position);
                        if (dataOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    dataOutputStream.close();
                                } catch (Throwable th4) {
                                    th3.addSuppressed(th4);
                                }
                            } else {
                                dataOutputStream.close();
                            }
                        }
                        if (bufferedOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    bufferedOutputStream.close();
                                } catch (Throwable th5) {
                                    th2.addSuppressed(th5);
                                }
                            } else {
                                bufferedOutputStream.close();
                            }
                        }
                        if (fileOutputStream != null) {
                            if (0 == 0) {
                                fileOutputStream.close();
                                return;
                            }
                            try {
                                fileOutputStream.close();
                            } catch (Throwable th6) {
                                th.addSuppressed(th6);
                            }
                        }
                    } catch (Throwable th7) {
                        th3 = th7;
                        throw th7;
                    }
                } catch (Throwable th8) {
                    if (dataOutputStream != null) {
                        if (th3 != null) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th9) {
                                th3.addSuppressed(th9);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                    throw th8;
                }
            } catch (Throwable th10) {
                if (r14 != 0) {
                    if (r15 != 0) {
                        try {
                            r14.close();
                        } catch (Throwable th11) {
                            r15.addSuppressed(th11);
                        }
                    } else {
                        r14.close();
                    }
                }
                throw th10;
            }
        } catch (Throwable th12) {
            if (fileOutputStream != null) {
                if (0 != 0) {
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th13) {
                        th.addSuppressed(th13);
                    }
                } else {
                    fileOutputStream.close();
                }
            }
            throw th12;
        }
    }

    public String asFlatPrint() {
        StringBuilder sb = new StringBuilder();
        FlatGraph rootAsFlatGraph = FlatGraph.getRootAsFlatGraph(asFlatBuffers());
        sb.append("\nExternal variables:\n\n");
        for (int i = 0; i < rootAsFlatGraph.variablesLength(); i++) {
            FlatVariable variables = rootAsFlatGraph.variables(i);
            INDArray createFromFlatArray = Nd4j.createFromFlatArray(variables.ndarray());
            sb.append(variables.id().first()).append(":<").append(variables.name()).append("> ").append(Arrays.toString(createFromFlatArray.shapeInfoDataBuffer().asInt())).append("; Values: ").append(Arrays.toString(createFromFlatArray.data().asFloat())).append(";\n");
        }
        Map<String, CustomOpDescriptor> customOperations = Nd4j.getExecutioner().getCustomOperations();
        sb.append("\nOps sequence:\n\n");
        for (int i2 = 0; i2 < rootAsFlatGraph.nodesLength(); i2++) {
            FlatNode nodes = rootAsFlatGraph.nodes(i2);
            log.info("{}:<{}>", Integer.valueOf(nodes.id()), nodes.name());
            sb.append(nodes.id()).append(":<").append(nodes.name()).append("> ").append(getTypeFromByte(nodes.opType()));
            if (getTypeFromByte(nodes.opType()) != Op.Type.CUSTOM) {
                sb.append(": ").append(nodes.opNum());
            } else {
                String str = null;
                for (String str2 : customOperations.keySet()) {
                    if (customOperations.get(str2).getHash() == nodes.opNum()) {
                        str = str2;
                    }
                }
                if (str == null) {
                    str = "unknown";
                }
                sb.append(": ").append(str);
            }
            sb.append("; Inputs: {");
            for (int i3 = 0; i3 < nodes.inputPairedLength(); i3++) {
                IntPair inputPaired = nodes.inputPaired(i3);
                sb.append(PropertyAccessor.PROPERTY_KEY_PREFIX).append(inputPaired.first()).append(":").append(inputPaired.second()).append(PropertyAccessor.PROPERTY_KEY_SUFFIX);
                if (i3 < nodes.inputPairedLength() - 1) {
                    sb.append(", ");
                }
            }
            sb.append("};");
            sb.append(" OpNum: {").append(nodes.opNum()).append("};");
            sb.append("\n");
        }
        return sb.toString();
    }

    public static DataBuffer.Type getDataTypeFromByte(byte b) {
        if (b == 5) {
            return DataBuffer.Type.FLOAT;
        }
        if (b == 6) {
            return DataBuffer.Type.DOUBLE;
        }
        if (b == 3) {
            return DataBuffer.Type.HALF;
        }
        throw new UnsupportedOperationException("Unsupported DataType: [" + ((int) b) + PropertyAccessor.PROPERTY_KEY_SUFFIX);
    }

    public static byte getDataTypeAsByte(DataBuffer.Type type) {
        switch (type) {
            case FLOAT:
                return (byte) 5;
            case DOUBLE:
                return (byte) 6;
            case HALF:
                return (byte) 3;
            case INT:
                return (byte) 9;
            case LONG:
                return (byte) 10;
            default:
                throw new ND4JIllegalStateException("Unknown or unsupported DataType used: [" + type + PropertyAccessor.PROPERTY_KEY_SUFFIX);
        }
    }

    public static long getOpNum(String str, Op.Type type) {
        if (type == Op.Type.LOOP) {
            return 0L;
        }
        if (type == Op.Type.RETURN) {
            return 40L;
        }
        if (type == Op.Type.IF) {
            return 30L;
        }
        if (type == Op.Type.CONDITIONAL) {
            return 10L;
        }
        if (type == Op.Type.MERGE) {
            return 60L;
        }
        if (type == Op.Type.LOOP_COND) {
            return 70L;
        }
        if (type == Op.Type.NEXT_ITERATION) {
            return 80L;
        }
        if (type == Op.Type.EXIT) {
            return 90L;
        }
        if (type == Op.Type.ENTER) {
            return 100L;
        }
        if (type != Op.Type.CUSTOM) {
            return Nd4j.getOpFactory().getOpNumByName(str);
        }
        if (Nd4j.getExecutioner().getCustomOperations().get(str.toLowerCase()) == null) {
            return 0L;
        }
        return Nd4j.getExecutioner().getCustomOperations().get(str.toLowerCase()).getHash();
    }

    public static Op.Type getTypeFromByte(byte b) {
        switch (b) {
            case 0:
                return Op.Type.TRANSFORM;
            case 1:
                return Op.Type.REDUCE;
            case 2:
                return Op.Type.INDEXREDUCE;
            case 3:
                return Op.Type.SCALAR;
            case 4:
                return Op.Type.BROADCAST;
            case 5:
                return Op.Type.PAIRWISE;
            case 6:
                return Op.Type.REDUCE3;
            case 7:
                return Op.Type.SUMMARYSTATS;
            case 8:
                return Op.Type.SHAPE;
            case 10:
                return Op.Type.RANDOM;
            case 11:
                return Op.Type.CUSTOM;
            case 119:
                return Op.Type.META;
            default:
                throw new UnsupportedOperationException("Unknown op type passed in: " + ((int) b));
        }
    }

    public static byte getFlatOpType(Op.Type type) {
        switch (type) {
            case SCALAR:
                return (byte) 3;
            case BROADCAST:
                return (byte) 4;
            case TRANSFORM:
            case SPECIAL:
                return (byte) 0;
            case REDUCE:
                return (byte) 1;
            case REDUCE3:
                return (byte) 6;
            case INDEXREDUCE:
                return (byte) 2;
            case RANDOM:
                return (byte) 10;
            case MERGE:
            case CONDITIONAL:
            case LOOP:
            case RETURN:
            case ENTER:
            case EXIT:
            case NEXT_ITERATION:
            case LOOP_COND:
            case IF:
                return (byte) 119;
            case CUSTOM:
                return (byte) 11;
            case SHAPE:
                return (byte) 8;
            case PAIRWISE:
                return (byte) 5;
            case SUMMARYSTATS:
                return (byte) 7;
            default:
                throw new UnsupportedOperationException("Unknown op type passed in: " + type);
        }
    }

    public static SameDiffBuilder builder() {
        return new SameDiffBuilder();
    }

    public SameDiff(Map<String[], DifferentialFunction> map, Map<String[], DifferentialFunction> map2, Map<String, String[]> map3, Map<String, String[]> map4, Map<String, int[]> map5, boolean z, Set<String> set, Map<String, String> map6, DifferentialFunctionFactory differentialFunctionFactory, Map<String, SDVariable> map7, Map<String, int[]> map8, Map<String, SDVariable> map9, Map<String, SDVariable> map10, Map<String, INDArray> map11, Map<String, List<DifferentialFunction>> map12, Map<String, List<DifferentialFunction>> map13, ThreadLocal<FlowPath> threadLocal, Map<String, List<String>> map14, Map<String, Map<String, Object>> map15, Map<String, List<String[]>> map16, Map<String, int[]> map17, Set<String> set2, IdentityHashMap<INDArray, SDVariable> identityHashMap, MemoryWorkspace memoryWorkspace, Map<String, SameDiffFunctionDefinition> map18, Map<String, SameDiff> map19, Set<String> set3, Map<String, DifferentialFunction> map20, Table<String, String, String> table, AtomicBoolean atomicBoolean, boolean z2, Map<int[], Op> map21, boolean z3, boolean z4) {
        this.shouldBootStrap = true;
        this.localFlowPath = new ThreadLocal<>();
        this.wasRegistered = new AtomicBoolean(false);
        this.resolvedVariables = false;
        this.logExecution = true;
        this.incomingArgs = map;
        this.outgoingArgs = map2;
        this.incomingArgsReverse = map3;
        this.outgoingArgsReverse = map4;
        this.permuteOrder = map5;
        this.shouldBootStrap = z;
        this.importedVarName = set;
        this.baseNameForFunctionInstanceId = map6;
        this.functionFactory = differentialFunctionFactory;
        this.variableMap = map7;
        this.variableNameToShape = map8;
        this.gradients = map9;
        this.forwardVarForGrad = map10;
        this.variableNameToArr = map11;
        this.functionsArgsFor = map12;
        this.functionOutputFor = map13;
        this.localFlowPath = threadLocal;
        this.propertiesToResolve = map14;
        this.propertiesForFunction = map15;
        this.placeHolderMap = map16;
        this.placeHolderOriginalShapes = map17;
        this.placeHolderVarNames = set2;
        this.reverseArrayLookup = identityHashMap;
        this.workspace = memoryWorkspace;
        this.sameDiffFunctionDefinitionMap = map18;
        this.sameDiffFunctionInstances = map19;
        this.placeHolderFunctions = set3;
        this.functionInstancesById = map20;
        this.fieldVariableResolutionMapping = table;
        this.wasRegistered = atomicBoolean;
        this.debugMode = z2;
        this.opsForResult = map21;
        this.resolvedVariables = z3;
        this.logExecution = z4;
    }

    public boolean isDebugMode() {
        return this.debugMode;
    }

    public boolean isLogExecution() {
        return this.logExecution;
    }

    public void setLogExecution(boolean z) {
        this.logExecution = z;
    }

    static {
        for (Method method : SameDiff.class.getDeclaredMethods()) {
            if (method.getReturnType().equals(SDVariable.class)) {
                opMethods.put(method.getName(), method);
            }
        }
    }
}
