package org.jetbrains.kotlinx.dl.onnx.inference;

import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.TensorInfo;
import ai.onnxruntime.ValueInfo;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import kotlin.Metadata;
import kotlin.NotImplementedError;
import kotlin.Pair;
import kotlin.TuplesKt;
import kotlin.UByte;
import kotlin.collections.ArraysKt;
import kotlin.collections.CollectionsKt;
import kotlin.collections.IntIterator;
import kotlin.collections.MapsKt;
import kotlin.jdk7.AutoCloseableKt;
import kotlin.jvm.functions.Function0;
import kotlin.jvm.functions.Function1;
import kotlin.jvm.internal.DefaultConstructorMarker;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.LongSpreadBuilder;
import kotlin.ranges.IntRange;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.kotlinx.dl.api.core.shape.TensorShape;
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel;
import org.jetbrains.kotlinx.dl.api.summary.ModelSummary;
import org.jetbrains.kotlinx.dl.api.summary.ModelWithSummary;
import org.jetbrains.kotlinx.dl.impl.util.FloatArrayExtensionFunctionsKt;
import org.jetbrains.kotlinx.dl.onnx.inference.executionproviders.ExecutionProvider;

/* compiled from: OnnxInferenceModel.kt */
@Metadata(mv = {1, 7, 1}, k = 1, xi = 48, d1 = {"��¬\u0001\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n\u0002\b\u0002\n\u0002\u0010\u0012\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010 \n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0010\u0016\n\u0002\b\f\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0010\u0011\n\u0002\b\u0004\n\u0002\u0010\u000b\n\u0002\b\u0005\n\u0002\u0010\b\n��\n\u0002\u0010\u0014\n��\n\u0002\u0010$\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0010\t\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\b\u0016\u0018�� O2\u00020\u00012\u00020\u00022\u00020\u0003:\u0002OPB\u000f\b\u0016\u0012\u0006\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006B\u000f\b\u0016\u0012\u0006\u0010\u0007\u001a\u00020\b¢\u0006\u0002\u0010\tB\u0015\b\u0016\u0012\f\u0010\n\u001a\b\u0012\u0004\u0012\u00020\b0\u000b¢\u0006\u0002\u0010\fB\u000f\b\u0002\u0012\u0006\u0010\r\u001a\u00020\u000e¢\u0006\u0002\u0010\u000fJ\u0016\u0010*\u001a\u00020+2\f\u0010,\u001a\b\u0012\u0004\u0012\u00020\u00150\u0014H\u0002J\b\u0010-\u001a\u00020.H\u0016J#\u0010/\u001a\b\u0012\u0004\u0012\u00020\u00150\u00142\u000e\u00100\u001a\n\u0012\u0006\b\u0001\u0012\u00020\u001501H\u0002¢\u0006\u0002\u00102J\"\u00103\u001a\u00020��2\b\u00104\u001a\u0004\u0018\u00010\u00052\u0006\u00105\u001a\u0002062\u0006\u00107\u001a\u000206H\u0016J\b\u00108\u001a\u00020.H\u0002J!\u00109\u001a\u00020.2\u0012\u00100\u001a\n\u0012\u0006\b\u0001\u0012\u00020\u001501\"\u00020\u0015H\u0016¢\u0006\u0002\u0010:J\u0010\u0010;\u001a\u00020<2\u0006\u0010=\u001a\u00020>H\u0016J\u001a\u0010?\u001a\u000e\u0012\u0004\u0012\u00020\u0005\u0012\u0004\u0012\u00020A0@2\u0006\u0010=\u001a\u00020>J-\u0010?\u001a\u0002HB\"\u0004\b��\u0010B2\u0006\u0010=\u001a\u00020>2\u0012\u0010C\u001a\u000e\u0012\u0004\u0012\u00020E\u0012\u0004\u0012\u0002HB0D¢\u0006\u0002\u0010FJ\u000e\u0010G\u001a\u00020>2\u0006\u0010=\u001a\u00020>J\u0018\u0010G\u001a\u00020>2\u0006\u0010=\u001a\u00020>2\u0006\u0010H\u001a\u00020\u0005H\u0016J\u0014\u0010I\u001a\u00020.2\n\u0010J\u001a\u00020\u001c\"\u00020KH\u0016J\b\u0010L\u001a\u00020MH\u0016J\b\u0010N\u001a\u00020\u0005H\u0016R\u0016\u0010\u0010\u001a\n \u0012*\u0004\u0018\u00010\u00110\u0011X\u0082\u0004¢\u0006\u0002\n��R\u0014\u0010\u0013\u001a\b\u0012\u0004\u0012\u00020\u00150\u0014X\u0082.¢\u0006\u0002\n��R\u001e\u0010\u0018\u001a\u00020\u00172\u0006\u0010\u0016\u001a\u00020\u0017@BX\u0086.¢\u0006\b\n��\u001a\u0004\b\u0019\u0010\u001aR\u0014\u0010\u001b\u001a\u00020\u001c8VX\u0096\u0004¢\u0006\u0006\u001a\u0004\b\u001d\u0010\u001eR\u000e\u0010\u001f\u001a\u00020\u001cX\u0082.¢\u0006\u0002\n��R\u000e\u0010\r\u001a\u00020\u000eX\u0082\u0004¢\u0006\u0002\n��R\u001c\u0010 \u001a\u0004\u0018\u00010\u0005X\u0086\u000e¢\u0006\u000e\n��\u001a\u0004\b!\u0010\"\"\u0004\b#\u0010\u0006R\u001e\u0010$\u001a\u00020\u00172\u0006\u0010\u0016\u001a\u00020\u0017@BX\u0086.¢\u0006\b\n��\u001a\u0004\b%\u0010\u001aR\u001e\u0010&\u001a\u00020\u001c2\u0006\u0010\u0016\u001a\u00020\u001c@BX\u0086.¢\u0006\b\n��\u001a\u0004\b'\u0010\u001eR\u000e\u0010(\u001a\u00020)X\u0082.¢\u0006\u0002\n��¨\u0006Q"}, d2 = {"Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel;", "Lorg/jetbrains/kotlinx/dl/api/inference/InferenceModel;", "Lorg/jetbrains/kotlinx/dl/onnx/inference/ExecutionProviderCompatible;", "Lorg/jetbrains/kotlinx/dl/api/summary/ModelWithSummary;", "modelPath", "", "(Ljava/lang/String;)V", "modelBytes", "", "([B)V", "loadBytes", "Lkotlin/Function0;", "(Lkotlin/jvm/functions/Function0;)V", "modelSource", "Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$ModelSource;", "(Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$ModelSource;)V", "env", "Lai/onnxruntime/OrtEnvironment;", "kotlin.jvm.PlatformType", "executionProvidersInUse", "", "Lorg/jetbrains/kotlinx/dl/onnx/inference/executionproviders/ExecutionProvider;", "<set-?>", "Lai/onnxruntime/OnnxJavaType;", "inputDataType", "getInputDataType", "()Lai/onnxruntime/OnnxJavaType;", "inputDimensions", "", "getInputDimensions", "()[J", "inputShape", "name", "getName", "()Ljava/lang/String;", "setName", "outputDataType", "getOutputDataType", "outputShape", "getOutputShape", "session", "Lai/onnxruntime/OrtSession;", "buildSessionOptions", "Lai/onnxruntime/OrtSession$SessionOptions;", "uniqueProviders", "close", "", "collectProviders", "executionProviders", "", "([Lorg/jetbrains/kotlinx/dl/onnx/inference/executionproviders/ExecutionProvider;)Ljava/util/List;", "copy", "copiedModelName", "saveOptimizerState", "", "copyWeights", "initInputOutputInfo", "initializeWith", "([Lorg/jetbrains/kotlinx/dl/onnx/inference/executionproviders/ExecutionProvider;)V", "predict", "", "inputData", "", "predictRaw", "", "", "R", "extractResult", "Lkotlin/Function1;", "Lai/onnxruntime/OrtSession$Result;", "([FLkotlin/jvm/functions/Function1;)Ljava/lang/Object;", "predictSoftly", "predictionTensorName", "reshape", "dims", "", "summary", "Lorg/jetbrains/kotlinx/dl/api/summary/ModelSummary;", "toString", "Companion", "ModelSource", "onnx"})
/* loaded from: input_file:org/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel.class */
public class OnnxInferenceModel implements InferenceModel, ExecutionProviderCompatible, ModelWithSummary {

    @NotNull
    public static final Companion Companion = new Companion(null);

    @NotNull
    private final ModelSource modelSource;
    private final OrtEnvironment env;
    private OrtSession session;
    private long[] inputShape;
    private OnnxJavaType inputDataType;
    private long[] outputShape;
    private OnnxJavaType outputDataType;
    private List<? extends ExecutionProvider> executionProvidersInUse;

    @Nullable
    private String name;

    /* compiled from: OnnxInferenceModel.kt */
    @Metadata(mv = {1, 7, 1}, k = 1, xi = 48, d1 = {"��P\n\u0002\u0018\u0002\n\u0002\u0010��\n\u0002\b\u0002\n\u0002\u0010\u0002\n��\n\u0002\u0010\u0014\n��\n\u0002\u0010\u0016\n��\n\u0002\u0018\u0002\n��\n\u0002\u0010\u0012\n��\n\u0002\u0010\u0011\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010\u000e\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\b\u0086\u0003\u0018��2\u00020\u0001B\u0007\b\u0002¢\u0006\u0002\u0010\u0002J\u0018\u0010\u0003\u001a\u00020\u00042\u0006\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0007\u001a\u00020\bH\u0002J)\u0010\t\u001a\u00020\n2\u0006\u0010\u000b\u001a\u00020\f2\u0014\b\u0002\u0010\r\u001a\n\u0012\u0006\b\u0001\u0012\u00020\u000f0\u000e\"\u00020\u000f¢\u0006\u0002\u0010\u0010J)\u0010\t\u001a\u00020\n2\u0006\u0010\u0011\u001a\u00020\u00122\u0014\b\u0002\u0010\r\u001a\n\u0012\u0006\b\u0001\u0012\u00020\u000f0\u000e\"\u00020\u000f¢\u0006\u0002\u0010\u0013J$\u0010\u0014\u001a\u00020\u0015*\u00020\u00162\u0006\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0017\u001a\u00020\u00182\u0006\u0010\u0019\u001a\u00020\bH\u0002¨\u0006\u001a"}, d2 = {"Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$Companion;", "", "()V", "checkTensorMatchesInputShape", "", "data", "", "inputShape", "", "load", "Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel;", "modelBytes", "", "executionProviders", "", "Lorg/jetbrains/kotlinx/dl/onnx/inference/executionproviders/ExecutionProvider;", "([B[Lorg/jetbrains/kotlinx/dl/onnx/inference/executionproviders/ExecutionProvider;)Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel;", "pathToModel", "", "(Ljava/lang/String;[Lorg/jetbrains/kotlinx/dl/onnx/inference/executionproviders/ExecutionProvider;)Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel;", "createTensor", "Lai/onnxruntime/OnnxTensor;", "Lai/onnxruntime/OrtEnvironment;", "dataType", "Lai/onnxruntime/OnnxJavaType;", "shape", "onnx"})
    /* loaded from: input_file:org/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$Companion.class */
    public static final class Companion {

        /* compiled from: OnnxInferenceModel.kt */
        @Metadata(mv = {1, 7, 1}, k = 3, xi = 48)
        /* loaded from: input_file:org/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$Companion$WhenMappings.class */
        public /* synthetic */ class WhenMappings {
            public static final /* synthetic */ int[] $EnumSwitchMapping$0;

            static {
                int[] iArr = new int[OnnxJavaType.values().length];
                iArr[OnnxJavaType.FLOAT.ordinal()] = 1;
                iArr[OnnxJavaType.DOUBLE.ordinal()] = 2;
                iArr[OnnxJavaType.INT8.ordinal()] = 3;
                iArr[OnnxJavaType.INT16.ordinal()] = 4;
                iArr[OnnxJavaType.INT32.ordinal()] = 5;
                iArr[OnnxJavaType.INT64.ordinal()] = 6;
                iArr[OnnxJavaType.STRING.ordinal()] = 7;
                iArr[OnnxJavaType.UINT8.ordinal()] = 8;
                iArr[OnnxJavaType.UNKNOWN.ordinal()] = 9;
                $EnumSwitchMapping$0 = iArr;
            }
        }

        private Companion() {
        }

        /* JADX INFO: Access modifiers changed from: private */
        public final OnnxTensor createTensor(OrtEnvironment ortEnvironment, float[] fArr, OnnxJavaType onnxJavaType, long[] jArr) {
            OnnxTensor createTensor;
            checkTensorMatchesInputShape(fArr, jArr);
            switch (WhenMappings.$EnumSwitchMapping$0[onnxJavaType.ordinal()]) {
                case 1:
                    createTensor = OnnxTensor.createTensor(ortEnvironment, FloatBuffer.wrap(fArr), jArr);
                    break;
                case 2:
                    ArrayList arrayList = new ArrayList(fArr.length);
                    for (float f : fArr) {
                        arrayList.add(Double.valueOf(f));
                    }
                    createTensor = OnnxTensor.createTensor(ortEnvironment, DoubleBuffer.wrap(CollectionsKt.toDoubleArray(arrayList)), jArr);
                    break;
                case 3:
                    ArrayList arrayList2 = new ArrayList(fArr.length);
                    for (float f2 : fArr) {
                        arrayList2.add(Byte.valueOf((byte) f2));
                    }
                    createTensor = OnnxTensor.createTensor(ortEnvironment, ByteBuffer.wrap(CollectionsKt.toByteArray(arrayList2)), jArr);
                    break;
                case 4:
                    ArrayList arrayList3 = new ArrayList(fArr.length);
                    for (float f3 : fArr) {
                        arrayList3.add(Short.valueOf((short) f3));
                    }
                    createTensor = OnnxTensor.createTensor(ortEnvironment, ShortBuffer.wrap(CollectionsKt.toShortArray(arrayList3)), jArr);
                    break;
                case 5:
                    ArrayList arrayList4 = new ArrayList(fArr.length);
                    for (float f4 : fArr) {
                        arrayList4.add(Integer.valueOf((int) f4));
                    }
                    createTensor = OnnxTensor.createTensor(ortEnvironment, IntBuffer.wrap(CollectionsKt.toIntArray(arrayList4)), jArr);
                    break;
                case 6:
                    ArrayList arrayList5 = new ArrayList(fArr.length);
                    for (float f5 : fArr) {
                        arrayList5.add(Long.valueOf(f5));
                    }
                    createTensor = OnnxTensor.createTensor(ortEnvironment, LongBuffer.wrap(CollectionsKt.toLongArray(arrayList5)), jArr);
                    break;
                case 7:
                    throw new NotImplementedError((String) null, 1, (DefaultConstructorMarker) null);
                case 8:
                    ArrayList arrayList6 = new ArrayList(fArr.length);
                    for (float f6 : fArr) {
                        arrayList6.add(Byte.valueOf(UByte.constructor-impl((byte) f6)));
                    }
                    createTensor = OnnxTensor.createTensor(ortEnvironment, ByteBuffer.wrap(CollectionsKt.toByteArray(arrayList6)), jArr, OnnxJavaType.UINT8);
                    break;
                case 9:
                    throw new NotImplementedError((String) null, 1, (DefaultConstructorMarker) null);
                default:
                    throw new NotImplementedError((String) null, 1, (DefaultConstructorMarker) null);
            }
            OnnxTensor onnxTensor = createTensor;
            Intrinsics.checkNotNullExpressionValue(onnxTensor, "inputTensor");
            return onnxTensor;
        }

        private final void checkTensorMatchesInputShape(float[] fArr, long[] jArr) {
            if (jArr.length == 0) {
                throw new UnsupportedOperationException("Empty array can't be reduced.");
            }
            long j = jArr[0];
            IntIterator it = new IntRange(1, ArraysKt.getLastIndex(jArr)).iterator();
            while (it.hasNext()) {
                j *= jArr[it.nextInt()];
            }
            int i = (int) j;
            if (fArr.length == i) {
                return;
            }
            if (jArr.length == 4 && jArr[0] == 1 && ((jArr[1] == 3 || jArr[3] == 3) && fArr.length * 3 == i)) {
                StringBuilder append = new StringBuilder().append("The number of elements (N=").append(fArr.length).append(") in the input tensor does not match the model input shape - ");
                String arrays = Arrays.toString(jArr);
                Intrinsics.checkNotNullExpressionValue(arrays, "toString(this)");
                throw new IllegalArgumentException(append.append(arrays).append(". It looks like you are trying to use a 1-channel (grayscale) image as an input, but the model expects a 3-channel image.").toString());
            }
            StringBuilder append2 = new StringBuilder().append("The number of elements (N=").append(fArr.length).append(") in the input tensor does not match the model input shape - ");
            String arrays2 = Arrays.toString(jArr);
            Intrinsics.checkNotNullExpressionValue(arrays2, "toString(this)");
            throw new IllegalArgumentException(append2.append(arrays2).append('.').toString());
        }

        @NotNull
        public final OnnxInferenceModel load(@NotNull String str, @NotNull ExecutionProvider... executionProviderArr) {
            Intrinsics.checkNotNullParameter(str, "pathToModel");
            Intrinsics.checkNotNullParameter(executionProviderArr, "executionProviders");
            OnnxInferenceModel onnxInferenceModel = new OnnxInferenceModel(str);
            onnxInferenceModel.initializeWith((ExecutionProvider[]) Arrays.copyOf(executionProviderArr, executionProviderArr.length));
            return onnxInferenceModel;
        }

        public static /* synthetic */ OnnxInferenceModel load$default(Companion companion, String str, ExecutionProvider[] executionProviderArr, int i, Object obj) {
            if ((i & 2) != 0) {
                executionProviderArr = new ExecutionProvider.CPU[]{new ExecutionProvider.CPU(true)};
            }
            return companion.load(str, executionProviderArr);
        }

        @NotNull
        public final OnnxInferenceModel load(@NotNull byte[] bArr, @NotNull ExecutionProvider... executionProviderArr) {
            Intrinsics.checkNotNullParameter(bArr, "modelBytes");
            Intrinsics.checkNotNullParameter(executionProviderArr, "executionProviders");
            OnnxInferenceModel onnxInferenceModel = new OnnxInferenceModel(bArr);
            onnxInferenceModel.initializeWith((ExecutionProvider[]) Arrays.copyOf(executionProviderArr, executionProviderArr.length));
            return onnxInferenceModel;
        }

        public static /* synthetic */ OnnxInferenceModel load$default(Companion companion, byte[] bArr, ExecutionProvider[] executionProviderArr, int i, Object obj) {
            if ((i & 2) != 0) {
                executionProviderArr = new ExecutionProvider.CPU[]{new ExecutionProvider.CPU(true)};
            }
            return companion.load(bArr, executionProviderArr);
        }

        public /* synthetic */ Companion(DefaultConstructorMarker defaultConstructorMarker) {
            this();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* compiled from: OnnxInferenceModel.kt */
    @Metadata(mv = {1, 7, 1}, k = 1, xi = 48, d1 = {"��(\n\u0002\u0018\u0002\n\u0002\u0010��\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\br\u0018��2\u00020\u0001:\u0002\b\tJ\u0018\u0010\u0002\u001a\u00020\u00032\u0006\u0010\u0004\u001a\u00020\u00052\u0006\u0010\u0006\u001a\u00020\u0007H&\u0082\u0001\u0002\n\u000b¨\u0006\f"}, d2 = {"Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$ModelSource;", "", "buildSession", "Lai/onnxruntime/OrtSession;", "environment", "Lai/onnxruntime/OrtEnvironment;", "options", "Lai/onnxruntime/OrtSession$SessionOptions;", "Bytes", "File", "Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$ModelSource$Bytes;", "Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$ModelSource$File;", "onnx"})
    /* loaded from: input_file:org/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$ModelSource.class */
    public interface ModelSource {

        /* compiled from: OnnxInferenceModel.kt */
        @Metadata(mv = {1, 7, 1}, k = 1, xi = 48, d1 = {"��(\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\u0010\u0012\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\u0018��2\u00020\u0001B\u0013\u0012\f\u0010\u0002\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003¢\u0006\u0002\u0010\u0005J\u0018\u0010\u0006\u001a\u00020\u00072\u0006\u0010\b\u001a\u00020\t2\u0006\u0010\n\u001a\u00020\u000bH\u0016R\u0014\u0010\u0002\u001a\b\u0012\u0004\u0012\u00020\u00040\u0003X\u0082\u0004¢\u0006\u0002\n��¨\u0006\f"}, d2 = {"Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$ModelSource$Bytes;", "Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$ModelSource;", "loadBytes", "Lkotlin/Function0;", "", "(Lkotlin/jvm/functions/Function0;)V", "buildSession", "Lai/onnxruntime/OrtSession;", "environment", "Lai/onnxruntime/OrtEnvironment;", "options", "Lai/onnxruntime/OrtSession$SessionOptions;", "onnx"})
        /* loaded from: input_file:org/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$ModelSource$Bytes.class */
        public static final class Bytes implements ModelSource {

            @NotNull
            private final Function0<byte[]> loadBytes;

            public Bytes(@NotNull Function0<byte[]> function0) {
                Intrinsics.checkNotNullParameter(function0, "loadBytes");
                this.loadBytes = function0;
            }

            @Override // org.jetbrains.kotlinx.dl.onnx.inference.OnnxInferenceModel.ModelSource
            @NotNull
            public OrtSession buildSession(@NotNull OrtEnvironment ortEnvironment, @NotNull OrtSession.SessionOptions sessionOptions) {
                Intrinsics.checkNotNullParameter(ortEnvironment, "environment");
                Intrinsics.checkNotNullParameter(sessionOptions, "options");
                OrtSession createSession = ortEnvironment.createSession((byte[]) this.loadBytes.invoke(), sessionOptions);
                Intrinsics.checkNotNullExpressionValue(createSession, "environment.createSession(loadBytes(), options)");
                return createSession;
            }
        }

        /* compiled from: OnnxInferenceModel.kt */
        @Metadata(mv = {1, 7, 1}, k = 1, xi = 48, d1 = {"��$\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0010\u000e\n\u0002\b\u0002\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n��\u0018��2\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J\u0018\u0010\u0005\u001a\u00020\u00062\u0006\u0010\u0007\u001a\u00020\b2\u0006\u0010\t\u001a\u00020\nH\u0016R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n��¨\u0006\u000b"}, d2 = {"Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$ModelSource$File;", "Lorg/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$ModelSource;", "pathToModel", "", "(Ljava/lang/String;)V", "buildSession", "Lai/onnxruntime/OrtSession;", "environment", "Lai/onnxruntime/OrtEnvironment;", "options", "Lai/onnxruntime/OrtSession$SessionOptions;", "onnx"})
        /* loaded from: input_file:org/jetbrains/kotlinx/dl/onnx/inference/OnnxInferenceModel$ModelSource$File.class */
        public static final class File implements ModelSource {

            @NotNull
            private final String pathToModel;

            public File(@NotNull String str) {
                Intrinsics.checkNotNullParameter(str, "pathToModel");
                this.pathToModel = str;
            }

            @Override // org.jetbrains.kotlinx.dl.onnx.inference.OnnxInferenceModel.ModelSource
            @NotNull
            public OrtSession buildSession(@NotNull OrtEnvironment ortEnvironment, @NotNull OrtSession.SessionOptions sessionOptions) {
                Intrinsics.checkNotNullParameter(ortEnvironment, "environment");
                Intrinsics.checkNotNullParameter(sessionOptions, "options");
                OrtSession createSession = ortEnvironment.createSession(this.pathToModel, sessionOptions);
                Intrinsics.checkNotNullExpressionValue(createSession, "environment.createSession(pathToModel, options)");
                return createSession;
            }
        }

        @NotNull
        OrtSession buildSession(@NotNull OrtEnvironment ortEnvironment, @NotNull OrtSession.SessionOptions sessionOptions);
    }

    private OnnxInferenceModel(ModelSource modelSource) {
        this.modelSource = modelSource;
        this.env = OrtEnvironment.getEnvironment();
    }

    @NotNull
    public final OnnxJavaType getInputDataType() {
        OnnxJavaType onnxJavaType = this.inputDataType;
        if (onnxJavaType != null) {
            return onnxJavaType;
        }
        Intrinsics.throwUninitializedPropertyAccessException("inputDataType");
        return null;
    }

    @NotNull
    public final long[] getOutputShape() {
        long[] jArr = this.outputShape;
        if (jArr != null) {
            return jArr;
        }
        Intrinsics.throwUninitializedPropertyAccessException("outputShape");
        return null;
    }

    @NotNull
    public final OnnxJavaType getOutputDataType() {
        OnnxJavaType onnxJavaType = this.outputDataType;
        if (onnxJavaType != null) {
            return onnxJavaType;
        }
        Intrinsics.throwUninitializedPropertyAccessException("outputDataType");
        return null;
    }

    @Nullable
    public final String getName() {
        return this.name;
    }

    public final void setName(@Nullable String str) {
        this.name = str;
    }

    /* JADX WARN: 'this' call moved to the top of the method (can break code semantics) */
    public OnnxInferenceModel(@NotNull String str) {
        this(new ModelSource.File(str));
        Intrinsics.checkNotNullParameter(str, "modelPath");
    }

    /* JADX WARN: 'this' call moved to the top of the method (can break code semantics) */
    public OnnxInferenceModel(@NotNull final byte[] bArr) {
        this(new ModelSource.Bytes(new Function0<byte[]>() { // from class: org.jetbrains.kotlinx.dl.onnx.inference.OnnxInferenceModel.1
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(0);
            }

            @NotNull
            /* renamed from: invoke, reason: merged with bridge method [inline-methods] */
            public final byte[] m60invoke() {
                return bArr;
            }
        }));
        Intrinsics.checkNotNullParameter(bArr, "modelBytes");
    }

    /* JADX WARN: 'this' call moved to the top of the method (can break code semantics) */
    public OnnxInferenceModel(@NotNull Function0<byte[]> function0) {
        this(new ModelSource.Bytes(function0));
        Intrinsics.checkNotNullParameter(function0, "loadBytes");
    }

    @Override // org.jetbrains.kotlinx.dl.onnx.inference.ExecutionProviderCompatible
    public void initializeWith(@NotNull ExecutionProvider... executionProviderArr) {
        Intrinsics.checkNotNullParameter(executionProviderArr, "executionProviders");
        List<ExecutionProvider> collectProviders = collectProviders(executionProviderArr);
        if (this.executionProvidersInUse != null) {
            List<? extends ExecutionProvider> list = this.executionProvidersInUse;
            if (list == null) {
                Intrinsics.throwUninitializedPropertyAccessException("executionProvidersInUse");
                list = null;
            }
            if (Intrinsics.areEqual(collectProviders, list)) {
                return;
            }
        }
        if (this.session != null) {
            OrtSession ortSession = this.session;
            if (ortSession == null) {
                Intrinsics.throwUninitializedPropertyAccessException("session");
                ortSession = null;
            }
            ortSession.close();
        }
        ModelSource modelSource = this.modelSource;
        OrtEnvironment ortEnvironment = this.env;
        Intrinsics.checkNotNullExpressionValue(ortEnvironment, "env");
        this.session = modelSource.buildSession(ortEnvironment, buildSessionOptions(collectProviders));
        this.executionProvidersInUse = collectProviders;
        initInputOutputInfo();
    }

    private final void initInputOutputInfo() {
        OrtSession ortSession = this.session;
        if (ortSession == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            ortSession = null;
        }
        Map inputInfo = ortSession.getInputInfo();
        Intrinsics.checkNotNullExpressionValue(inputInfo, "session.inputInfo");
        TensorInfo info = ((NodeInfo) ((Pair) MapsKt.toList(inputInfo).get(0)).getSecond()).getInfo();
        Intrinsics.checkNotNull(info, "null cannot be cast to non-null type ai.onnxruntime.TensorInfo");
        TensorInfo tensorInfo = info;
        if (this.inputShape == null) {
            long[] shape = tensorInfo.getShape();
            Intrinsics.checkNotNullExpressionValue(shape, "inputTensorInfo.shape");
            long[] longArray = CollectionsKt.toLongArray(ArraysKt.takeLast(shape, 3));
            this.inputShape = new TensorShape(1L, Arrays.copyOf(longArray, longArray.length)).dims();
        }
        OnnxJavaType onnxJavaType = tensorInfo.type;
        Intrinsics.checkNotNullExpressionValue(onnxJavaType, "inputTensorInfo.type");
        this.inputDataType = onnxJavaType;
        OrtSession ortSession2 = this.session;
        if (ortSession2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            ortSession2 = null;
        }
        Map outputInfo = ortSession2.getOutputInfo();
        Intrinsics.checkNotNullExpressionValue(outputInfo, "session.outputInfo");
        TensorInfo info2 = ((NodeInfo) ((Pair) MapsKt.toList(outputInfo).get(0)).getSecond()).getInfo();
        Intrinsics.checkNotNull(info2, "null cannot be cast to non-null type ai.onnxruntime.TensorInfo");
        TensorInfo tensorInfo2 = info2;
        if (this.outputShape == null) {
            long[] shape2 = tensorInfo2.getShape();
            Intrinsics.checkNotNullExpressionValue(shape2, "outputTensorInfo.shape");
            long[] longArray2 = CollectionsKt.toLongArray(ArraysKt.takeLast(shape2, 3));
            this.outputShape = new TensorShape(1L, Arrays.copyOf(longArray2, longArray2.length)).dims();
        }
        OnnxJavaType onnxJavaType2 = tensorInfo2.type;
        Intrinsics.checkNotNullExpressionValue(onnxJavaType2, "outputTensorInfo.type");
        this.outputDataType = onnxJavaType2;
    }

    private final OrtSession.SessionOptions buildSessionOptions(List<? extends ExecutionProvider> list) {
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        Iterator<? extends ExecutionProvider> it = list.iterator();
        while (it.hasNext()) {
            it.next().addOptionsTo(sessionOptions);
        }
        return sessionOptions;
    }

    private final List<ExecutionProvider> collectProviders(ExecutionProvider[] executionProviderArr) {
        int i;
        for (ExecutionProvider executionProvider : executionProviderArr) {
            if (!OrtEnvironment.getAvailableProviders().contains(executionProvider.getInternalProviderId())) {
                throw new IllegalArgumentException(("The optimized execution provider " + executionProvider + " is not available in the current environment!").toString());
            }
        }
        List<ExecutionProvider> mutableList = CollectionsKt.toMutableList(ArraysKt.distinct(executionProviderArr));
        List<ExecutionProvider> list = mutableList;
        if ((list instanceof Collection) && list.isEmpty()) {
            i = 0;
        } else {
            int i2 = 0;
            Iterator<T> it = list.iterator();
            while (it.hasNext()) {
                if (((ExecutionProvider) it.next()) instanceof ExecutionProvider.CPU) {
                    i2++;
                    if (i2 < 0) {
                        CollectionsKt.throwCountOverflow();
                    }
                }
            }
            i = i2;
        }
        switch (i) {
            case 0:
                mutableList.add(new ExecutionProvider.CPU(true));
                break;
            case 1:
                for (Object obj : mutableList) {
                    if (((ExecutionProvider) obj) instanceof ExecutionProvider.CPU) {
                        ExecutionProvider executionProvider2 = (ExecutionProvider) obj;
                        mutableList.remove(executionProvider2);
                        mutableList.add(executionProvider2);
                        break;
                    }
                }
                throw new NoSuchElementException("Collection contains no element matching the predicate.");
            default:
                throw new IllegalArgumentException("Unable to use CPU(useArena = true) and CPU(useArena = false) at the same time!");
        }
        return mutableList;
    }

    public void reshape(@NotNull long... jArr) {
        Intrinsics.checkNotNullParameter(jArr, "dims");
        LongSpreadBuilder longSpreadBuilder = new LongSpreadBuilder(2);
        longSpreadBuilder.add(1L);
        longSpreadBuilder.addSpread(jArr);
        this.inputShape = longSpreadBuilder.toArray();
    }

    @NotNull
    public long[] getInputDimensions() {
        long[] jArr = this.inputShape;
        if (jArr == null) {
            Intrinsics.throwUninitializedPropertyAccessException("inputShape");
            jArr = null;
        }
        return new TensorShape(jArr).tail();
    }

    public int predict(@NotNull float[] fArr) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        return FloatArrayExtensionFunctionsKt.argmax(predictSoftly(fArr));
    }

    @NotNull
    public float[] predictSoftly(@NotNull float[] fArr, @NotNull String str) {
        String str2;
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        Intrinsics.checkNotNullParameter(str, "predictionTensorName");
        String str3 = str;
        if (str3.length() == 0) {
            OrtSession ortSession = this.session;
            if (ortSession == null) {
                Intrinsics.throwUninitializedPropertyAccessException("session");
                ortSession = null;
            }
            Set outputNames = ortSession.getOutputNames();
            Intrinsics.checkNotNullExpressionValue(outputNames, "session.outputNames");
            str2 = (String) CollectionsKt.first(outputNames);
        } else {
            str2 = str3;
        }
        final String str4 = str2;
        OrtSession ortSession2 = this.session;
        if (ortSession2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            ortSession2 = null;
        }
        Map outputInfo = ortSession2.getOutputInfo();
        Intrinsics.checkNotNullExpressionValue(outputInfo, "session.outputInfo");
        if (!outputInfo.containsKey(str4)) {
            StringBuilder append = new StringBuilder().append("There is no output with name '").append(str4).append("'. The model only has following outputs - ");
            OrtSession ortSession3 = this.session;
            if (ortSession3 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("session");
                ortSession3 = null;
            }
            throw new IllegalArgumentException(append.append(ortSession3.getOutputInfo().keySet()).toString().toString());
        }
        OrtSession ortSession4 = this.session;
        if (ortSession4 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            ortSession4 = null;
        }
        Map outputInfo2 = ortSession4.getOutputInfo();
        Intrinsics.checkNotNullExpressionValue(outputInfo2, "session.outputInfo");
        ValueInfo info = ((NodeInfo) MapsKt.getValue(outputInfo2, str4)).getInfo();
        OrtSessionResultConversions ortSessionResultConversions = OrtSessionResultConversions.INSTANCE;
        Intrinsics.checkNotNullExpressionValue(info, "outputInfo");
        Intrinsics.checkNotNullExpressionValue(str4, "outputTensorName");
        ortSessionResultConversions.throwIfOutputNotSupported$onnx(info, str4, "predictSoftly", OnnxJavaType.FLOAT);
        return (float[]) predictRaw(fArr, new Function1<OrtSession.Result, float[]>() { // from class: org.jetbrains.kotlinx.dl.onnx.inference.OnnxInferenceModel$predictSoftly$2
            /* JADX INFO: Access modifiers changed from: package-private */
            /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
            {
                super(1);
            }

            @NotNull
            public final float[] invoke(@NotNull OrtSession.Result result) {
                Intrinsics.checkNotNullParameter(result, "output");
                OrtSessionResultConversions ortSessionResultConversions2 = OrtSessionResultConversions.INSTANCE;
                String str5 = str4;
                Intrinsics.checkNotNullExpressionValue(str5, "outputTensorName");
                return ortSessionResultConversions2.getFloatArray(result, str5);
            }
        });
    }

    @NotNull
    public final float[] predictSoftly(@NotNull float[] fArr) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        OrtSession ortSession = this.session;
        if (ortSession == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            ortSession = null;
        }
        Set outputNames = ortSession.getOutputNames();
        Intrinsics.checkNotNullExpressionValue(outputNames, "session.outputNames");
        Object first = CollectionsKt.first(outputNames);
        Intrinsics.checkNotNullExpressionValue(first, "session.outputNames.first()");
        return predictSoftly(fArr, (String) first);
    }

    @NotNull
    public final Map<String, Object> predictRaw(@NotNull float[] fArr) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        return (Map) predictRaw(fArr, new Function1<OrtSession.Result, Map<String, ? extends Object>>() { // from class: org.jetbrains.kotlinx.dl.onnx.inference.OnnxInferenceModel$predictRaw$1
            @NotNull
            public final Map<String, Object> invoke(@NotNull OrtSession.Result result) {
                Intrinsics.checkNotNullParameter(result, "it");
                return OrtSessionResultConversions.INSTANCE.getValues(result);
            }
        });
    }

    public final <R> R predictRaw(@NotNull float[] fArr, @NotNull Function1<? super OrtSession.Result, ? extends R> function1) {
        Intrinsics.checkNotNullParameter(fArr, "inputData");
        Intrinsics.checkNotNullParameter(function1, "extractResult");
        if (!(this.inputShape != null)) {
            throw new IllegalArgumentException("Model input shape is not defined. Call reshape() to set input shape.".toString());
        }
        Companion companion = Companion;
        OrtEnvironment ortEnvironment = this.env;
        Intrinsics.checkNotNullExpressionValue(ortEnvironment, "env");
        OnnxJavaType inputDataType = getInputDataType();
        long[] jArr = this.inputShape;
        if (jArr == null) {
            Intrinsics.throwUninitializedPropertyAccessException("inputShape");
            jArr = null;
        }
        OnnxTensor onnxTensor = (AutoCloseable) companion.createTensor(ortEnvironment, fArr, inputDataType, jArr);
        try {
            OnnxTensor onnxTensor2 = onnxTensor;
            OrtSession ortSession = this.session;
            if (ortSession == null) {
                Intrinsics.throwUninitializedPropertyAccessException("session");
                ortSession = null;
            }
            OrtSession ortSession2 = this.session;
            if (ortSession2 == null) {
                Intrinsics.throwUninitializedPropertyAccessException("session");
                ortSession2 = null;
            }
            Set inputNames = ortSession2.getInputNames();
            Intrinsics.checkNotNullExpressionValue(inputNames, "session.inputNames");
            OrtSession.Result result = (AutoCloseable) ortSession.run(MapsKt.mapOf(TuplesKt.to(CollectionsKt.first(inputNames), onnxTensor2)));
            Throwable th = null;
            try {
                try {
                    OrtSession.Result result2 = result;
                    Intrinsics.checkNotNullExpressionValue(result2, "output");
                    R r = (R) function1.invoke(result2);
                    AutoCloseableKt.closeFinally(result, (Throwable) null);
                    return r;
                } finally {
                }
            } catch (Throwable th2) {
                AutoCloseableKt.closeFinally(result, th);
                throw th2;
            }
        } finally {
            AutoCloseableKt.closeFinally(onnxTensor, (Throwable) null);
        }
    }

    @NotNull
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public OnnxInferenceModel m59copy(@Nullable String str, boolean z, boolean z2) {
        OnnxInferenceModel onnxInferenceModel = new OnnxInferenceModel(this.modelSource);
        onnxInferenceModel.name = str;
        if (this.inputShape != null) {
            long[] inputDimensions = getInputDimensions();
            onnxInferenceModel.reshape(Arrays.copyOf(inputDimensions, inputDimensions.length));
        }
        if (this.session != null) {
            List<? extends ExecutionProvider> list = this.executionProvidersInUse;
            if (list == null) {
                Intrinsics.throwUninitializedPropertyAccessException("executionProvidersInUse");
                list = null;
            }
            Object[] array = list.toArray(new ExecutionProvider[0]);
            Intrinsics.checkNotNull(array, "null cannot be cast to non-null type kotlin.Array<T of kotlin.collections.ArraysKt__ArraysJVMKt.toTypedArray>");
            ExecutionProvider[] executionProviderArr = (ExecutionProvider[]) array;
            onnxInferenceModel.initializeWith((ExecutionProvider[]) Arrays.copyOf(executionProviderArr, executionProviderArr.length));
        }
        return onnxInferenceModel;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.session != null) {
            OrtSession ortSession = this.session;
            if (ortSession == null) {
                Intrinsics.throwUninitializedPropertyAccessException("session");
                ortSession = null;
            }
            ortSession.close();
        }
        this.env.close();
    }

    @NotNull
    public ModelSummary summary() {
        OrtSession ortSession = this.session;
        if (ortSession == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            ortSession = null;
        }
        Map inputInfo = ortSession.getInputInfo();
        Intrinsics.checkNotNullExpressionValue(inputInfo, "session.inputInfo");
        LinkedHashMap linkedHashMap = new LinkedHashMap(MapsKt.mapCapacity(inputInfo.size()));
        for (Object obj : inputInfo.entrySet()) {
            Object key = ((Map.Entry) obj).getKey();
            ValueInfo info = ((NodeInfo) ((Map.Entry) obj).getValue()).getInfo();
            Intrinsics.checkNotNullExpressionValue(info, "node.info");
            linkedHashMap.put(key, OnnxModelSummaryKt.summary(info));
        }
        List list = MapsKt.toList(linkedHashMap);
        OrtSession ortSession2 = this.session;
        if (ortSession2 == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            ortSession2 = null;
        }
        Map outputInfo = ortSession2.getOutputInfo();
        Intrinsics.checkNotNullExpressionValue(outputInfo, "session.outputInfo");
        LinkedHashMap linkedHashMap2 = new LinkedHashMap(MapsKt.mapCapacity(outputInfo.size()));
        for (Object obj2 : outputInfo.entrySet()) {
            Object key2 = ((Map.Entry) obj2).getKey();
            ValueInfo info2 = ((NodeInfo) ((Map.Entry) obj2).getValue()).getInfo();
            Intrinsics.checkNotNullExpressionValue(info2, "node.info");
            linkedHashMap2.put(key2, OnnxModelSummaryKt.summary(info2));
        }
        return new OnnxModelSummary(list, MapsKt.toList(linkedHashMap2));
    }

    @NotNull
    public String toString() {
        StringBuilder append = new StringBuilder().append("OnnxModel(session=");
        OrtSession ortSession = this.session;
        if (ortSession == null) {
            Intrinsics.throwUninitializedPropertyAccessException("session");
            ortSession = null;
        }
        return append.append(ortSession).append(')').toString();
    }
}
