package org.nd4j.onnxruntime.runner;

import java.io.Closeable;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.UUID;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.CharPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.onnxruntime.AllocatorWithDefaultOptions;
import org.bytedeco.onnxruntime.Env;
import org.bytedeco.onnxruntime.MemoryInfo;
import org.bytedeco.onnxruntime.RunOptions;
import org.bytedeco.onnxruntime.Session;
import org.bytedeco.onnxruntime.SessionOptions;
import org.bytedeco.onnxruntime.Value;
import org.bytedeco.onnxruntime.ValueVector;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.onnxruntime.util.ONNXUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/onnxruntime/runner/OnnxRuntimeRunner.class */
public class OnnxRuntimeRunner implements Closeable {
    private static final Logger log = LoggerFactory.getLogger(OnnxRuntimeRunner.class);
    private Session session;
    private RunOptions runOptions;
    private MemoryInfo memoryInfo;
    private AllocatorWithDefaultOptions allocator;
    private SessionOptions sessionOptions;
    private static Env env;
    private Pointer bp;

    /* loaded from: input_file:org/nd4j/onnxruntime/runner/OnnxRuntimeRunner$OnnxRuntimeRunnerBuilder.class */
    public static class OnnxRuntimeRunnerBuilder {
        private String modelUri;

        OnnxRuntimeRunnerBuilder() {
        }

        public OnnxRuntimeRunnerBuilder modelUri(String str) {
            this.modelUri = str;
            return this;
        }

        public OnnxRuntimeRunner build() {
            return new OnnxRuntimeRunner(this.modelUri);
        }

        public String toString() {
            return "OnnxRuntimeRunner.OnnxRuntimeRunnerBuilder(modelUri=" + this.modelUri + ")";
        }
    }

    public OnnxRuntimeRunner(String str) {
        if (env == null) {
            env = new Env(ONNXUtils.getOnnxLogLevelFromLogger(log), new BytePointer("nd4j-serving-onnx-session-" + UUID.randomUUID().toString()));
            env.retainReference();
        }
        this.sessionOptions = new SessionOptions();
        this.sessionOptions.SetGraphOptimizationLevel(2);
        this.sessionOptions.SetIntraOpNumThreads(1);
        this.sessionOptions.retainReference();
        this.allocator = new AllocatorWithDefaultOptions();
        this.allocator.retainReference();
        this.bp = Loader.getPlatform().toLowerCase().startsWith("windows") ? new CharPointer(str) : new BytePointer(str);
        this.runOptions = new RunOptions();
        this.memoryInfo = MemoryInfo.CreateCpu(1, 0);
        this.session = new Session(env, this.bp, this.sessionOptions);
        this.session.retainReference();
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        if (this.session != null) {
            this.session.close();
        }
        this.sessionOptions.releaseReference();
        this.allocator.releaseReference();
        this.runOptions.releaseReference();
    }

    public Map<String, INDArray> exec(Map<String, INDArray> map) {
        long GetInputCount = this.session.GetInputCount();
        long GetOutputCount = this.session.GetOutputCount();
        PointerPointer pointerPointer = new PointerPointer(GetInputCount);
        PointerPointer pointerPointer2 = new PointerPointer(GetOutputCount);
        Value value = new Value(GetInputCount);
        for (int i = 0; i < GetInputCount; i++) {
            BytePointer GetInputName = this.session.GetInputName(i, this.allocator.asOrtAllocator());
            pointerPointer.put(i, GetInputName);
            Value tensor = ONNXUtils.getTensor(map.get(GetInputName.getString()), this.memoryInfo);
            Preconditions.checkState(tensor.IsTensor(), "Input must be a tensor.");
            value.position(i).put(tensor);
        }
        value.position(0L);
        for (int i2 = 0; i2 < GetOutputCount; i2++) {
            pointerPointer2.put(i2, this.session.GetOutputName(i2, this.allocator.asOrtAllocator()));
        }
        ValueVector Run = this.session.Run(this.runOptions, pointerPointer, value, GetInputCount, pointerPointer2, GetOutputCount);
        Run.retainReference();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i3 = 0; i3 < GetOutputCount; i3++) {
            Value value2 = Run.get(i3);
            value2.retainReference();
            this.session.GetOutputTypeInfo(i3);
            DataBuffer dataBuffer = ONNXUtils.getDataBuffer(value2);
            LongPointer GetShape = value2.GetTensorTypeAndShapeInfo().GetShape();
            if (GetShape != null) {
                long[] jArr = new long[(int) GetShape.capacity()];
                GetShape.get(jArr);
                linkedHashMap.put(pointerPointer2.get(BytePointer.class, i3).getString(), Nd4j.create(dataBuffer).reshape(jArr));
            } else {
                linkedHashMap.put(pointerPointer2.get(BytePointer.class, i3).getString(), Nd4j.create(dataBuffer));
            }
        }
        return linkedHashMap;
    }

    public static OnnxRuntimeRunnerBuilder builder() {
        return new OnnxRuntimeRunnerBuilder();
    }
}
