package org.nd4j.tensorflow.conversion.graphrunner;

import java.io.Closeable;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.tensorflow.TF_Buffer;
import org.bytedeco.tensorflow.TF_Graph;
import org.bytedeco.tensorflow.TF_Operation;
import org.bytedeco.tensorflow.TF_Output;
import org.bytedeco.tensorflow.TF_Session;
import org.bytedeco.tensorflow.TF_SessionOptions;
import org.bytedeco.tensorflow.TF_Status;
import org.bytedeco.tensorflow.TF_Tensor;
import org.bytedeco.tensorflow.global.tensorflow;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
import org.nd4j.shade.protobuf.util.JsonFormat;
import org.nd4j.tensorflow.conversion.TensorflowConversion;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.class */
public class GraphRunner implements Closeable {
    private static final Logger log = LoggerFactory.getLogger(GraphRunner.class);
    private SavedModelConfig savedModelConfig;
    private TF_Graph graph;
    private TensorflowConversion conversion;
    private TF_Session session;
    private TF_SessionOptions options;
    private TF_Status status;
    private List<String> inputOrder;
    private List<String> outputOrder;
    private ConfigProto protoBufConfigProto;

    public GraphRunner(List<String> list, List<String> list2, TF_Graph tF_Graph, GraphDef graphDef) {
        this(list, list2, tF_Graph, graphDef, null);
    }

    public GraphRunner(List<String> list, List<String> list2, TF_Graph tF_Graph, GraphDef graphDef, ConfigProto configProto) {
        this.conversion = TensorflowConversion.getInstance();
        this.graph = tF_Graph;
        this.protoBufConfigProto = configProto;
        this.inputOrder = list;
        this.outputOrder = list2;
        initSessionAndStatusIfNeeded(graphDef);
    }

    public GraphRunner(byte[] bArr, List<String> list, List<String> list2) {
        this(bArr, list, list2, getAlignedWithNd4j());
    }

    public GraphRunner(String str, List<String> list, List<String> list2) {
        this(str, list, list2, getAlignedWithNd4j());
    }

    public GraphRunner(String str, List<String> list, List<String> list2, ConfigProto configProto) {
        this.conversion = TensorflowConversion.getInstance();
        try {
            this.inputOrder = list;
            this.outputOrder = list2;
            this.protoBufConfigProto = configProto;
            initOptionsIfNeeded();
            byte[] byteArray = IOUtils.toByteArray(new File(str).toURI());
            this.graph = this.conversion.loadGraph(byteArray, this.status);
            initSessionAndStatusIfNeeded(byteArray);
        } catch (Exception e) {
            throw new IllegalArgumentException("Unable to parse protobuf", e);
        }
    }

    public GraphRunner(byte[] bArr, List<String> list, List<String> list2, ConfigProto configProto) {
        this.conversion = TensorflowConversion.getInstance();
        try {
            this.inputOrder = list;
            this.outputOrder = list2;
            this.protoBufConfigProto = configProto;
            initOptionsIfNeeded();
            this.graph = this.conversion.loadGraph(bArr, this.status);
            initSessionAndStatusIfNeeded(bArr);
        } catch (Exception e) {
            throw new IllegalArgumentException("Unable to parse protobuf", e);
        }
    }

    public GraphRunner(List<String> list, List<String> list2, SavedModelConfig savedModelConfig) {
        this(list, list2, savedModelConfig, getAlignedWithNd4j());
    }

    public GraphRunner(List<String> list, List<String> list2, SavedModelConfig savedModelConfig, ConfigProto configProto) {
        this.conversion = TensorflowConversion.getInstance();
        try {
            this.savedModelConfig = savedModelConfig;
            this.protoBufConfigProto = configProto;
            this.inputOrder = list;
            this.outputOrder = list2;
            initOptionsIfNeeded();
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            this.graph = tensorflow.TF_NewGraph();
            this.session = this.conversion.loadSavedModel(savedModelConfig, this.options, null, this.graph, linkedHashMap, linkedHashMap2, this.status);
            this.inputOrder = new ArrayList(linkedHashMap.keySet());
            this.outputOrder = new ArrayList(linkedHashMap2.keySet());
            savedModelConfig.setSavedModelInputOrder(new ArrayList(linkedHashMap.values()));
            savedModelConfig.setSaveModelOutputOrder(new ArrayList(linkedHashMap2.values()));
        } catch (Exception e) {
            throw new IllegalArgumentException("Unable to parse protobuf", e);
        }
    }

    public GraphRunner(List<String> list, TF_Graph tF_Graph, GraphDef graphDef) {
        this(list, null, tF_Graph, graphDef, null);
    }

    public GraphRunner(List<String> list, TF_Graph tF_Graph, GraphDef graphDef, ConfigProto configProto) {
        this(list, null, tF_Graph, graphDef, configProto);
    }

    public GraphRunner(byte[] bArr, List<String> list) {
        this(bArr, list, getAlignedWithNd4j());
    }

    public GraphRunner(String str, List<String> list) {
        this(str, list, getAlignedWithNd4j());
    }

    public GraphRunner(String str, List<String> list, ConfigProto configProto) {
        this(str, list, (List<String>) null, configProto);
    }

    public GraphRunner(byte[] bArr, List<String> list, ConfigProto configProto) {
        this(bArr, list, (List<String>) null, configProto);
    }

    public GraphRunner(SavedModelConfig savedModelConfig) {
        this(savedModelConfig, getAlignedWithNd4j());
    }

    public GraphRunner(SavedModelConfig savedModelConfig, ConfigProto configProto) {
        this.conversion = TensorflowConversion.getInstance();
        try {
            this.savedModelConfig = savedModelConfig;
            this.protoBufConfigProto = configProto;
            initOptionsIfNeeded();
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            this.graph = tensorflow.TF_NewGraph();
            this.session = this.conversion.loadSavedModel(savedModelConfig, this.options, null, this.graph, linkedHashMap, linkedHashMap2, this.status);
            this.inputOrder = new ArrayList(linkedHashMap.keySet());
            this.outputOrder = new ArrayList(linkedHashMap2.keySet());
            savedModelConfig.setSavedModelInputOrder(new ArrayList(linkedHashMap.values()));
            savedModelConfig.setSaveModelOutputOrder(new ArrayList(linkedHashMap2.values()));
        } catch (Exception e) {
            throw new IllegalArgumentException("Unable to parse protobuf", e);
        }
    }

    public Map<String, INDArray> run(Map<String, INDArray> map) {
        if (this.graph == null) {
            throw new IllegalStateException("Graph not initialized.");
        }
        if (map.size() != this.inputOrder.size()) {
            throw new IllegalArgumentException("Number of inputs specified do not match number of arrays specified.");
        }
        if (this.savedModelConfig != null) {
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            HashMap hashMap = new HashMap();
            TF_Output tF_Output = new TF_Output(this.savedModelConfig.getSavedModelInputOrder().size());
            TF_Tensor[] tF_TensorArr = new TF_Tensor[this.savedModelConfig.getSavedModelInputOrder().size()];
            for (int i = 0; i < this.savedModelConfig.getSavedModelInputOrder().size(); i++) {
                String[] split = this.savedModelConfig.getSavedModelInputOrder().get(i).split(":");
                TF_Operation TF_GraphOperationByName = tensorflow.TF_GraphOperationByName(this.graph, split[0]);
                hashMap.put(this.savedModelConfig.getSavedModelInputOrder().get(i), TF_GraphOperationByName);
                tF_Output.position(i).oper(TF_GraphOperationByName).index(split.length > 1 ? Integer.parseInt(split[1]) : 0);
                tF_TensorArr[i] = this.conversion.tensorFromNDArray(map.get((this.inputOrder == null || this.inputOrder.isEmpty()) ? this.savedModelConfig.getSavedModelInputOrder().get(i) : this.inputOrder.get(i)));
            }
            tF_Output.position(0L);
            TF_Output tF_Output2 = new TF_Output(this.savedModelConfig.getSaveModelOutputOrder().size());
            for (int i2 = 0; i2 < this.savedModelConfig.getSaveModelOutputOrder().size(); i2++) {
                String[] split2 = this.savedModelConfig.getSaveModelOutputOrder().get(i2).split(":");
                TF_Operation TF_GraphOperationByName2 = tensorflow.TF_GraphOperationByName(this.graph, split2[0]);
                hashMap.put(this.savedModelConfig.getSaveModelOutputOrder().get(i2), TF_GraphOperationByName2);
                tF_Output2.position(i2).oper(TF_GraphOperationByName2).index(split2.length > 1 ? Integer.parseInt(split2[1]) : 0);
            }
            tF_Output2.position(0L);
            PointerPointer pointerPointer = new PointerPointer(tF_TensorArr);
            PointerPointer pointerPointer2 = new PointerPointer(this.savedModelConfig.getSaveModelOutputOrder().size());
            tensorflow.TF_SessionRun(this.session, (TF_Buffer) null, tF_Output, pointerPointer, tF_TensorArr.length, tF_Output2, pointerPointer2, this.savedModelConfig.getSaveModelOutputOrder().size(), (PointerPointer) null, 0, (TF_Buffer) null, this.status);
            if (tensorflow.TF_GetCode(this.status) != 0) {
                throw new IllegalStateException("ERROR: Unable to run session " + tensorflow.TF_Message(this.status).getString());
            }
            for (int i3 = 0; i3 < this.outputOrder.size(); i3++) {
                linkedHashMap.put((this.outputOrder == null || this.outputOrder.isEmpty()) ? this.savedModelConfig.getSaveModelOutputOrder().get(i3) : this.outputOrder.get(i3), this.conversion.ndArrayFromTensor(new TF_Tensor(pointerPointer2.get(i3))));
            }
            return linkedHashMap;
        }
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        HashMap hashMap2 = new HashMap();
        TF_Output tF_Output3 = new TF_Output(this.inputOrder.size());
        TF_Tensor[] tF_TensorArr2 = new TF_Tensor[this.inputOrder.size()];
        for (int i4 = 0; i4 < this.inputOrder.size(); i4++) {
            String[] split3 = this.inputOrder.get(i4).split(":");
            TF_Operation TF_GraphOperationByName3 = tensorflow.TF_GraphOperationByName(this.graph, split3[0]);
            hashMap2.put(this.inputOrder.get(i4), TF_GraphOperationByName3);
            tF_Output3.position(i4).oper(TF_GraphOperationByName3).index(split3.length > 1 ? Integer.parseInt(split3[1]) : 0);
            tF_TensorArr2[i4] = this.conversion.tensorFromNDArray(map.get(this.inputOrder.get(i4)));
        }
        tF_Output3.position(0L);
        TF_Output tF_Output4 = new TF_Output(this.outputOrder.size());
        for (int i5 = 0; i5 < this.outputOrder.size(); i5++) {
            String[] split4 = this.outputOrder.get(i5).split(":");
            TF_Operation TF_GraphOperationByName4 = tensorflow.TF_GraphOperationByName(this.graph, split4[0]);
            if (TF_GraphOperationByName4 == null) {
                throw new IllegalArgumentException("Illegal input found " + this.inputOrder.get(i5) + " - no op found! Mis specified name perhaps?");
            }
            hashMap2.put(this.outputOrder.get(i5), TF_GraphOperationByName4);
            tF_Output4.position(i5).oper(TF_GraphOperationByName4).index(split4.length > 1 ? Integer.parseInt(split4[1]) : 0);
        }
        tF_Output4.position(0L);
        PointerPointer pointerPointer3 = new PointerPointer(tF_TensorArr2);
        PointerPointer pointerPointer4 = new PointerPointer(this.outputOrder.size());
        tensorflow.TF_SessionRun(this.session, (TF_Buffer) null, tF_Output3, pointerPointer3, tF_TensorArr2.length, tF_Output4, pointerPointer4, this.outputOrder.size(), (PointerPointer) null, 0, (TF_Buffer) null, this.status);
        if (tensorflow.TF_GetCode(this.status) != 0) {
            throw new IllegalStateException("ERROR: Unable to run session " + tensorflow.TF_Message(this.status).getString());
        }
        for (int i6 = 0; i6 < this.outputOrder.size(); i6++) {
            linkedHashMap2.put(this.outputOrder.get(i6), this.conversion.ndArrayFromTensor(new TF_Tensor(pointerPointer4.get(i6))));
        }
        return linkedHashMap2;
    }

    private void initOptionsIfNeeded() {
        if (this.status == null) {
            this.status = tensorflow.TF_NewStatus();
        }
        if (this.options == null) {
            this.options = tensorflow.TF_NewSessionOptions();
            if (this.protoBufConfigProto != null) {
                tensorflow.TF_SetConfig(this.options, new BytePointer(this.protoBufConfigProto.toByteArray()), r0.getStringBytes().length, this.status);
                if (tensorflow.TF_GetCode(this.status) != 0) {
                    throw new IllegalStateException("ERROR: Unable to set value configuration:" + tensorflow.TF_Message(this.status).getString());
                }
            }
        }
    }

    private void initSessionAndStatusIfNeeded(GraphDef graphDef) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (int i = 0; i < graphDef.getNodeCount(); i++) {
            NodeDef node = graphDef.getNode(i);
            for (int i2 = 0; i2 < node.getInputCount(); i2++) {
                linkedHashSet.add(node.getInput(i2));
            }
        }
        if (this.outputOrder == null) {
            this.outputOrder = new ArrayList();
            log.trace("Attempting to automatically resolve tensorflow output names..");
            for (int i3 = 0; i3 < graphDef.getNodeCount(); i3++) {
                if (!linkedHashSet.contains(graphDef.getNode(i3).getName()) && !graphDef.getNode(i3).getOp().equals("Placeholder")) {
                    this.outputOrder.add(graphDef.getNode(i3).getName());
                }
            }
            if (this.outputOrder.size() > 1) {
                HashSet hashSet = new HashSet();
                for (String str : this.outputOrder) {
                    if (str.contains("/")) {
                        hashSet.add(str);
                    }
                }
                this.outputOrder.removeAll(hashSet);
            }
        }
        if (this.session == null) {
            initOptionsIfNeeded();
            this.session = tensorflow.TF_NewSession(this.graph, this.options, this.status);
            if (tensorflow.TF_GetCode(this.status) != 0) {
                throw new IllegalStateException("ERROR: Unable to open session " + tensorflow.TF_Message(this.status).getString());
            }
        }
    }

    private void initSessionAndStatusIfNeeded(byte[] bArr) {
        try {
            initSessionAndStatusIfNeeded(GraphDef.parseFrom(bArr));
        } catch (InvalidProtocolBufferException e) {
            e.printStackTrace();
        }
    }

    public static ConfigProto getAlignedWithNd4j() {
        ConfigProto.Builder addDeviceFilters = ConfigProto.getDefaultInstance().toBuilder().addDeviceFilters(TensorflowConversion.defaultDeviceForThread());
        try {
            if (Nd4j.getBackend().getClass().getName().toLowerCase().contains("jcu")) {
                addDeviceFilters.setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true).setPerProcessGpuMemoryFraction(0.5d).build());
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return addDeviceFilters.build();
    }

    public static ConfigProto fromJson(String str) {
        ConfigProto.Builder newBuilder = ConfigProto.newBuilder();
        try {
            JsonFormat.parser().merge(str, newBuilder);
            return ConfigProto.parseFrom(newBuilder.build().toByteString().toByteArray());
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public String sessionOptionsToJson() {
        try {
            return JsonFormat.printer().print(this.protoBufConfigProto);
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        if (this.session != null && this.status != null) {
            tensorflow.TF_CloseSession(this.session, this.status);
            tensorflow.TF_DeleteSession(this.session, this.status);
        }
        if (this.status != null && tensorflow.TF_GetCode(this.status) != 0) {
            throw new IllegalStateException("ERROR: Unable to delete session " + tensorflow.TF_Message(this.status).getString());
        }
        if (this.status != null) {
            tensorflow.TF_DeleteStatus(this.status);
        }
    }

    public List<String> getInputOrder() {
        return this.inputOrder;
    }

    public List<String> getOutputOrder() {
        return this.outputOrder;
    }

    public void setInputOrder(List<String> list) {
        this.inputOrder = list;
    }

    public void setOutputOrder(List<String> list) {
        this.outputOrder = list;
    }

    public ConfigProto getProtoBufConfigProto() {
        return this.protoBufConfigProto;
    }
}
