package org.nd4j.onnxruntime.util;

import org.bytedeco.javacpp.BoolPointer;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.javacpp.ShortPointer;
import org.bytedeco.javacpp.indexer.BooleanIndexer;
import org.bytedeco.javacpp.indexer.ByteIndexer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.bytedeco.javacpp.indexer.ShortIndexer;
import org.bytedeco.onnxruntime.MemoryInfo;
import org.bytedeco.onnxruntime.Value;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;

/* loaded from: input_file:org/nd4j/onnxruntime/util/ONNXUtils.class */
public class ONNXUtils {
    public static void validateType(DataType dataType, INDArray iNDArray) {
        if (!iNDArray.dataType().equals(dataType)) {
            throw new RuntimeException("INDArray data type (" + iNDArray.dataType() + ") does not match required ONNX data type (" + dataType + ")");
        }
    }

    public static DataType dataTypeForOnnxType(int i) {
        if (i == i) {
            return DataType.FLOAT;
        }
        if (i == 3) {
            return DataType.INT8;
        }
        if (i == 11) {
            return DataType.DOUBLE;
        }
        if (i == 9) {
            return DataType.BOOL;
        }
        if (i == 2) {
            return DataType.UINT8;
        }
        if (i == 4) {
            return DataType.UINT16;
        }
        if (i == 5) {
            return DataType.INT16;
        }
        if (i == 6) {
            return DataType.INT32;
        }
        if (i == 7) {
            return DataType.INT64;
        }
        if (i == 10) {
            return DataType.FLOAT16;
        }
        if (i == 12) {
            return DataType.UINT32;
        }
        if (i == 13) {
            return DataType.UINT64;
        }
        if (i == 16) {
            return DataType.BFLOAT16;
        }
        throw new IllegalArgumentException("Illegal data type " + i);
    }

    public static int onnxTypeForDataType(DataType dataType) {
        if (dataType == DataType.FLOAT) {
            return 1;
        }
        if (dataType == DataType.INT8) {
            return 3;
        }
        if (dataType == DataType.DOUBLE) {
            return 11;
        }
        if (dataType == DataType.BOOL) {
            return 9;
        }
        if (dataType == DataType.UINT8) {
            return 2;
        }
        if (dataType == DataType.UINT16) {
            return 4;
        }
        if (dataType == DataType.INT16) {
            return 5;
        }
        if (dataType == DataType.INT32) {
            return 6;
        }
        if (dataType == DataType.INT64) {
            return 7;
        }
        if (dataType == DataType.FLOAT16) {
            return 10;
        }
        if (dataType == DataType.UINT32) {
            return 12;
        }
        if (dataType == DataType.UINT64) {
            return 13;
        }
        if (dataType == DataType.BFLOAT16) {
            return 16;
        }
        throw new IllegalArgumentException("Illegal data type " + dataType);
    }

    public static INDArray getArray(Value value) {
        long[] jArr;
        DataType dataTypeForOnnxType = dataTypeForOnnxType(value.GetTypeInfo().GetONNXType());
        LongPointer GetShape = value.GetTensorTypeAndShapeInfo().GetShape();
        if (GetShape != null) {
            jArr = new long[(int) value.GetTensorTypeAndShapeInfo().GetDimensionsCount()];
            GetShape.get(jArr);
        } else {
            jArr = new long[]{1};
        }
        DataBuffer dataBuffer = getDataBuffer(value);
        Preconditions.checkState(dataTypeForOnnxType.equals(dataBuffer.dataType()), "Data type must be equivalent as specified by the onnx metadata.");
        return Nd4j.create(dataBuffer, jArr, Nd4j.getStrides(jArr), 0L);
    }

    public static int getOnnxLogLevelFromLogger(Logger logger) {
        if (logger.isTraceEnabled() || logger.isDebugEnabled()) {
            return 0;
        }
        if (logger.isInfoEnabled()) {
            return 1;
        }
        if (logger.isWarnEnabled()) {
            return 2;
        }
        return logger.isErrorEnabled() ? 3 : 1;
    }

    public static Value getTensor(INDArray iNDArray, MemoryInfo memoryInfo) {
        return Value.CreateTensor(memoryInfo.asOrtMemoryInfo(), iNDArray.data().pointer(), iNDArray.length() * iNDArray.data().getElementSize(), new LongPointer(iNDArray.shape()), iNDArray.rank(), onnxTypeForDataType(iNDArray.dataType()));
    }

    public static DataBuffer getDataBuffer(Value value) {
        DataBuffer createBuffer;
        if (value.isNull()) {
            throw new IllegalArgumentException("Native underlying tensor value was null!");
        }
        PointerScope pointerScope = new PointerScope(new Class[0]);
        Throwable th = null;
        try {
            int GetElementType = value.GetTensorTypeAndShapeInfo().GetElementType();
            long GetElementCount = value.GetTensorTypeAndShapeInfo().GetElementCount();
            switch (GetElementType) {
                case 1:
                    FloatPointer capacity = value.GetTensorMutableDataFloat().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity, DataType.FLOAT, GetElementCount, FloatIndexer.create(capacity));
                    break;
                case 2:
                    BytePointer capacity2 = value.GetTensorMutableDataUByte().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity2, DataType.UINT8, GetElementCount, ByteIndexer.create(capacity2));
                    break;
                case 3:
                    BytePointer capacity3 = value.GetTensorMutableDataByte().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity3, DataType.UINT8, GetElementCount, ByteIndexer.create(capacity3));
                    break;
                case 4:
                    ShortPointer capacity4 = value.GetTensorMutableDataUShort().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity4, DataType.UINT16, GetElementCount, ShortIndexer.create(capacity4));
                    break;
                case 5:
                    ShortPointer capacity5 = value.GetTensorMutableDataShort().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity5, DataType.INT16, GetElementCount, ShortIndexer.create(capacity5));
                    break;
                case 6:
                    IntPointer capacity6 = value.GetTensorMutableDataInt().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity6, DataType.INT32, GetElementCount, IntIndexer.create(capacity6));
                    break;
                case 7:
                    LongPointer capacity7 = value.GetTensorMutableDataLong().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity7, DataType.INT64, GetElementCount, LongIndexer.create(capacity7));
                    break;
                case 8:
                    BytePointer capacity8 = value.GetTensorMutableDataByte().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity8, DataType.INT8, GetElementCount, ByteIndexer.create(capacity8));
                    break;
                case 9:
                    BoolPointer capacity9 = value.GetTensorMutableDataBool().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity9, DataType.BOOL, GetElementCount, BooleanIndexer.create(new BooleanPointer(capacity9)));
                    break;
                case 10:
                    ShortPointer capacity10 = value.GetTensorMutableDataShort().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity10, DataType.FLOAT16, GetElementCount, ShortIndexer.create(capacity10));
                    break;
                case 11:
                    DoublePointer capacity11 = value.GetTensorMutableDataDouble().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity11, DataType.DOUBLE, GetElementCount, DoubleIndexer.create(capacity11));
                    break;
                case 12:
                    IntPointer capacity12 = value.GetTensorMutableDataUInt().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity12, DataType.UINT32, GetElementCount, IntIndexer.create(capacity12));
                    break;
                case 13:
                    LongPointer capacity13 = value.GetTensorMutableDataULong().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity13, DataType.UINT64, GetElementCount, LongIndexer.create(capacity13));
                    break;
                case 14:
                case 15:
                default:
                    throw new RuntimeException("Unsupported data type encountered");
                case 16:
                    ShortPointer capacity14 = value.GetTensorMutableDataShort().capacity(GetElementCount);
                    createBuffer = Nd4j.createBuffer(capacity14, DataType.BFLOAT16, GetElementCount, ShortIndexer.create(capacity14));
                    break;
            }
            return createBuffer;
        } finally {
            if (pointerScope != null) {
                if (0 != 0) {
                    try {
                        pointerScope.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    pointerScope.close();
                }
            }
        }
    }
}
