package org.jetbrains.kotlinx.dl.api.inference.onnx;

import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.TensorInfo;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import kotlin.Metadata;
import kotlin.NotImplementedError;
import kotlin.Pair;
import kotlin.Unit;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.MapsKt;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import mu.KLogger;
import mu.KotlinLogging;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape;
import org.jetbrains.kotlinx.dl.api.extension.FloatArrayExtensionFunctionsKt;
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel;
import org.jetbrains.kotlinx.dl.api.inference.TensorFlowInferenceModel;

/* compiled from: OnnxInferenceModel.kt */
@Metadata(mv = {1, 5, 1}, k = 1, xi = 48, d1 = {"��n\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0016\n\u0002\b\u0006\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n��\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0010\b\n��\n\u0002\u0010\u0014\n\u0002\b\u0003\n\u0002\u0010 \n\u0002\u0010\u0011\n��\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\t\n\u0002\b\u0003\b\u0016\u0018�� -2\u00020\u0001:\u0001-B\u0005¢\u0006\u0002\u0010\u0002J\b\u0010\u0012\u001a\u00020\u0013H\u0016J\"\u0010\u0014\u001a\u00020\u00152\b\u0010\u0016\u001a\u0004\u0018\u00010\u00172\u0006\u0010\u0018\u001a\u00020\u00192\u0006\u0010\u001a\u001a\u00020\u0019H\u0016J\u0010\u0010\u001b\u001a\u00020\u001c2\u0006\u0010\u001d\u001a\u00020\u001eH\u0016J\u001e\u0010\u001b\u001a\u00020\u001c2\u0006\u0010\u001d\u001a\u00020\u001e2\u0006\u0010\u001f\u001a\u00020\u00172\u0006\u0010 \u001a\u00020\u0017J\u0018\u0010!\u001a\f\u0012\b\u0012\u0006\u0012\u0002\b\u00030#0\"2\u0006\u0010\u001d\u001a\u00020\u001eJ \u0010$\u001a\u0014\u0012\u0010\u0012\u000e\u0012\u0004\u0012\u00020&\u0012\u0004\u0012\u00020\u00060%0\"2\u0006\u0010\u001d\u001a\u00020\u001eJ\u000e\u0010'\u001a\u00020\u001e2\u0006\u0010\u001d\u001a\u00020\u001eJ\u0018\u0010'\u001a\u00020\u001e2\u0006\u0010\u001d\u001a\u00020\u001e2\u0006\u0010(\u001a\u00020\u0017H\u0016J\u0014\u0010)\u001a\u00020\u00132\n\u0010*\u001a\u00020\u0006\"\u00020+H\u0016J\b\u0010,\u001a\u00020\u0017H\u0016R\u000e\u0010\u0003\u001a\u00020\u0004X\u0082.¢\u0006\u0002\n��R\u0014\u0010\u0005\u001a\u00020\u00068VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u0007\u0010\bR\u001e\u0010\n\u001a\u00020\u00062\u0006\u0010\t\u001a\u00020\u0006@BX\u0086.¢\u0006\b\n��\u001a\u0004\b\u000b\u0010\bR\u000e\u0010\f\u001a\u00020\rX\u0082\u0004¢\u0006\u0002\n��R\u001e\u0010\u000e\u001a\u00020\u00062\u0006\u0010\t\u001a\u00020\u0006@BX\u0086.¢\u0006\b\n��\u001a\u0004\b\u000f\u0010\bR\u000e\u0010\u0010\u001a\u00020\u0011X\u0082.¢\u0006\u0002\n��¨\u0006."}, d2 = {"Lorg/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel;", "Lorg/jetbrains/kotlinx/dl/api/inference/InferenceModel;", "()V", "env", "Lai/onnxruntime/OrtEnvironment;", "inputDimensions", "", "getInputDimensions", "()[J", "<set-?>", "inputShape", "getInputShape", "logger", "Lmu/KLogger;", "outputShape", "getOutputShape", "session", "Lai/onnxruntime/OrtSession;", "close", "", "copy", "Lorg/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel;", "copiedModelName", "", "saveOptimizerState", "", "copyWeights", "predict", "", "inputData", "", "inputTensorName", "outputTensorName", "predictRaw", "", "", "predictRawWithShapes", "Lkotlin/Pair;", "Ljava/nio/FloatBuffer;", "predictSoftly", "predictionTensorName", "reshape", "dims", "", "toString", "Companion", "onnx"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.class */
public class OnnxInferenceModel extends InferenceModel {

    @NotNull
    public static final Companion Companion = new Companion(null);

    @NotNull
    private final KLogger logger = KotlinLogging.INSTANCE.logger(new Function0<Unit>() { // from class: org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel$logger$1
        public final void invoke() {
        }

        /* renamed from: invoke, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m10invoke() {
            invoke();
            return Unit.INSTANCE;
        }
    });
    private OrtEnvironment env;
    private OrtSession session;
    private long[] inputShape;
    private long[] outputShape;

    /* compiled from: OnnxInferenceModel.kt */
    @Metadata(mv = {1, 5, 1}, k = 1, xi = 48, d1 = {"��\u001c\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000e\n\u0002\b\u0003\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\u001d\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00042\u0006\u0010\u0006\u001a\u00020\u0007H��¢\u0006\u0002\b\bJ\u000e\u0010\t\u001a\u00020\u00042\u0006\u0010\u0006\u001a\u00020\u0007¨\u0006\n"}, d2 = {"Lorg/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel$Companion;", "", "()V", "initializeONNXModel", "Lorg/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel;", "model", "pathToModel", "", "initializeONNXModel$onnx", "load", "onnx"})
    /* loaded from: input_file:org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel$Companion.class */
    public static final class Companion {
        private Companion() {
        }

        @NotNull
        public final OnnxInferenceModel load(@NotNull String str) {
            Intrinsics.checkNotNullParameter(str, "pathToModel");
            return initializeONNXModel$onnx(new OnnxInferenceModel(), str);
        }

        @NotNull
        public final OnnxInferenceModel initializeONNXModel$onnx(@NotNull OnnxInferenceModel onnxInferenceModel, @NotNull String str) {
            Intrinsics.checkNotNullParameter(onnxInferenceModel, "model");
            Intrinsics.checkNotNullParameter(str, "pathToModel");
            if (!(onnxInferenceModel.env == null)) {
                throw new IllegalArgumentException(("The model " + onnxInferenceModel + " is initialized!").toString());
            }
            if (!(onnxInferenceModel.session == null)) {
                throw new IllegalArgumentException(("The model " + onnxInferenceModel + " is initialized!").toString());
            }
            if (!(onnxInferenceModel.inputShape == null)) {
                throw new IllegalArgumentException(("The model " + onnxInferenceModel + " is initialized!").toString());
            }
            if (!(onnxInferenceModel.outputShape == null)) {
                throw new IllegalArgumentException(("The model " + onnxInferenceModel + " is initialized!").toString());
            }
            OrtEnvironment environment = OrtEnvironment.getEnvironment();
            Intrinsics.checkNotNullExpressionValue(environment, "getEnvironment()");
            onnxInferenceModel.env = environment;
            OrtEnvironment ortEnvironment = onnxInferenceModel.env;
            if (ortEnvironment == null) {
                Intrinsics.throwUninitializedPropertyAccessException("env");
                throw null;
            }
            OrtSession createSession = ortEnvironment.createSession(str, new OrtSession.SessionOptions());
            Intrinsics.checkNotNullExpressionValue(createSession, "model.env.createSession(pathToModel, OrtSession.SessionOptions())");
            onnxInferenceModel.session = createSession;
            OrtSession ortSession = onnxInferenceModel.session;
            if (ortSession == null) {
                Intrinsics.throwUninitializedPropertyAccessException("session");
                throw null;
            }
            Map inputInfo = ortSession.getInputInfo();
            Intrinsics.checkNotNullExpressionValue(inputInfo, "model.session.inputInfo");
            TensorInfo info = ((NodeInfo) ((Pair) MapsKt.toList(inputInfo).get(0)).getSecond()).getInfo();
            if (info == null) {
                throw new NullPointerException("null cannot be cast to non-null type ai.onnxruntime.TensorInfo");
            }
            long[] shape = info.getShape();
            Intrinsics.checkNotNullExpressionValue(shape, "model.session.inputInfo.toList()[0].second.info as TensorInfo).shape");
            long[] longArray = CollectionsKt.toLongArray(ArraysKt.takeLast(shape, 3));
            onnxInferenceModel.inputShape = new TensorShape(1L, Arrays.copyOf(longArray, longArray.length)).dims();
            OrtSession ortSession2 = onnxInferenceModel.session;
            if (ortSession2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("session");
                throw null;
            }
            Map outputInfo = ortSession2.getOutputInfo();
            Intrinsics.checkNotNullExpressionValue(outputInfo, "model.session.outputInfo");
            TensorInfo info2 = ((NodeInfo) ((Pair) MapsKt.toList(outputInfo).get(0)).getSecond()).getInfo();
            if (info2 == null) {
                throw new NullPointerException("null cannot be cast to non-null type ai.onnxruntime.TensorInfo");
            }
            long[] shape2 = info2.getShape();
            Intrinsics.checkNotNullExpressionValue(shape2, "model.session.outputInfo.toList()[0].second.info as TensorInfo).shape");
            long[] longArray2 = CollectionsKt.toLongArray(ArraysKt.takeLast(shape2, 3));
            onnxInferenceModel.outputShape = new TensorShape(1L, Arrays.copyOf(longArray2, longArray2.length)).dims();
            return onnxInferenceModel;
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    @NotNull
    public final long[] getInputShape() {
        long[] jArr = this.inputShape;
        if (jArr != null) {
            return jArr;
        }
        Intrinsics.throwUninitializedPropertyAccessException("inputShape");
        throw null;
    }

    @NotNull
    public final long[] getOutputShape() {
        long[] jArr = this.outputShape;
        if (jArr != null) {
            return jArr;
        }
        Intrinsics.throwUninitializedPropertyAccessException("outputShape");
        throw null;
    }

    public void reshape(@NotNull long... jArr) {
        Intrinsics.checkNotNullParameter(jArr, "dims");
        this.inputShape = new TensorShape(1L, Arrays.copyOf(jArr, jArr.length)).dims();
    }

    @NotNull
    public TensorFlowInferenceModel copy(@Nullable String str, boolean z, boolean z2) {
        throw new NotImplementedError("An operation is not implemented: Not yet implemented");
    }

    @NotNull
    public long[] getInputDimensions() {
        return new TensorShape(getInputShape()).tail();
    }

    public int predict(@NotNull float[] fArr) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        return FloatArrayExtensionFunctionsKt.argmax(predictSoftly(fArr));
    }

    @NotNull
    public float[] predictSoftly(@NotNull float[] fArr, @NotNull String str) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        Intrinsics.checkNotNullParameter(str, "predictionTensorName");
        return predictSoftly(fArr);
    }

    @NotNull
    public final float[] predictSoftly(@NotNull float[] fArr) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        if (!(this.inputShape != null)) {
            throw new IllegalArgumentException("Reshape functions is missed! Define and set up the reshape function to transform initial data to the model input.".toString());
        }
        FloatBuffer wrap = FloatBuffer.wrap(fArr);
        OrtEnvironment ortEnvironment = this.env;
        if (ortEnvironment == null) {
            Intrinsics.throwUninitializedPropertyAccessException("env");
            throw null;
        }
        OnnxTensor createTensor = OnnxTensor.createTensor(ortEnvironment, wrap, getInputShape());
        OrtSession ortSession = this.session;
        if (ortSession == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            throw null;
        }
        OrtSession ortSession2 = this.session;
        if (ortSession2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            throw null;
        }
        Set inputNames = ortSession2.getInputNames();
        Intrinsics.checkNotNullExpressionValue(inputNames, "session.inputNames");
        OrtSession.Result run = ortSession.run(Collections.singletonMap(CollectionsKt.toList(inputNames).get(0), createTensor));
        Object value = run.get(0).getValue();
        if (value == null) {
            throw new NullPointerException("null cannot be cast to non-null type kotlin.Array<kotlin.FloatArray>");
        }
        run.close();
        createTensor.close();
        return ((float[][]) value)[0];
    }

    @NotNull
    public final List<Object[]> predictRaw(@NotNull float[] fArr) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        if (!(this.inputShape != null)) {
            throw new IllegalArgumentException("Reshape functions is missed! Define and set up the reshape function to transform initial data to the model input.".toString());
        }
        FloatBuffer wrap = FloatBuffer.wrap(fArr);
        OrtEnvironment ortEnvironment = this.env;
        if (ortEnvironment == null) {
            Intrinsics.throwUninitializedPropertyAccessException("env");
            throw null;
        }
        OnnxTensor createTensor = OnnxTensor.createTensor(ortEnvironment, wrap, getInputShape());
        OrtSession ortSession = this.session;
        if (ortSession == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            throw null;
        }
        OrtSession ortSession2 = this.session;
        if (ortSession2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            throw null;
        }
        Set inputNames = ortSession2.getInputNames();
        Intrinsics.checkNotNullExpressionValue(inputNames, "session.inputNames");
        Iterable run = ortSession.run(Collections.singletonMap(CollectionsKt.toList(inputNames).get(0), createTensor));
        ArrayList arrayList = new ArrayList();
        Intrinsics.checkNotNullExpressionValue(run, "output");
        Iterator it = run.iterator();
        while (it.hasNext()) {
            Object value = ((OnnxValue) ((Map.Entry) it.next()).getValue()).getValue();
            if (value == null) {
                throw new NullPointerException("null cannot be cast to non-null type kotlin.Array<*>");
            }
            arrayList.add((Object[]) value);
        }
        run.close();
        createTensor.close();
        return CollectionsKt.toList(arrayList);
    }

    @NotNull
    public final List<Pair<FloatBuffer, long[]>> predictRawWithShapes(@NotNull float[] fArr) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        if (!(this.inputShape != null)) {
            throw new IllegalArgumentException("Reshape functions is missed! Define and set up the reshape function to transform initial data to the model input.".toString());
        }
        FloatBuffer wrap = FloatBuffer.wrap(fArr);
        OrtEnvironment ortEnvironment = this.env;
        if (ortEnvironment == null) {
            Intrinsics.throwUninitializedPropertyAccessException("env");
            throw null;
        }
        OnnxTensor createTensor = OnnxTensor.createTensor(ortEnvironment, wrap, getInputShape());
        OrtSession ortSession = this.session;
        if (ortSession == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            throw null;
        }
        OrtSession ortSession2 = this.session;
        if (ortSession2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            throw null;
        }
        Set inputNames = ortSession2.getInputNames();
        Intrinsics.checkNotNullExpressionValue(inputNames, "session.inputNames");
        Iterable<Map.Entry> run = ortSession.run(Collections.singletonMap(CollectionsKt.toList(inputNames).get(0), createTensor));
        ArrayList arrayList = new ArrayList();
        Intrinsics.checkNotNullExpressionValue(run, "output");
        for (Map.Entry entry : run) {
            TensorInfo info = ((OnnxValue) entry.getValue()).getInfo();
            if (info == null) {
                throw new NullPointerException("null cannot be cast to non-null type ai.onnxruntime.TensorInfo");
            }
            long[] shape = info.getShape();
            Object value = entry.getValue();
            if (value == null) {
                throw new NullPointerException("null cannot be cast to non-null type ai.onnxruntime.OnnxTensor");
            }
            arrayList.add(new Pair(((OnnxTensor) value).getFloatBuffer(), shape));
        }
        run.close();
        createTensor.close();
        return CollectionsKt.toList(arrayList);
    }

    public final int predict(@NotNull float[] fArr, @NotNull String str, @NotNull String str2) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        Intrinsics.checkNotNullParameter(str, "inputTensorName");
        Intrinsics.checkNotNullParameter(str2, "outputTensorName");
        throw new NotImplementedError("An operation is not implemented: ONNX doesn't support extraction outputs from the intermediate levels of the model.");
    }

    public void close() {
        OrtSession ortSession = this.session;
        if (ortSession == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            throw null;
        }
        ortSession.close();
        OrtEnvironment ortEnvironment = this.env;
        if (ortEnvironment == null) {
            Intrinsics.throwUninitializedPropertyAccessException("env");
            throw null;
        }
        ortEnvironment.close();
    }

    @NotNull
    public String toString() {
        OrtSession ortSession = this.session;
        if (ortSession == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            throw null;
        }
        System.out.println(ortSession.getInputNames());
        OrtSession ortSession2 = this.session;
        if (ortSession2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            throw null;
        }
        System.out.println(ortSession2.getInputInfo());
        OrtSession ortSession3 = this.session;
        if (ortSession3 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            throw null;
        }
        System.out.println(ortSession3.getOutputNames());
        OrtSession ortSession4 = this.session;
        if (ortSession4 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            throw null;
        }
        System.out.println(ortSession4.getOutputInfo());
        StringBuilder append = new StringBuilder().append("OnnxModel(session=");
        OrtSession ortSession5 = this.session;
        if (ortSession5 != null) {
            return append.append(ortSession5).append(')').toString();
        }
        Intrinsics.throwUninitializedPropertyAccessException("session");
        throw null;
    }
}
