package org.nd4j.linalg.jcublas.compression;

import java.util.ArrayList;
import org.apache.commons.math3.util.FastMath;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.compression.impl.AbstractCompressor;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionType;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/compression/CudaThreshold.class */
public class CudaThreshold extends AbstractCompressor {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) CudaThreshold.class);
    protected float threshold = 0.001f;

    @Override // org.nd4j.linalg.compression.NDArrayCompressor
    public String getDescriptor() {
        return "THRESHOLD";
    }

    @Override // org.nd4j.compression.impl.AbstractCompressor, org.nd4j.linalg.compression.NDArrayCompressor
    public void configure(Object... objArr) {
        if (!(objArr[0] instanceof Number)) {
            throw new ND4JIllegalStateException("Threshold value should be Number");
        }
        this.threshold = FastMath.abs(((Number) objArr[0]).floatValue());
        log.info("Setting threshold to [{}]", Float.valueOf(this.threshold));
    }

    @Override // org.nd4j.compression.impl.AbstractCompressor, org.nd4j.linalg.compression.NDArrayCompressor
    public INDArray compress(INDArray iNDArray) {
        Nd4j.getExecutioner().commit();
        DataBuffer compress = compress(iNDArray.data());
        if (compress == null) {
            return null;
        }
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(compress, iNDArray.shapeInfoDataBuffer());
        createArrayFromShapeBuffer.markAsCompressed(true);
        return createArrayFromShapeBuffer;
    }

    @Override // org.nd4j.linalg.compression.NDArrayCompressor
    public CompressionType getCompressionType() {
        return CompressionType.LOSSLESS;
    }

    @Override // org.nd4j.compression.impl.AbstractCompressor, org.nd4j.linalg.compression.NDArrayCompressor
    public DataBuffer decompress(DataBuffer dataBuffer) {
        if (dataBuffer.dataType() != DataBuffer.Type.INT) {
            throw new UnsupportedOperationException();
        }
        long j = dataBuffer.getInt(0L);
        DataBuffer createBuffer = Nd4j.createBuffer(dataBuffer.getInt(1L));
        NativeOpsHolder.getInstance().getDeviceNativeOps().decodeThresholdFloat(new PointerPointer(32L).put(1L, ((CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext()).getOldStream()), AtomicAllocator.getInstance().getPointer(dataBuffer), j, (FloatPointer) AtomicAllocator.getInstance().getPointer(createBuffer));
        AtomicAllocator.getInstance().getAllocationPoint(createBuffer).tickDeviceWrite();
        return createBuffer;
    }

    @Override // org.nd4j.compression.impl.AbstractCompressor, org.nd4j.linalg.compression.NDArrayCompressor
    public DataBuffer compress(DataBuffer dataBuffer) {
        int length = (int) ((dataBuffer.length() / 1024) + (dataBuffer.length() % ((long) 1024) == 0 ? 0 : 1));
        CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
        DataBuffer createInt = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(length + 1, true) : Nd4j.getDataBufferFactory().createInt(length + 1, true, Nd4j.getMemoryManager().getCurrentWorkspace());
        PointerPointer put = new PointerPointer(32L).put(1L, cudaContext.getOldStream());
        NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP1Float(put, (FloatPointer) AtomicAllocator.getInstance().getPointer(dataBuffer), dataBuffer.length(), (IntPointer) AtomicAllocator.getInstance().getPointer(createInt), this.threshold);
        AtomicAllocator.getInstance().getAllocationPoint(createInt).tickDeviceWrite();
        int i = createInt.getInt(0L);
        DataBuffer createInt2 = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(3 + i, false) : Nd4j.getDataBufferFactory().createInt(3 + i, false, Nd4j.getMemoryManager().getCurrentWorkspace());
        createInt2.put(0L, i);
        createInt2.put(1L, (int) dataBuffer.length());
        createInt2.put(2L, Float.floatToIntBits(this.threshold));
        AtomicAllocator.getInstance().getAllocationPoint(createInt2).tickHostWrite();
        int i2 = length;
        int i3 = 0;
        ArrayList arrayList = new ArrayList();
        do {
            int max = Math.max(1, (int) Math.ceil(i2 / (2.0f * 512)));
            if (length > 1) {
                i3++;
            }
            i2 = max;
        } while (i2 > 1);
        long[] jArr = new long[i3];
        int i4 = 0;
        int i5 = length;
        DataBuffer createDouble = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createDouble(jArr.length, false) : Nd4j.getDataBufferFactory().createDouble(jArr.length, false, Nd4j.getMemoryManager().getCurrentWorkspace());
        do {
            int max2 = Math.max(1, (int) Math.ceil(i5 / (2.0f * 512)));
            if (max2 > 1) {
                DataBuffer createInt3 = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(max2, false) : Nd4j.getDataBufferFactory().createInt(max2, false, Nd4j.getMemoryManager().getCurrentWorkspace());
                arrayList.add(createInt3);
                int i6 = i4;
                i4++;
                jArr[i6] = AtomicAllocator.getInstance().getPointer(createInt3).address();
            }
            i5 = max2;
        } while (i5 > 1);
        AtomicAllocator.getInstance().memcpyBlocking(createDouble, new LongPointer(jArr), jArr.length * 8, 0L);
        put.put(2L, AtomicAllocator.getInstance().getPointer(createDouble));
        DataBuffer createInt4 = Nd4j.getMemoryManager().getCurrentWorkspace() == null ? Nd4j.getDataBufferFactory().createInt(length, true) : Nd4j.getDataBufferFactory().createInt(length, true, Nd4j.getMemoryManager().getCurrentWorkspace());
        NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP2Int(put, (IntPointer) AtomicAllocator.getInstance().getPointer(createInt), length, (IntPointer) AtomicAllocator.getInstance().getPointer(createInt4));
        AtomicAllocator.getInstance().getAllocationPoint(createInt4).tickDeviceWrite();
        NativeOpsHolder.getInstance().getDeviceNativeOps().encodeThresholdP3Float(put, (FloatPointer) AtomicAllocator.getInstance().getPointer(dataBuffer), (IntPointer) AtomicAllocator.getInstance().getPointer(createInt4), dataBuffer.length(), (IntPointer) AtomicAllocator.getInstance().getPointer(createInt2));
        AtomicAllocator.getInstance().getAllocationPoint(createInt2).tickDeviceWrite();
        AtomicAllocator.getInstance().getAllocationPoint(dataBuffer).tickDeviceWrite();
        put.address();
        createDouble.address();
        return createInt2;
    }

    @Override // org.nd4j.compression.impl.AbstractCompressor
    protected CompressedDataBuffer compressPointer(DataBuffer.TypeEx typeEx, Pointer pointer, int i, int i2) {
        throw new UnsupportedOperationException();
    }

    public float getThreshold() {
        return this.threshold;
    }

    public void setThreshold(float f) {
        this.threshold = f;
    }
}
