package org.nd4j.tvm.util;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.ShortPointer;
import org.bytedeco.javacpp.indexer.Bfloat16Indexer;
import org.bytedeco.javacpp.indexer.ByteIndexer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.bytedeco.javacpp.indexer.ShortIndexer;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.indexer.UIntIndexer;
import org.bytedeco.javacpp.indexer.UShortIndexer;
import org.bytedeco.tvm.DLDataType;
import org.bytedeco.tvm.DLDevice;
import org.bytedeco.tvm.DLTensor;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/tvm/util/TVMUtils.class */
public class TVMUtils {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.nd4j.tvm.util.TVMUtils$1, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/tvm/util/TVMUtils$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$buffer$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.BYTE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.SHORT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.INT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.LONG.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UBYTE.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UINT16.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UINT32.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UINT64.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.HALF.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.FLOAT.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.DOUBLE.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.BFLOAT16.ordinal()] = 12;
            } catch (NoSuchFieldError e12) {
            }
        }
    }

    public static DataType dataTypeForTvmType(DLDataType dLDataType) {
        if (dLDataType.code() == 0 && dLDataType.bits() == 8) {
            return DataType.INT8;
        }
        if (dLDataType.code() == 0 && dLDataType.bits() == 16) {
            return DataType.INT16;
        }
        if (dLDataType.code() == 0 && dLDataType.bits() == 32) {
            return DataType.INT32;
        }
        if (dLDataType.code() == 0 && dLDataType.bits() == 64) {
            return DataType.INT64;
        }
        if (dLDataType.code() == 1 && dLDataType.bits() == 8) {
            return DataType.UINT8;
        }
        if (dLDataType.code() == 1 && dLDataType.bits() == 16) {
            return DataType.UINT16;
        }
        if (dLDataType.code() == 1 && dLDataType.bits() == 32) {
            return DataType.UINT32;
        }
        if (dLDataType.code() == 1 && dLDataType.bits() == 64) {
            return DataType.UINT64;
        }
        if (dLDataType.code() == 2 && dLDataType.bits() == 16) {
            return DataType.FLOAT16;
        }
        if (dLDataType.code() == 2 && dLDataType.bits() == 32) {
            return DataType.FLOAT;
        }
        if (dLDataType.code() == 2 && dLDataType.bits() == 64) {
            return DataType.DOUBLE;
        }
        if (dLDataType.code() == 4 && dLDataType.bits() == 16) {
            return DataType.BFLOAT16;
        }
        throw new IllegalArgumentException("Illegal data type code " + dLDataType.code() + " with bits " + dLDataType.bits());
    }

    public static DLDataType tvmTypeForDataType(DataType dataType) {
        if (dataType == DataType.INT8) {
            return new DLDataType().code((byte) 0).bits((byte) 8).lanes((short) 1);
        }
        if (dataType == DataType.INT16) {
            return new DLDataType().code((byte) 0).bits((byte) 16).lanes((short) 1);
        }
        if (dataType == DataType.INT32) {
            return new DLDataType().code((byte) 0).bits((byte) 32).lanes((short) 1);
        }
        if (dataType == DataType.INT64) {
            return new DLDataType().code((byte) 0).bits((byte) 64).lanes((short) 1);
        }
        if (dataType == DataType.UINT8) {
            return new DLDataType().code((byte) 1).bits((byte) 8).lanes((short) 1);
        }
        if (dataType == DataType.UINT16) {
            return new DLDataType().code((byte) 1).bits((byte) 16).lanes((short) 1);
        }
        if (dataType == DataType.UINT32) {
            return new DLDataType().code((byte) 1).bits((byte) 32).lanes((short) 1);
        }
        if (dataType == DataType.UINT64) {
            return new DLDataType().code((byte) 1).bits((byte) 64).lanes((short) 1);
        }
        if (dataType == DataType.FLOAT16) {
            return new DLDataType().code((byte) 2).bits((byte) 16).lanes((short) 1);
        }
        if (dataType == DataType.FLOAT) {
            return new DLDataType().code((byte) 2).bits((byte) 32).lanes((short) 1);
        }
        if (dataType == DataType.DOUBLE) {
            return new DLDataType().code((byte) 2).bits((byte) 64).lanes((short) 1);
        }
        if (dataType == DataType.BFLOAT16) {
            return new DLDataType().code((byte) 4).bits((byte) 16).lanes((short) 1);
        }
        throw new IllegalArgumentException("Illegal data type " + dataType);
    }

    public static INDArray getArray(DLTensor dLTensor) {
        long[] jArr;
        long[] strides;
        DataType dataTypeForTvmType = dataTypeForTvmType(dLTensor.dtype());
        LongPointer shape = dLTensor.shape();
        LongPointer strides2 = dLTensor.strides();
        if (shape != null) {
            jArr = new long[dLTensor.ndim()];
            shape.get(jArr);
        } else {
            jArr = new long[]{1};
        }
        if (strides2 != null) {
            strides = new long[dLTensor.ndim()];
            strides2.get(strides);
        } else {
            strides = Nd4j.getStrides(jArr);
        }
        long j = 1;
        for (long j2 : jArr) {
            j *= j2;
        }
        DataBuffer dataBuffer = getDataBuffer(dLTensor, j * (dLTensor.dtype().bits() / 8));
        Preconditions.checkState(dataTypeForTvmType.equals(dataBuffer.dataType()), "Data type must be equivalent as specified by the tvm metadata.");
        return Nd4j.create(dataBuffer, jArr, strides, 0L);
    }

    public static DLTensor getTensor(INDArray iNDArray, DLDevice dLDevice) {
        DLTensor dLTensor = new DLTensor();
        dLTensor.data(iNDArray.data().pointer());
        dLTensor.device(dLDevice);
        dLTensor.ndim(iNDArray.rank());
        dLTensor.dtype(tvmTypeForDataType(iNDArray.dataType()));
        dLTensor.shape(new LongPointer(iNDArray.shape()));
        dLTensor.strides(new LongPointer(iNDArray.stride()));
        dLTensor.byte_offset(iNDArray.offset());
        return dLTensor;
    }

    public static DataBuffer getDataBuffer(DLTensor dLTensor, long j) {
        DataBuffer createBuffer;
        DataType dataTypeForTvmType = dataTypeForTvmType(dLTensor.dtype());
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataTypeForTvmType.ordinal()]) {
            case 1:
                BytePointer capacity = new BytePointer(dLTensor.data()).capacity(j);
                createBuffer = Nd4j.createBuffer(capacity, dataTypeForTvmType, j, ByteIndexer.create(capacity));
                break;
            case 2:
                ShortPointer capacity2 = new ShortPointer(dLTensor.data()).capacity(j);
                createBuffer = Nd4j.createBuffer(capacity2, dataTypeForTvmType, j, ShortIndexer.create(capacity2));
                break;
            case 3:
                IntPointer capacity3 = new IntPointer(dLTensor.data()).capacity(j);
                createBuffer = Nd4j.createBuffer(capacity3, dataTypeForTvmType, j, IntIndexer.create(capacity3));
                break;
            case 4:
                LongPointer capacity4 = new LongPointer(dLTensor.data()).capacity(j);
                createBuffer = Nd4j.createBuffer(capacity4, dataTypeForTvmType, j, LongIndexer.create(capacity4));
                break;
            case 5:
                BytePointer capacity5 = new BytePointer(dLTensor.data()).capacity(j);
                createBuffer = Nd4j.createBuffer(capacity5, dataTypeForTvmType, j, UByteIndexer.create(capacity5));
                break;
            case 6:
                ShortPointer capacity6 = new ShortPointer(dLTensor.data()).capacity(j);
                createBuffer = Nd4j.createBuffer(capacity6, dataTypeForTvmType, j, UShortIndexer.create(capacity6));
                break;
            case 7:
                IntPointer capacity7 = new IntPointer(dLTensor.data()).capacity(j);
                createBuffer = Nd4j.createBuffer(capacity7, dataTypeForTvmType, j, UIntIndexer.create(capacity7));
                break;
            case 8:
                LongPointer capacity8 = new LongPointer(dLTensor.data()).capacity(j);
                createBuffer = Nd4j.createBuffer(capacity8, dataTypeForTvmType, j, LongIndexer.create(capacity8));
                break;
            case 9:
                ShortPointer capacity9 = new ShortPointer(dLTensor.data()).capacity(j);
                createBuffer = Nd4j.createBuffer(capacity9, dataTypeForTvmType, j, HalfIndexer.create(capacity9));
                break;
            case 10:
                FloatPointer capacity10 = new FloatPointer(dLTensor.data()).capacity(j);
                createBuffer = Nd4j.createBuffer(capacity10, dataTypeForTvmType, j, FloatIndexer.create(capacity10));
                break;
            case 11:
                DoublePointer capacity11 = new DoublePointer(dLTensor.data()).capacity(j);
                createBuffer = Nd4j.createBuffer(capacity11, dataTypeForTvmType, j, DoubleIndexer.create(capacity11));
                break;
            case 12:
                ShortPointer capacity12 = new ShortPointer(dLTensor.data()).capacity(j);
                createBuffer = Nd4j.createBuffer(capacity12, dataTypeForTvmType, j, Bfloat16Indexer.create(capacity12));
                break;
            default:
                throw new RuntimeException("Unsupported data type encountered");
        }
        return createBuffer;
    }
}
