package org.nd4j.imports.graphmapper.tf.tensors;

import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import org.bytedeco.cuda.global.nppc;
import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.TensorProto;

/* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers.class */
public class TFTensorMappers {

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$BFloat16TensorMapper.class */
    public static class BFloat16TensorMapper extends BaseTensorMapper<float[], ShortBuffer> {
        public BFloat16TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getHalfValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public float[] newArray(int i) {
            return new float[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public ShortBuffer getBuffer(ByteBuffer byteBuffer) {
            return byteBuffer.asShortBuffer();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(float[] fArr, int i) {
            fArr[i] = Bfloat16ArrayIndexer.toFloat(this.tfTensor.getHalfVal(i));
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(float[] fArr, ShortBuffer shortBuffer, int i) {
            throw new UnsupportedOperationException("Not yet implemnted: BFP16 reading from buffer");
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, float[] fArr) {
            return (fArr.length != 1 || ArrayUtil.prod(jArr) <= 1) ? Nd4j.create(fArr, jArr, 'c').castTo(DataType.BFLOAT16) : Nd4j.createUninitialized(DataType.HALF, jArr).assign(Float.valueOf(fArr[0]));
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$BaseTensorMapper.class */
    public static abstract class BaseTensorMapper<T, U extends Buffer> implements TFTensorMapper<T, U> {
        protected TensorProto tfTensor;

        public BaseTensorMapper(TensorProto tensorProto) {
            this.tfTensor = tensorProto;
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public DataType dataType() {
            return ArrayOptionsHelper.convertToDataType(this.tfTensor.getDtype());
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public long[] shape() {
            int dimCount = this.tfTensor.getTensorShape().getDimCount();
            long[] jArr = new long[dimCount];
            for (int i = 0; i < dimCount; i++) {
                jArr[i] = this.tfTensor.getTensorShape().getDim(i).getSize();
            }
            return jArr;
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public boolean isEmpty() {
            return valueSource() == TFTensorMapper.ValueSource.EMPTY;
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public TFTensorMapper.ValueSource valueSource() {
            return valueCount() > 0 ? TFTensorMapper.ValueSource.VALUE_COUNT : (this.tfTensor.getTensorContent() == null || this.tfTensor.getTensorContent().size() <= 0) ? TFTensorMapper.ValueSource.EMPTY : TFTensorMapper.ValueSource.BINARY;
        }

        /* JADX WARN: Type inference failed for: r0v10, types: [java.nio.Buffer] */
        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray toNDArray() {
            INDArray arrayFor;
            DataType dataType = dataType();
            TFTensorMapper.ValueSource valueSource = valueSource();
            long[] shape = shape();
            switch (valueSource) {
                case EMPTY:
                    arrayFor = Nd4j.create(dataType, shape);
                    break;
                case VALUE_COUNT:
                    int valueCount = valueCount();
                    T newArray = newArray(valueCount);
                    for (int i = 0; i < valueCount; i++) {
                        getValue(newArray, i);
                    }
                    arrayFor = arrayFor(shape, newArray);
                    break;
                case BINARY:
                    ?? buffer = getBuffer(this.tfTensor.getTensorContent().asReadOnlyByteBuffer().order(ByteOrder.nativeOrder()));
                    int capacity = buffer.capacity();
                    T newArray2 = newArray(capacity);
                    for (int i2 = 0; i2 < capacity; i2++) {
                        getValue(newArray2, buffer, i2);
                    }
                    arrayFor = arrayFor(shape, newArray2);
                    break;
                default:
                    throw new RuntimeException("Error converting TF tensor to INDArray");
            }
            return arrayFor;
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$BoolTensorMapper.class */
    public static class BoolTensorMapper extends BaseTensorMapper<boolean[], ByteBuffer> {
        public BoolTensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getBoolValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public boolean[] newArray(int i) {
            return new boolean[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public ByteBuffer getBuffer(ByteBuffer byteBuffer) {
            throw new UnsupportedOperationException("Not supported for String types");
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(boolean[] zArr, int i) {
            zArr[i] = this.tfTensor.getBoolVal(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(boolean[] zArr, ByteBuffer byteBuffer, int i) {
            throw new UnsupportedOperationException("Not supported for boolean types");
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, boolean[] zArr) {
            return Nd4j.create(zArr).reshape(jArr);
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$Float16TensorMapper.class */
    public static class Float16TensorMapper extends BaseTensorMapper<float[], Buffer> {
        public Float16TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getHalfValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public float[] newArray(int i) {
            return new float[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public Buffer getBuffer(ByteBuffer byteBuffer) {
            throw new UnsupportedOperationException("Not yet implemnted: FP16 reading from buffer");
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(float[] fArr, int i) {
            fArr[i] = HalfIndexer.toFloat(this.tfTensor.getHalfVal(i));
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(float[] fArr, Buffer buffer, int i) {
            throw new UnsupportedOperationException("Not yet implemnted: FP16 reading from buffer");
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, float[] fArr) {
            return (fArr.length != 1 || ArrayUtil.prod(jArr) <= 1) ? Nd4j.create(fArr, jArr, 'c').castTo(DataType.HALF) : Nd4j.createUninitialized(DataType.HALF, jArr).assign(Float.valueOf(fArr[0]));
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$Float32TensorMapper.class */
    public static class Float32TensorMapper extends BaseTensorMapper<float[], FloatBuffer> {
        public Float32TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getFloatValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public float[] newArray(int i) {
            return new float[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public FloatBuffer getBuffer(ByteBuffer byteBuffer) {
            return byteBuffer.asFloatBuffer();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(float[] fArr, int i) {
            fArr[i] = this.tfTensor.getFloatVal(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(float[] fArr, FloatBuffer floatBuffer, int i) {
            fArr[i] = floatBuffer.get(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, float[] fArr) {
            return (fArr.length != 1 || ArrayUtil.prod(jArr) <= 1) ? Nd4j.create(fArr, jArr, 'c') : Nd4j.valueArrayOf(jArr, fArr[0]);
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$Float64TensorMapper.class */
    public static class Float64TensorMapper extends BaseTensorMapper<double[], DoubleBuffer> {
        public Float64TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getDoubleValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public double[] newArray(int i) {
            return new double[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public DoubleBuffer getBuffer(ByteBuffer byteBuffer) {
            return byteBuffer.asDoubleBuffer();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(double[] dArr, int i) {
            dArr[i] = this.tfTensor.getDoubleVal(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(double[] dArr, DoubleBuffer doubleBuffer, int i) {
            dArr[i] = doubleBuffer.get(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, double[] dArr) {
            return (dArr.length != 1 || ArrayUtil.prod(jArr) <= 1) ? Nd4j.create(dArr, jArr, 'c') : Nd4j.valueArrayOf(jArr, dArr[0]);
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$Int16TensorMapper.class */
    public static class Int16TensorMapper extends BaseTensorMapper<int[], ShortBuffer> {
        public Int16TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getIntValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int[] newArray(int i) {
            return new int[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public ShortBuffer getBuffer(ByteBuffer byteBuffer) {
            return byteBuffer.asShortBuffer();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(int[] iArr, int i) {
            iArr[i] = this.tfTensor.getIntVal(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(int[] iArr, ShortBuffer shortBuffer, int i) {
            iArr[i] = shortBuffer.get(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, int[] iArr) {
            DataType dataType = dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(iArr, dataType), jArr, Nd4j.getStrides(jArr, 'c'), 0L, 'c', dataType);
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$Int32TensorMapper.class */
    public static class Int32TensorMapper extends BaseTensorMapper<int[], IntBuffer> {
        public Int32TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getIntValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int[] newArray(int i) {
            return new int[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public IntBuffer getBuffer(ByteBuffer byteBuffer) {
            return byteBuffer.asIntBuffer();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(int[] iArr, int i) {
            iArr[i] = this.tfTensor.getIntVal(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(int[] iArr, IntBuffer intBuffer, int i) {
            iArr[i] = intBuffer.get(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, int[] iArr) {
            DataType dataType = dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(iArr, dataType), jArr, Nd4j.getStrides(jArr, 'c'), 0L, 'c', dataType);
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$Int64TensorMapper.class */
    public static class Int64TensorMapper extends BaseTensorMapper<long[], LongBuffer> {
        public Int64TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getInt64ValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public long[] newArray(int i) {
            return new long[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public LongBuffer getBuffer(ByteBuffer byteBuffer) {
            return byteBuffer.asLongBuffer();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(long[] jArr, int i) {
            jArr[i] = this.tfTensor.getInt64Val(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(long[] jArr, LongBuffer longBuffer, int i) {
            jArr[i] = longBuffer.get(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, long[] jArr2) {
            DataType dataType = dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(jArr2, dataType), jArr, Nd4j.getStrides(jArr, 'c'), 0L, 'c', dataType);
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$Int8TensorMapper.class */
    public static class Int8TensorMapper extends BaseTensorMapper<int[], ByteBuffer> {
        public Int8TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getIntValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int[] newArray(int i) {
            return new int[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public ByteBuffer getBuffer(ByteBuffer byteBuffer) {
            return byteBuffer;
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(int[] iArr, int i) {
            iArr[i] = this.tfTensor.getIntVal(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(int[] iArr, ByteBuffer byteBuffer, int i) {
            iArr[i] = byteBuffer.get(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, int[] iArr) {
            DataType dataType = dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(iArr, dataType), jArr, Nd4j.getStrides(jArr, 'c'), 0L, 'c', dataType);
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$StringTensorMapper.class */
    public static class StringTensorMapper extends BaseTensorMapper<String[], ByteBuffer> {
        public StringTensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getStringValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public String[] newArray(int i) {
            return new String[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public ByteBuffer getBuffer(ByteBuffer byteBuffer) {
            throw new UnsupportedOperationException("Not supported for String types");
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(String[] strArr, int i) {
            strArr[i] = this.tfTensor.getStringVal(i).toStringUtf8();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(String[] strArr, ByteBuffer byteBuffer, int i) {
            throw new UnsupportedOperationException("Not supported for String types");
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, String[] strArr) {
            return Nd4j.create(strArr).reshape(jArr);
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$UInt16TensorMapper.class */
    public static class UInt16TensorMapper extends BaseTensorMapper<int[], ShortBuffer> {
        public UInt16TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getIntValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int[] newArray(int i) {
            return new int[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public ShortBuffer getBuffer(ByteBuffer byteBuffer) {
            return byteBuffer.asShortBuffer();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(int[] iArr, int i) {
            iArr[i] = this.tfTensor.getIntVal(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(int[] iArr, ShortBuffer shortBuffer, int i) {
            iArr[i] = shortBuffer.get(i) & 65535;
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, int[] iArr) {
            DataType dataType = dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(iArr, dataType), jArr, Nd4j.getStrides(jArr, 'c'), 0L, 'c', dataType);
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$UInt32TensorMapper.class */
    public static class UInt32TensorMapper extends BaseTensorMapper<long[], IntBuffer> {
        public UInt32TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getInt64ValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public long[] newArray(int i) {
            return new long[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public IntBuffer getBuffer(ByteBuffer byteBuffer) {
            return byteBuffer.asIntBuffer();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(long[] jArr, int i) {
            jArr[i] = this.tfTensor.getInt64Val(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(long[] jArr, IntBuffer intBuffer, int i) {
            jArr[i] = intBuffer.get(i) & nppc.NPP_MAX_32U;
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, long[] jArr2) {
            DataType dataType = dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(jArr2, dataType), jArr, Nd4j.getStrides(jArr, 'c'), 0L, 'c', dataType);
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$UInt64TensorMapper.class */
    public static class UInt64TensorMapper extends BaseTensorMapper<long[], LongBuffer> {
        public UInt64TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getInt64ValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public long[] newArray(int i) {
            return new long[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public LongBuffer getBuffer(ByteBuffer byteBuffer) {
            return byteBuffer.asLongBuffer();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(long[] jArr, int i) {
            jArr[i] = this.tfTensor.getInt64Val(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(long[] jArr, LongBuffer longBuffer, int i) {
            jArr[i] = longBuffer.get(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, long[] jArr2) {
            DataType dataType = dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(jArr2, dataType), jArr, Nd4j.getStrides(jArr, 'c'), 0L, 'c', dataType);
        }
    }

    /* loaded from: input_file:org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers$UInt8TensorMapper.class */
    public static class UInt8TensorMapper extends BaseTensorMapper<int[], ByteBuffer> {
        public UInt8TensorMapper(TensorProto tensorProto) {
            super(tensorProto);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int valueCount() {
            return this.tfTensor.getIntValCount();
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public int[] newArray(int i) {
            return new int[i];
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public ByteBuffer getBuffer(ByteBuffer byteBuffer) {
            return byteBuffer;
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(int[] iArr, int i) {
            iArr[i] = this.tfTensor.getIntVal(i);
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public void getValue(int[] iArr, ByteBuffer byteBuffer, int i) {
            iArr[i] = byteBuffer.get(i) & 255;
        }

        @Override // org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper
        public INDArray arrayFor(long[] jArr, int[] iArr) {
            DataType dataType = dataType();
            return Nd4j.create(Nd4j.createTypedBuffer(iArr, dataType), jArr, Nd4j.getStrides(jArr, 'c'), 0L, 'c', dataType);
        }
    }

    private TFTensorMappers() {
    }

    public static TFTensorMapper<?, ?> newMapper(TensorProto tensorProto) {
        switch (tensorProto.getDtype()) {
            case DT_HALF:
                return new Float16TensorMapper(tensorProto);
            case DT_FLOAT:
                return new Float32TensorMapper(tensorProto);
            case DT_DOUBLE:
                return new Float64TensorMapper(tensorProto);
            case DT_BFLOAT16:
                return new BFloat16TensorMapper(tensorProto);
            case DT_INT8:
                return new Int8TensorMapper(tensorProto);
            case DT_INT16:
                return new Int16TensorMapper(tensorProto);
            case DT_INT32:
                return new Int32TensorMapper(tensorProto);
            case DT_INT64:
                return new Int64TensorMapper(tensorProto);
            case DT_STRING:
                return new StringTensorMapper(tensorProto);
            case DT_BOOL:
                return new BoolTensorMapper(tensorProto);
            case DT_UINT8:
                return new UInt8TensorMapper(tensorProto);
            case DT_UINT16:
                return new UInt16TensorMapper(tensorProto);
            case DT_UINT32:
                return new UInt32TensorMapper(tensorProto);
            case DT_UINT64:
                return new UInt64TensorMapper(tensorProto);
            case DT_QINT8:
            case DT_QUINT8:
            case DT_QINT32:
            case DT_QINT16:
            case DT_QUINT16:
                throw new IllegalStateException("Unable to map quantized type: " + tensorProto.getDtype());
            case DT_COMPLEX64:
            case DT_COMPLEX128:
                throw new IllegalStateException("Unable to map complex type: " + tensorProto.getDtype());
            case DT_FLOAT_REF:
            case DT_DOUBLE_REF:
            case DT_INT32_REF:
            case DT_UINT8_REF:
            case DT_INT16_REF:
            case DT_INT8_REF:
            case DT_STRING_REF:
            case DT_COMPLEX64_REF:
            case DT_INT64_REF:
            case DT_BOOL_REF:
            case DT_QINT8_REF:
            case DT_QUINT8_REF:
            case DT_QINT32_REF:
            case DT_BFLOAT16_REF:
            case DT_QINT16_REF:
            case DT_QUINT16_REF:
            case DT_UINT16_REF:
            case DT_COMPLEX128_REF:
            case DT_HALF_REF:
            case DT_RESOURCE_REF:
            case DT_VARIANT_REF:
            case DT_UINT32_REF:
            case DT_UINT64_REF:
                throw new IllegalStateException("Unable to map reference type: " + tensorProto.getDtype());
            case UNRECOGNIZED:
            case DT_RESOURCE:
            case DT_VARIANT:
            case DT_INVALID:
            default:
                throw new IllegalStateException("Unable to map type: " + tensorProto.getDtype());
        }
    }
}
