package org.nd4j.linalg.jcublas.gpumetrics;

import java.beans.ConstructorProperties;
import jcuda.driver.CUoccupancyB2DSize;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaDeviceProp;
import jcuda.utils.KernelLauncher;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.kernel.KernelFunctionLoader;
import org.nd4j.linalg.jcublas.util.PointerUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/gpumetrics/GpuMetrics.class */
public class GpuMetrics {
    public static final int MAX_THREADS = 256;
    public static final int MAX_BLOCKS = 64;
    private int gridSize;
    private int blockSize;
    private int sharedMemory;
    private static Logger log = LoggerFactory.getLogger(GpuMetrics.class);
    private static CUoccupancyB2DSize DOUBLE = new CUoccupancyB2DSize() { // from class: org.nd4j.linalg.jcublas.gpumetrics.GpuMetrics.1
        @Override // jcuda.driver.CUoccupancyB2DSize
        public long call(int i) {
            return i * 8;
        }
    };
    private static CUoccupancyB2DSize FLOAT = new CUoccupancyB2DSize() { // from class: org.nd4j.linalg.jcublas.gpumetrics.GpuMetrics.2
        @Override // jcuda.driver.CUoccupancyB2DSize
        public long call(int i) {
            return i * 4;
        }
    };

    public GpuMetrics() {
    }

    public int[] getGpuDefinitionInfo() {
        return new int[]{getBlockSize(), getGridSize(), getSharedMemory(), ContextHolder.getInstance().getCurrentGpuInformation().getMaxSharedMemoryPerBlock()};
    }

    public int getGridSize() {
        return this.gridSize;
    }

    public int getBlockSize() {
        return this.blockSize;
    }

    public int getSharedMemory() {
        return this.sharedMemory;
    }

    public static int[] getThreadsAndBlocks(int i, int i2, int i3) {
        cudaDeviceProp cudadeviceprop = new cudaDeviceProp();
        int[] iArr = new int[1];
        JCuda.cudaGetDevice(iArr);
        JCuda.cudaGetDeviceProperties(cudadeviceprop, iArr[0]);
        int nextPow2 = i < i2 * 2 ? PointerUtil.nextPow2((i + 1) / 2) : i2;
        int i4 = (i + ((nextPow2 * 2) - 1)) / (nextPow2 * 2);
        if (nextPow2 * i4 > cudadeviceprop.maxGridSize[0] * cudadeviceprop.maxThreadsPerBlock) {
            throw new IllegalStateException("n is too large, please choose a smaller number!\n");
        }
        if (i4 > cudadeviceprop.maxGridSize[0]) {
            log.warn("Grid size <%d> exceeds the device capability <%d>, set block size as %d (original %d)\n", new Object[]{Integer.valueOf(i4), Integer.valueOf(cudadeviceprop.maxGridSize[0]), Integer.valueOf(nextPow2 * 2), Integer.valueOf(nextPow2)});
            i4 /= 2;
            nextPow2 *= 2;
        }
        return new int[]{nextPow2, Math.min(i3, i4)};
    }

    public static GpuMetrics blockAndThreads(String str, int i) {
        int i2 = str.equals("double") ? 8 : 4;
        int[] threadsAndBlocks = getThreadsAndBlocks(i, MAX_THREADS, 64);
        return new GpuMetrics(threadsAndBlocks[0], threadsAndBlocks[1], threadsAndBlocks[0] <= 32 ? 2 * threadsAndBlocks[0] * i2 : threadsAndBlocks[0] * i2);
    }

    public static GpuMetrics blocksAndThreadsOccupancy(String str, String str2, int i) {
        int[] iArr = new int[1];
        int[] iArr2 = new int[1];
        KernelLauncher launcher = KernelFunctionLoader.launcher(str, str2);
        JCudaDriver.cuOccupancyMaxPotentialBlockSize(iArr, iArr2, launcher.getFunction(), str2.equals("float") ? FLOAT : DOUBLE, 0L, 0);
        int i2 = ((i + iArr2[0]) - 1) / iArr2[0];
        int i3 = iArr2[0];
        if (i3 > i) {
            i3 = i;
        }
        int maxThreadsPerBlock = ContextHolder.getInstance().getCurrentGpuInformation().getMaxThreadsPerBlock();
        if (i3 > maxThreadsPerBlock) {
            i3 = maxThreadsPerBlock;
        }
        int maxGrimDimX = ContextHolder.getInstance().getCurrentGpuInformation().getMaxGrimDimX();
        if (i2 > maxGrimDimX) {
            i2 = maxGrimDimX;
        }
        int maxSharedMemoryPerBlock = ContextHolder.getInstance().getCurrentGpuInformation().getMaxSharedMemoryPerBlock();
        int i4 = i3 * (str2.equals("float") ? 4 : 8);
        if (i4 > maxSharedMemoryPerBlock) {
            i4 = maxSharedMemoryPerBlock;
        }
        return new GpuMetrics(i2, i3, i4);
    }

    public void validate() {
        int maxThreadsPerBlock = ContextHolder.getInstance().getCurrentGpuInformation().getMaxThreadsPerBlock();
        int maxBlockDimx = ContextHolder.getInstance().getCurrentGpuInformation().getMaxBlockDimx();
        int maxSharedMemoryPerBlock = ContextHolder.getInstance().getCurrentGpuInformation().getMaxSharedMemoryPerBlock();
        if (this.gridSize > maxThreadsPerBlock) {
            throw new IllegalArgumentException("Maximum grid size is " + maxThreadsPerBlock + " but was specified as " + this.gridSize);
        }
        if (this.blockSize > maxBlockDimx) {
            throw new IllegalArgumentException("Maximum block size is " + maxBlockDimx + " but was specified as " + this.blockSize);
        }
        if (this.sharedMemory > maxSharedMemoryPerBlock) {
            throw new IllegalArgumentException("Maximum shared memory size per block is " + maxSharedMemoryPerBlock + " but was specified as " + this.sharedMemory);
        }
    }

    public void setSharedMemoryNotOverMax(int i) {
        setSharedMemory(Math.min(i, 1024));
    }

    public void setGridSizeNotOverMax(int i) {
        setGridSize(Math.min(i, ContextHolder.getInstance().getCurrentGpuInformation().getMaxThreadsPerBlock()));
    }

    public void setBlockSizeNotOverMax(int i) {
        setBlockSize(Math.min(i, ContextHolder.getInstance().getCurrentGpuInformation().getMaxBlockDimx()));
    }

    public void setGridSize(int i) {
        this.gridSize = i;
    }

    public void setBlockSize(int i) {
        this.blockSize = i;
    }

    public void setSharedMemory(int i) {
        this.sharedMemory = i;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof GpuMetrics)) {
            return false;
        }
        GpuMetrics gpuMetrics = (GpuMetrics) obj;
        return gpuMetrics.canEqual(this) && getGridSize() == gpuMetrics.getGridSize() && getBlockSize() == gpuMetrics.getBlockSize() && getSharedMemory() == gpuMetrics.getSharedMemory();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof GpuMetrics;
    }

    public int hashCode() {
        return (((((1 * 59) + getGridSize()) * 59) + getBlockSize()) * 59) + getSharedMemory();
    }

    public String toString() {
        return "GpuMetrics(gridSize=" + getGridSize() + ", blockSize=" + getBlockSize() + ", sharedMemory=" + getSharedMemory() + ")";
    }

    @ConstructorProperties({"gridSize", "blockSize", "sharedMemory"})
    public GpuMetrics(int i, int i2, int i3) {
        this.gridSize = i;
        this.blockSize = i2;
        this.sharedMemory = i3;
    }
}
