package org.nd4j.linalg.jcublas.buffer.allocation;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import org.apache.commons.lang3.tuple.Triple;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.jcublas.buffer.DevicePointerInfo;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.util.NioUtil;

/* loaded from: input_file:org/nd4j/linalg/jcublas/buffer/allocation/PinnedMemoryStrategy.class */
public class PinnedMemoryStrategy implements MemoryStrategy {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.nd4j.linalg.jcublas.buffer.allocation.PinnedMemoryStrategy$1, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/linalg/jcublas/buffer/allocation/PinnedMemoryStrategy$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$buffer$DataBuffer$Type = new int[DataBuffer.Type.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataBuffer$Type[DataBuffer.Type.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataBuffer$Type[DataBuffer.Type.INT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataBuffer$Type[DataBuffer.Type.FLOAT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    @Override // org.nd4j.linalg.jcublas.buffer.allocation.MemoryStrategy
    public void getData(DataBuffer dataBuffer, int i, int i2, int i3, DataBuffer dataBuffer2, CudaContext cudaContext, int i4, int i5) {
        dataBuffer.copyAtStride(dataBuffer2, i3, i2, i4, i, i5);
    }

    @Override // org.nd4j.linalg.jcublas.buffer.allocation.MemoryStrategy
    public void getData(DataBuffer dataBuffer, int i, DataBuffer dataBuffer2, CudaContext cudaContext) {
        getData(dataBuffer, i, 1, dataBuffer.length(), dataBuffer2, cudaContext, 1, 0);
    }

    @Override // org.nd4j.linalg.jcublas.buffer.allocation.MemoryStrategy
    public void setData(Pointer pointer, int i, int i2, int i3, Pointer pointer2) {
    }

    @Override // org.nd4j.linalg.jcublas.buffer.allocation.MemoryStrategy
    public void setData(DataBuffer dataBuffer, int i, int i2, int i3) {
    }

    @Override // org.nd4j.linalg.jcublas.buffer.allocation.MemoryStrategy
    public void setData(DataBuffer dataBuffer, int i) {
    }

    @Override // org.nd4j.linalg.jcublas.buffer.allocation.MemoryStrategy
    public Object copyToHost(DataBuffer dataBuffer, int i, CudaContext cudaContext) {
        JCudaBuffer jCudaBuffer = (JCudaBuffer) dataBuffer;
        DevicePointerInfo devicePointerInfo = (DevicePointerInfo) jCudaBuffer.getPointersToContexts().get(Thread.currentThread().getName(), Triple.of(Integer.valueOf(i), Integer.valueOf(jCudaBuffer.length()), 1));
        ByteBuffer order = devicePointerInfo.getPointers().getHostPointer().getByteBuffer(0L, dataBuffer.getElementSize() * dataBuffer.length()).order(ByteOrder.nativeOrder());
        ByteBuffer asNio = dataBuffer.asNio();
        order.flip();
        asNio.put(order);
        return devicePointerInfo;
    }

    @Override // org.nd4j.linalg.jcublas.buffer.allocation.MemoryStrategy
    public Object copyToHost(DataBuffer dataBuffer, int i, int i2, int i3, CudaContext cudaContext, int i4, int i5) {
        ByteBuffer asNio = dataBuffer.asNio();
        DevicePointerInfo devicePointerInfo = (DevicePointerInfo) ((JCudaBuffer) dataBuffer).getPointersToContexts().get(Thread.currentThread().getName(), Triple.of(Integer.valueOf(i), Integer.valueOf(i3), Integer.valueOf(i2)));
        NioUtil.copyAtStride(i3, getBufferType(dataBuffer), devicePointerInfo.getPointers().getHostPointer().getByteBuffer(0L, dataBuffer.length() * dataBuffer.getElementSize()), i, i2, asNio, i4, i5);
        return devicePointerInfo;
    }

    @Override // org.nd4j.linalg.jcublas.buffer.allocation.MemoryStrategy
    public Object alloc(DataBuffer dataBuffer, int i, int i2, int i3, boolean z) {
        ContextHolder.getInstance().setContext();
        Pointer pointer = new Pointer();
        Pointer pointer2 = new Pointer();
        JCuda.cudaHostAlloc(pointer2, dataBuffer.getElementSize() * i3, 2);
        JCuda.cudaHostGetDevicePointer(pointer, pointer2, 0);
        DevicePointerInfo devicePointerInfo = new DevicePointerInfo(new HostDevicePointer(pointer2, pointer), i3, i, i2, false);
        if (z) {
            ByteBuffer byteBuffer = pointer2.getByteBuffer(0L, dataBuffer.getElementSize() * dataBuffer.length());
            byteBuffer.order(ByteOrder.nativeOrder());
            NioUtil.copyAtStride(dataBuffer.length(), getBufferType(dataBuffer), dataBuffer.asNio(), i2, i, byteBuffer, 0, 1);
        }
        return devicePointerInfo;
    }

    private NioUtil.BufferType getBufferType(DataBuffer dataBuffer) {
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataBuffer$Type[dataBuffer.dataType().ordinal()]) {
            case 1:
                return NioUtil.BufferType.DOUBLE;
            case 2:
                return NioUtil.BufferType.FLOAT;
            case 3:
                return NioUtil.BufferType.FLOAT;
            default:
                throw new UnsupportedOperationException("Unsupported data buffer type");
        }
    }

    @Override // org.nd4j.linalg.jcublas.buffer.allocation.MemoryStrategy
    public void free(DataBuffer dataBuffer, int i, int i2) {
        DevicePointerInfo devicePointerInfo = (DevicePointerInfo) ((JCudaBuffer) dataBuffer).getPointersToContexts().get(Thread.currentThread().getName(), Triple.of(Integer.valueOf(i), Integer.valueOf(i2), 1));
        if (devicePointerInfo.isFreed()) {
            return;
        }
        JCuda.cudaFreeHost(devicePointerInfo.getPointers().getDevicePointer());
        devicePointerInfo.setFreed(true);
    }

    @Override // org.nd4j.linalg.jcublas.buffer.allocation.MemoryStrategy
    public void validate(DataBuffer dataBuffer, CudaContext cudaContext) throws Exception {
    }
}
