package org.nd4j.imports.graphmapper;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.IOUtils;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.shade.protobuf.TextFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/imports/graphmapper/BaseGraphMapper.class */
public abstract class BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE> implements GraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE> {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BaseGraphMapper.class);

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public Op.Type opTypeForNode(NODE_TYPE node_type) {
        DifferentialFunction mappedOp = getMappedOp(getOpType(node_type));
        if (mappedOp == null) {
            throw new NoOpNameFoundException("No op found with name " + getOpType(node_type));
        }
        return mappedOp.opType();
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public void mapProperties(DifferentialFunction differentialFunction, NODE_TYPE node_type, GRAPH_TYPE graph_type, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> map) {
        Map<String, PropertyMapping> map2 = map.get(getOpType(node_type));
        if (map2 == null || map2.isEmpty()) {
            return;
        }
        Iterator<Map.Entry<String, PropertyMapping>> it = map2.entrySet().iterator();
        while (it.hasNext()) {
            mapProperty(it.next().getKey(), differentialFunction, node_type, graph_type, sameDiff, map);
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public SameDiff importGraph(InputStream inputStream) {
        return importGraph(inputStream, (Map) Collections.emptyMap(), (OpImportFilter) null);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public SameDiff importGraph(InputStream inputStream, Map<String, ? extends OpImportOverride<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE>> map, OpImportFilter<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE> opImportFilter) {
        return importGraph((BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE>) readGraph(inputStream, map), (Map<String, ? extends OpImportOverride<BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE>, NODE_TYPE, ATTR_TYPE>>) map, (OpImportFilter<BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE>, NODE_TYPE, ATTR_TYPE>) opImportFilter);
    }

    /* JADX WARN: Failed to calculate best type for var: r12v0 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Failed to calculate best type for var: r13v0 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException
     */
    /* JADX WARN: Not initialized variable reg: 12, insn: 0x0105: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r12 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:64:0x0105 */
    /* JADX WARN: Not initialized variable reg: 13, insn: 0x010a: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r13 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:66:0x010a */
    /* JADX WARN: Type inference failed for: r0v31, types: [org.nd4j.shade.protobuf.Message] */
    /* JADX WARN: Type inference failed for: r12v0, types: [java.io.BufferedInputStream] */
    /* JADX WARN: Type inference failed for: r13v0, types: [java.lang.Throwable] */
    protected GRAPH_TYPE readGraph(InputStream inputStream, Map<String, ? extends OpImportOverride<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE>> map) {
        byte[] bArr = null;
        GRAPH_TYPE graph_type = null;
        try {
            bArr = IOUtils.toByteArray(inputStream);
            graph_type = parseGraphFrom(bArr);
        } catch (IOException e) {
            try {
                try {
                    BufferedInputStream bufferedInputStream = new BufferedInputStream(new ByteArrayInputStream(bArr));
                    Throwable th = null;
                    BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(bufferedInputStream));
                    Throwable th2 = null;
                    try {
                        try {
                            Message.Builder newGraphBuilder = getNewGraphBuilder();
                            StringBuilder sb = new StringBuilder();
                            while (true) {
                                String readLine = bufferedReader.readLine();
                                if (readLine == null) {
                                    break;
                                }
                                sb.append(readLine);
                            }
                            TextFormat.getParser().merge(sb.toString(), newGraphBuilder);
                            graph_type = newGraphBuilder.build();
                            if (bufferedReader != null) {
                                if (0 != 0) {
                                    try {
                                        bufferedReader.close();
                                    } catch (Throwable th3) {
                                        th2.addSuppressed(th3);
                                    }
                                } else {
                                    bufferedReader.close();
                                }
                            }
                            if (bufferedInputStream != null) {
                                if (0 != 0) {
                                    try {
                                        bufferedInputStream.close();
                                    } catch (Throwable th4) {
                                        th.addSuppressed(th4);
                                    }
                                } else {
                                    bufferedInputStream.close();
                                }
                            }
                        } finally {
                        }
                    } catch (Throwable th5) {
                        if (bufferedReader != null) {
                            if (th2 != null) {
                                try {
                                    bufferedReader.close();
                                } catch (Throwable th6) {
                                    th2.addSuppressed(th6);
                                }
                            } else {
                                bufferedReader.close();
                            }
                        }
                        throw th5;
                    }
                } finally {
                }
            } catch (Exception e2) {
                e2.printStackTrace();
            }
        }
        return graph_type;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public SameDiff importGraph(File file) {
        return importGraph(file, (Map) Collections.emptyMap(), (OpImportFilter) null);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public SameDiff importGraph(File file, Map<String, ? extends OpImportOverride<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE>> map, OpImportFilter<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE> opImportFilter) {
        try {
            FileInputStream fileInputStream = new FileInputStream(file);
            Throwable th = null;
            try {
                try {
                    SameDiff importGraph = importGraph((InputStream) fileInputStream, (Map) map, (OpImportFilter) opImportFilter);
                    if (fileInputStream != null) {
                        if (0 != 0) {
                            try {
                                fileInputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            fileInputStream.close();
                        }
                    }
                    return importGraph;
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new ND4JIllegalStateException("Error encountered loading graph file: " + file.getAbsolutePath(), e);
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public Map<String, NODE_TYPE> nameIndexForGraph(GRAPH_TYPE graph_type) {
        List<NODE_TYPE> nodeList = getNodeList(graph_type);
        HashMap hashMap = new HashMap();
        for (NODE_TYPE node_type : nodeList) {
            hashMap.put(getName(node_type), node_type);
        }
        return hashMap;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public Map<String, NODE_TYPE> nodesByName(GRAPH_TYPE graph_type) {
        List<NODE_TYPE> nodeList = getNodeList(graph_type);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < nodeList.size(); i++) {
            linkedHashMap.put(getName(nodeList.get(i)), nodeList.get(i));
        }
        return linkedHashMap;
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public SameDiff importGraph(GRAPH_TYPE graph_type) {
        return importGraph((BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE>) graph_type, (Map<String, ? extends OpImportOverride<BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE>, NODE_TYPE, ATTR_TYPE>>) Collections.emptyMap(), (OpImportFilter<BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE>, NODE_TYPE, ATTR_TYPE>) null);
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public SameDiff importGraph(GRAPH_TYPE graph_type, Map<String, ? extends OpImportOverride<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE>> map, OpImportFilter<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE> opImportFilter) {
        SameDiff create = SameDiff.create();
        ImportState<GRAPH_TYPE, TENSOR_TYPE> importState = new ImportState<>();
        importState.setSameDiff(create);
        importState.setGraph(graph_type);
        Map<String, TENSOR_TYPE> variablesForGraph = variablesForGraph(graph_type);
        importState.setVariables(variablesForGraph);
        new HashMap();
        for (Map.Entry<String, TENSOR_TYPE> entry : variablesForGraph.entrySet()) {
            if (!shouldSkip(entry.getValue())) {
                TENSOR_TYPE value = entry.getValue();
                String opType = getOpType(value);
                String name = getName(value);
                if (opImportFilter != null && opImportFilter.skipOp(value, importState.getSameDiff(), null, importState.getGraph())) {
                    log.info("Skipping variables for op: {} (name: {})", opType, name);
                } else if (map == null || !map.containsKey(opType)) {
                    DataType dataTypeForTensor = dataTypeForTensor(entry.getValue(), 0);
                    INDArray nDArrayFromTensor = getNDArrayFromTensor(entry.getKey(), entry.getValue(), graph_type);
                    long[] shape = hasShape(entry.getValue()) ? getShape(entry.getValue()) : null;
                    if (dataTypeForTensor == DataType.UNKNOWN) {
                        dataTypeForTensor = null;
                    }
                    if (isPlaceHolder(entry.getValue())) {
                        create.placeHolder(entry.getKey(), dataTypeForTensor, shape);
                    } else if (isConstant(entry.getValue())) {
                        Preconditions.checkNotNull(nDArrayFromTensor, "Array is null for placeholder variable %s", entry.getKey());
                        create.constant(entry.getKey(), nDArrayFromTensor);
                    } else {
                        SDVariable var = shape == null ? create.var(entry.getKey(), VariableType.ARRAY, null, dataTypeForTensor, (long[]) null) : create.var(entry.getKey(), dataTypeForTensor, shape);
                        if (nDArrayFromTensor != null) {
                            create.associateArrayWithVariable(nDArrayFromTensor, var);
                        }
                    }
                    List<String> controlDependencies = getControlDependencies(value);
                    if (controlDependencies != null) {
                        create.getVariables().get(entry.getKey()).setControlDeps(controlDependencies);
                    }
                } else {
                    log.info("Skipping variables for op due to presence of OpImportOverride: {} (name: {})", opType, name);
                }
            }
        }
        for (NODE_TYPE node_type : getNodeList(graph_type)) {
            String opType2 = getOpType(node_type);
            OpImportOverride<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE> opImportOverride = map != null ? map.get(opType2) : null;
            if (opImportFilter != null && opImportFilter.skipOp(node_type, importState.getSameDiff(), null, null)) {
                log.info("Skipping op due to op filter: {}", opType2, getName(node_type));
            } else if (!opsToIgnore().contains(opType2) || isOpIgnoreException(node_type)) {
                mapNodeType(node_type, importState, opImportOverride, opImportFilter);
            }
        }
        for (Variable variable : create.getVariables().values()) {
            if (!variable.getVariable().isPlaceHolder() && !variable.getVariable().isConstant()) {
                String name2 = variable.getName();
                String str = name2;
                if (variable.getName().matches(".*:\\d+")) {
                    str = name2.substring(0, name2.lastIndexOf(58));
                }
                if (create.getOps().containsKey(str)) {
                    variable.setOutputOfOp(str);
                    if (variable.getVariable().getVariableType() != VariableType.ARRAY) {
                        variable.getVariable().setVariableType(VariableType.ARRAY);
                    }
                }
            }
        }
        Iterator<SameDiffOp> it = create.getOps().values().iterator();
        while (it.hasNext()) {
            initOutputVariables(create, it.next().getOp());
        }
        for (Map.Entry<String, Variable> entry2 : create.getVariables().entrySet()) {
            Variable value2 = entry2.getValue();
            if (value2.getControlDeps() != null) {
                Iterator<String> it2 = value2.getControlDeps().iterator();
                while (it2.hasNext()) {
                    Variable variable2 = create.getVariables().get(it2.next());
                    if (variable2.getControlDepsForVar() == null) {
                        variable2.setControlDepsForVar(new ArrayList());
                    }
                    if (!variable2.getControlDepsForVar().contains(entry2.getKey())) {
                        variable2.getControlDepsForVar().add(entry2.getKey());
                    }
                }
            }
        }
        for (Map.Entry<String, SameDiffOp> entry3 : create.getOps().entrySet()) {
            SameDiffOp value3 = entry3.getValue();
            if (value3.getControlDeps() != null) {
                Iterator<String> it3 = value3.getControlDeps().iterator();
                while (it3.hasNext()) {
                    Variable variable3 = create.getVariables().get(it3.next());
                    if (variable3.getControlDepsForOp() == null) {
                        variable3.setControlDepsForOp(new ArrayList());
                    }
                    if (!variable3.getControlDepsForOp().contains(entry3.getKey())) {
                        variable3.getControlDepsForOp().add(entry3.getKey());
                    }
                }
            }
        }
        boolean z = false;
        Iterator<SDVariable> it4 = create.variables().iterator();
        while (it4.hasNext()) {
            if (it4.next().dataType() == null) {
                z = true;
            }
        }
        if (z) {
            Map<String, DataType> calculateOutputDataTypes = create.calculateOutputDataTypes();
            for (SDVariable sDVariable : create.variables()) {
                if (sDVariable.dataType() == null) {
                    sDVariable.setDataType(calculateOutputDataTypes.get(sDVariable.getVarName()));
                }
            }
        }
        validateGraphStructure(create);
        return create;
    }

    protected void initOutputVariables(SameDiff sameDiff, DifferentialFunction differentialFunction) {
        String[] outputsForOp = sameDiff.getOutputsForOp(differentialFunction);
        if (outputsForOp == null) {
            SDVariable[] generateOutputVariableForOp = sameDiff.generateOutputVariableForOp(differentialFunction, differentialFunction.getOwnName() != null ? differentialFunction.getOwnName() : differentialFunction.opName(), true);
            outputsForOp = new String[generateOutputVariableForOp.length];
            for (int i = 0; i < generateOutputVariableForOp.length; i++) {
                outputsForOp[i] = generateOutputVariableForOp[i].getVarName();
            }
            sameDiff.getOps().get(differentialFunction.getOwnName()).setOutputsOfOp(Arrays.asList(outputsForOp));
        }
        for (String str : outputsForOp) {
            sameDiff.getVariables().get(str).setOutputOfOp(differentialFunction.getOwnName());
        }
    }

    @Override // org.nd4j.imports.graphmapper.GraphMapper
    public boolean validTensorDataType(TENSOR_TYPE tensor_type) {
        return dataTypeForTensor(tensor_type, 0) != DataType.UNKNOWN;
    }

    public void validateGraphStructure(SameDiff sameDiff) {
        Iterator<SDVariable> it = sameDiff.variables().iterator();
        while (it.hasNext()) {
            String varName = it.next().getVarName();
            if (sameDiff.isPlaceHolder(varName)) {
                sameDiff.getVariables().get(varName).getOutputOfOp();
            }
        }
        for (SameDiffOp sameDiffOp : sameDiff.getOps().values()) {
            List<String> inputsToOp = sameDiffOp.getInputsToOp();
            if (inputsToOp != null) {
                for (String str : inputsToOp) {
                    if (sameDiff.getVariable(str) == null) {
                        throw new IllegalStateException("Import validation failed: op \"" + sameDiffOp.getName() + "\" of type " + sameDiffOp.getOp().getClass().getSimpleName() + " has input \"" + str + "\" that does not have a corresponding variable in the graph");
                    }
                }
            }
        }
    }
}
