package org.nd4j.linalg.api.ops.impl.controlflow;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import onnx.Onnx;
import org.apache.camel.util.URISupport;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.SameDiffConditional;
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/controlflow/While.class */
public class While extends DifferentialFunction implements CustomOp {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) While.class);
    private AtomicInteger startPosition;
    protected SameDiff loopBodyExecution;
    protected SameDiff predicateExecution;
    protected SameDiffConditional predicate;
    protected SameDiffFunctionDefinition trueBody;
    protected String blockName;
    protected String trueBodyName;
    protected SDVariable[] inputVars;
    protected SDVariable targetBoolean;
    protected SDVariable dummyResult;
    protected SDVariable[] outputVars;
    protected int numLooped;

    /* loaded from: input_file:org/nd4j/linalg/api/ops/impl/controlflow/While$WhileBuilder.class */
    public static class WhileBuilder {
        private String blockName;
        private SameDiff parent;
        private SDVariable[] inputVars;
        private SameDiffConditional predicate;
        private SameDiffFunctionDefinition condition;
        private SameDiffFunctionDefinition trueBody;

        WhileBuilder() {
        }

        public WhileBuilder blockName(String str) {
            this.blockName = str;
            return this;
        }

        public WhileBuilder parent(SameDiff sameDiff) {
            this.parent = sameDiff;
            return this;
        }

        public WhileBuilder inputVars(SDVariable[] sDVariableArr) {
            this.inputVars = sDVariableArr;
            return this;
        }

        public WhileBuilder predicate(SameDiffConditional sameDiffConditional) {
            this.predicate = sameDiffConditional;
            return this;
        }

        public WhileBuilder condition(SameDiffFunctionDefinition sameDiffFunctionDefinition) {
            this.condition = sameDiffFunctionDefinition;
            return this;
        }

        public WhileBuilder trueBody(SameDiffFunctionDefinition sameDiffFunctionDefinition) {
            this.trueBody = sameDiffFunctionDefinition;
            return this;
        }

        public While build() {
            return new While(this.blockName, this.parent, this.inputVars, this.predicate, this.condition, this.trueBody);
        }

        public String toString() {
            return "While.WhileBuilder(blockName=" + this.blockName + ", parent=" + this.parent + ", inputVars=" + Arrays.deepToString(this.inputVars) + ", predicate=" + this.predicate + ", condition=" + this.condition + ", trueBody=" + this.trueBody + URISupport.RAW_TOKEN_END;
        }
    }

    public While(AtomicInteger atomicInteger) {
        this.numLooped = 0;
        this.startPosition = atomicInteger;
    }

    public While(While r12) {
        this.numLooped = 0;
        this.sameDiff = r12.sameDiff;
        this.outputVars = r12.outputVars;
        this.loopBodyExecution = r12.loopBodyExecution;
        this.numLooped = r12.numLooped;
        this.dummyResult = r12.dummyResult;
        this.predicate = r12.predicate;
        this.predicateExecution = r12.predicateExecution;
        this.inputVars = r12.inputVars;
        this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(), new ZeroInitScheme('f'), DataType.FLOAT, 1);
    }

    public While(String str, SameDiff sameDiff, SDVariable[] sDVariableArr, SameDiffConditional sameDiffConditional, SameDiffFunctionDefinition sameDiffFunctionDefinition, SameDiffFunctionDefinition sameDiffFunctionDefinition2) {
        this.numLooped = 0;
        init(str, sameDiff, sDVariableArr, sameDiffConditional, sameDiffFunctionDefinition, sameDiffFunctionDefinition2);
    }

    private void init(String str, SameDiff sameDiff, SDVariable[] sDVariableArr, SameDiffConditional sameDiffConditional, SameDiffFunctionDefinition sameDiffFunctionDefinition, SameDiffFunctionDefinition sameDiffFunctionDefinition2) {
        this.sameDiff = sameDiff;
        this.inputVars = sDVariableArr;
        this.predicate = sameDiffConditional;
        this.trueBody = sameDiffFunctionDefinition2;
        this.blockName = str;
        this.dummyResult = sameDiff.var("dummyresult-" + UUID.randomUUID().toString(), new ZeroInitScheme('f'), DataType.FLOAT, 1);
        sameDiff.putOpForId(getOwnName(), this);
        sameDiff.addArgsFor(sDVariableArr, this);
        sameDiff.addOutgoingFor(new SDVariable[]{this.dummyResult}, this);
        SameDiff create = SameDiff.create();
        this.targetBoolean = sameDiffConditional.eval(create, sameDiffFunctionDefinition, sDVariableArr);
        this.predicateExecution = create;
        String str2 = "true-body-" + UUID.randomUUID().toString();
        this.trueBodyName = str2;
        sameDiff.defineFunction(str2, sameDiffFunctionDefinition2, sDVariableArr);
        sameDiff.defineFunction(str, sameDiffFunctionDefinition, sDVariableArr);
        sameDiff.putSubFunction("predicate-eval-body", create);
        this.loopBodyExecution = sameDiff.getFunction(str2);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public SDVariable[] outputVariables(String str) {
        return new SDVariable[]{this.dummyResult};
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(Arrays.asList(new WhileDerivative(this).outputVariables()));
        return arrayList;
    }

    public void incrementLoopCounter() {
        this.numLooped++;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        doImport(nodeDef, sameDiff, map, graphDef, new LinkedHashSet(), new AtomicInteger(0));
    }

    private void doImport(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef, Set<String> set, AtomicInteger atomicInteger) {
        String uuid = UUID.randomUUID().toString();
        set.add(nodeDef.getName());
        SameDiff create = SameDiff.create();
        SameDiff create2 = SameDiff.create();
        sameDiff.putSubFunction("condition-" + uuid, create);
        sameDiff.putSubFunction("loopbody-" + uuid, create2);
        this.loopBodyExecution = create2;
        this.predicateExecution = create;
        this.startPosition = atomicInteger;
        log.info("Adding 2 new scopes for WHILE {}");
        List<NodeDef> nodeList = graphDef.getNodeList();
        while (atomicInteger.get() < nodeList.size()) {
            NodeDef nodeDef2 = nodeList.get(atomicInteger.get());
            if (!nodeDef2.getOp().equalsIgnoreCase(Enter.OP_NAME)) {
                break;
            }
            set.add(nodeDef2.getName());
            SDVariable[] sDVariableArr = new SDVariable[nodeDef2.getInputCount()];
            for (int i = 0; i < nodeDef2.getInputCount(); i++) {
                String nodeName = TFGraphMapper.getInstance().getNodeName(nodeDef2.getInput(i));
                sDVariableArr[i] = sameDiff.getVariable(nodeName) == null ? sameDiff.var(nodeName, (LongShapeDescriptor) null, new ZeroInitScheme()) : sameDiff.getVariable(nodeName);
                create.var(sDVariableArr[i]);
                create2.var(sDVariableArr[i]);
            }
            this.inputVars = sDVariableArr;
            atomicInteger.incrementAndGet();
        }
        int i2 = 0;
        while (true) {
            if (atomicInteger.get() >= nodeList.size()) {
                break;
            }
            NodeDef nodeDef3 = nodeList.get(atomicInteger.get());
            if (!nodeDef3.getOp().equalsIgnoreCase(Merge.OP_NAME)) {
                create2.var(TFGraphMapper.getInstance().getNodeName(nodeDef3.getName()), (LongShapeDescriptor) null, new ZeroInitScheme());
                break;
            }
            set.add(nodeDef3.getName());
            SDVariable var = create2.var(TFGraphMapper.getInstance().getNodeName(nodeDef3.getName()), (LongShapeDescriptor) null, new ZeroInitScheme());
            create.var(var);
            sameDiff.var(var);
            i2++;
            atomicInteger.incrementAndGet();
        }
        while (true) {
            if (atomicInteger.get() >= nodeList.size()) {
                break;
            }
            NodeDef nodeDef4 = nodeList.get(atomicInteger.get());
            if (nodeDef4.getOp().equalsIgnoreCase("LoopCond")) {
                set.add(nodeDef4.getName());
                atomicInteger.incrementAndGet();
                break;
            }
            boolean equalsIgnoreCase = nodeDef4.getOp().equalsIgnoreCase("const");
            boolean startsWith = nodeDef4.getOp().startsWith("VariableV");
            boolean startsWith2 = nodeDef4.getOp().startsWith("Placeholder");
            if (equalsIgnoreCase || startsWith || startsWith2) {
                SDVariable var2 = create.var(nodeDef4.getName(), (LongShapeDescriptor) null, new ZeroInitScheme());
                create2.var(var2);
                sameDiff.var(var2);
                log.info("Adding condition var [{}]", var2.getVarName());
            } else if (!set.contains(nodeDef4.getName())) {
                DifferentialFunction differentialFunctionClassHolder = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(nodeDef4.getOp()).opName());
                differentialFunctionClassHolder.initFromTensorFlow(nodeDef4, create, nodeDef.getAttrMap(), graphDef);
                differentialFunctionClassHolder.setSameDiff(create2);
            }
            set.add(nodeDef4.getName());
            atomicInteger.incrementAndGet();
        }
        int i3 = 0;
        while (atomicInteger.get() < nodeList.size()) {
            NodeDef nodeDef5 = nodeList.get(atomicInteger.get());
            if (!nodeDef5.getOp().equalsIgnoreCase("Switch")) {
                break;
            }
            i3++;
            set.add(nodeDef5.getName());
            atomicInteger.incrementAndGet();
        }
        while (atomicInteger.get() < nodeList.size()) {
            NodeDef nodeDef6 = nodeList.get(atomicInteger.get());
            if (!nodeDef6.getOp().equalsIgnoreCase("Identity")) {
                break;
            }
            DifferentialFunction differentialFunctionClassHolder2 = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(nodeDef6.getOp()).opName());
            differentialFunctionClassHolder2.initFromTensorFlow(nodeDef6, sameDiff, nodeDef.getAttrMap(), graphDef);
            differentialFunctionClassHolder2.setSameDiff(create2);
            SDVariable[] sDVariableArr2 = new SDVariable[nodeDef6.getInputCount()];
            for (int i4 = 0; i4 < nodeDef6.getInputCount(); i4++) {
                if (sameDiff.getVariable(TFGraphMapper.getInstance().getNodeName(nodeDef6.getInput(i4))) == null) {
                    sDVariableArr2[i4] = sameDiff.var(nodeDef6.getInput(i4), (LongShapeDescriptor) null, new ZeroInitScheme());
                    create.var(sDVariableArr2[i4]);
                    create2.var(sDVariableArr2[i4]);
                } else {
                    sDVariableArr2[i4] = sameDiff.getVariable(TFGraphMapper.getInstance().getNodeName(nodeDef6.getInput(i4)));
                    create.var(sDVariableArr2[i4]);
                    create2.var(sDVariableArr2[i4]);
                }
            }
            create2.addArgsFor(sDVariableArr2, differentialFunctionClassHolder2);
            set.add(nodeDef6.getName());
            atomicInteger.incrementAndGet();
        }
        while (atomicInteger.get() < nodeList.size()) {
            NodeDef nodeDef7 = nodeList.get(atomicInteger.get());
            if (set.contains(nodeDef7.getName())) {
                log.info("Skipping: {}", nodeDef7.getName());
            } else {
                if (nodeDef7.getOp().equalsIgnoreCase("NextIteration")) {
                    break;
                }
                if (set.contains(nodeDef7.getName())) {
                    log.info("Skipping: {}", nodeDef7.getName());
                } else {
                    boolean equalsIgnoreCase2 = nodeDef7.getOp().equalsIgnoreCase("const");
                    boolean startsWith3 = nodeDef7.getOp().startsWith("VariableV");
                    boolean startsWith4 = nodeDef7.getOp().startsWith("Placeholder");
                    if (equalsIgnoreCase2 || startsWith3 || startsWith4) {
                        log.info("Adding body var [{}]", create2.var(nodeDef7.getName(), (LongShapeDescriptor) null, new ZeroInitScheme()).getVarName());
                    } else {
                        log.info("starting on [{}]: {}", nodeDef7.getName(), nodeDef7.getOp());
                        if (nodeDef7.getOp().equalsIgnoreCase(Enter.OP_NAME)) {
                            log.info("NEW LOOP ----------------------------------------");
                            While r0 = new While(atomicInteger);
                            r0.doImport(nodeDef, sameDiff, map, graphDef, set, atomicInteger);
                            r0.setSameDiff(sameDiff);
                            log.info("END LOOP ----------------------------------------");
                        } else {
                            DifferentialFunction differentialFunctionClassHolder3 = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(nodeDef7.getOp()).opName());
                            differentialFunctionClassHolder3.initFromTensorFlow(nodeDef7, sameDiff, nodeDef.getAttrMap(), graphDef);
                            differentialFunctionClassHolder3.setSameDiff(create);
                            SDVariable[] sDVariableArr3 = new SDVariable[nodeDef7.getInputCount()];
                            for (int i5 = 0; i5 < nodeDef7.getInputCount(); i5++) {
                                String nodeName2 = TFGraphMapper.getInstance().getNodeName(nodeDef7.getInput(i5));
                                sDVariableArr3[i5] = create.getVariable(nodeName2);
                                if (sDVariableArr3[i5] == null) {
                                    if (create2.getVariable(nodeName2) == null) {
                                        sDVariableArr3[i5] = create.var(sameDiff.getVariable(nodeName2));
                                    } else if (create2.getVariable(nodeName2) != null) {
                                        sDVariableArr3[i5] = create2.getVariable(nodeName2);
                                    } else {
                                        sDVariableArr3[i5] = create2.var(nodeName2, Nd4j.scalar(1.0d));
                                    }
                                }
                            }
                            create2.addArgsFor(sDVariableArr3, differentialFunctionClassHolder3);
                        }
                    }
                    set.add(nodeDef7.getName());
                }
            }
            atomicInteger.incrementAndGet();
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        while (atomicInteger.get() < nodeList.size()) {
            NodeDef nodeDef8 = nodeList.get(atomicInteger.get());
            if (!nodeDef8.getOp().equalsIgnoreCase("NextIteration")) {
                break;
            }
            set.add(nodeDef8.getName());
            String nodeName3 = TFGraphMapper.getInstance().getNodeName(nodeDef8.getName());
            arrayList.add(sameDiff.getVariable(nodeName3) == null ? sameDiff.var(nodeName3, (LongShapeDescriptor) null, new ZeroInitScheme()) : sameDiff.getVariable(nodeName3));
            atomicInteger.incrementAndGet();
        }
        this.outputVars = (SDVariable[]) arrayList2.toArray(new SDVariable[arrayList2.size()]);
        this.inputVars = (SDVariable[]) arrayList.toArray(new SDVariable[arrayList.size()]);
        sameDiff.addArgsFor(this.inputVars, this);
        sameDiff.addOutgoingFor(this.outputVars, this);
        while (atomicInteger.get() < nodeList.size()) {
            NodeDef nodeDef9 = nodeList.get(atomicInteger.get());
            if (!nodeDef9.getOp().equalsIgnoreCase("Exit")) {
                break;
            }
            set.add(nodeDef9.getName());
            String nodeName4 = TFGraphMapper.getInstance().getNodeName(nodeDef9.getName());
            SDVariable var3 = sameDiff.getVariable(nodeName4) == null ? sameDiff.var(nodeName4, (LongShapeDescriptor) null, new ZeroInitScheme()) : sameDiff.getVariable(nodeName4);
            atomicInteger.incrementAndGet();
        }
        DifferentialFunction[] ops = create.ops();
        if (ops.length < 1) {
            throw new ND4JIllegalArgumentException("No functions found!");
        }
        this.targetBoolean = ops[ops.length - 1].outputVariables()[0];
        log.info("-------------------------------------------");
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(Onnx.NodeProto nodeProto, SameDiff sameDiff, Map<String, Onnx.AttributeProto> map, Onnx.GraphProto graphProto) {
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String toString() {
        return opName();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "while";
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public long opHash() {
        return opName().hashCode();
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public boolean isInplaceCall() {
        return false;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public INDArray[] outputArguments() {
        return new INDArray[0];
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public INDArray[] inputArguments() {
        return new INDArray[0];
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public long[] iArgs() {
        return new long[0];
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public double[] tArgs() {
        return new double[0];
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void addIArgument(int... iArr) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void addIArgument(long... jArr) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void removeIArgument(Integer num) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public Long getIArgument(int i) {
        return null;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public int numIArguments() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void addTArgument(double... dArr) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void removeTArgument(Double d) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public Double getTArgument(int i) {
        return null;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public int numTArguments() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public int numBArguments() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void addInputArgument(INDArray... iNDArrayArr) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void removeInputArgument(INDArray iNDArray) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public boolean[] bArgs() {
        return new boolean[0];
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void addBArgument(boolean... zArr) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public Boolean getBArgument(int i) {
        return null;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public INDArray getInputArgument(int i) {
        return null;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public int numInputArguments() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void addOutputArgument(INDArray... iNDArrayArr) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void removeOutputArgument(INDArray iNDArray) {
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public INDArray getOutputArgument(int i) {
        return null;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public int numOutputArguments() {
        return 0;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape() {
        ArrayList arrayList = new ArrayList();
        for (SDVariable sDVariable : args()) {
            arrayList.add(this.sameDiff.getShapeDescriptorForVarName(sDVariable.getVarName()));
        }
        return arrayList;
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public CustomOpDescriptor getDescriptor() {
        return CustomOpDescriptor.builder().build();
    }

    @Override // org.nd4j.linalg.api.ops.CustomOp
    public void assertValidForExecution() {
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + opName());
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        throw new NoOpNameFoundException("No *singular (eg: use tensorflowNames() found for this op " + opName());
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String[] tensorflowNames() {
        throw new NoOpNameFoundException("This operation has no TF counterpart");
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Op.Type opType() {
        return Op.Type.LOOP;
    }

    public static WhileBuilder builder() {
        return new WhileBuilder();
    }

    public While() {
        this.numLooped = 0;
    }

    public SameDiff getLoopBodyExecution() {
        return this.loopBodyExecution;
    }

    public SameDiff getPredicateExecution() {
        return this.predicateExecution;
    }

    public SameDiffConditional getPredicate() {
        return this.predicate;
    }

    public SameDiffFunctionDefinition getTrueBody() {
        return this.trueBody;
    }

    public String getBlockName() {
        return this.blockName;
    }

    public String getTrueBodyName() {
        return this.trueBodyName;
    }

    public SDVariable[] getInputVars() {
        return this.inputVars;
    }

    public SDVariable getTargetBoolean() {
        return this.targetBoolean;
    }

    public SDVariable[] getOutputVars() {
        return this.outputVars;
    }

    public void setOutputVars(SDVariable[] sDVariableArr) {
        this.outputVars = sDVariableArr;
    }

    public int getNumLooped() {
        return this.numLooped;
    }
}
