package org.bigml.mimir.math.gpu;

import java.util.HashMap;
import java.util.Map;
import org.bigml.mimir.utils.ResourceLoader;
import org.jocl.BuildProgramFunction;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import org.jocl.cl_context;
import org.jocl.cl_device_id;
import org.jocl.cl_kernel;
import org.jocl.cl_mem;
import org.jocl.cl_program;

/* loaded from: input_file:org/bigml/mimir/math/gpu/Program.class */
public class Program {
    private Device _device;
    private cl_program _program;
    public static long READ_CPY = 36;
    public static long READ_WRT = 1;
    private static final Map<Type, String> sourcePaths = new HashMap();

    /* loaded from: input_file:org/bigml/mimir/math/gpu/Program$Type.class */
    public enum Type {
        CONV_2D_16,
        CONV_2D_4,
        CONV_BLOCK_2D_16,
        CONV_BLOCK_2D_4,
        DEPTH_2D
    }

    public Program(Type type, Device device) {
        this._device = device;
        this._program = CL.clCreateProgramWithSource(this._device.getContext(), 1, new String[]{readProgram(type)}, (long[]) null, (int[]) null);
        CL.clBuildProgram(this._program, 0, (cl_device_id[]) null, (String) null, (BuildProgramFunction) null, (Object) null);
    }

    public cl_mem setArg(cl_kernel cl_kernelVar, int i, Pointer pointer, long j) {
        cl_context context = this._device.getContext();
        long j2 = READ_CPY;
        if (pointer == null) {
            j2 = READ_WRT;
        }
        cl_mem clCreateBuffer = CL.clCreateBuffer(context, j2, j, pointer, (int[]) null);
        CL.clSetKernelArg(cl_kernelVar, i, Sizeof.cl_mem, Pointer.to(clCreateBuffer));
        return clCreateBuffer;
    }

    public cl_program getProgram() {
        return this._program;
    }

    public Device getDevice() {
        return this._device;
    }

    private static String readProgram(Type type) {
        return ResourceLoader.stringForFile(sourcePaths.get(type));
    }

    static {
        sourcePaths.put(Type.DEPTH_2D, "depthwise_convolution_2d.cl");
        sourcePaths.put(Type.CONV_BLOCK_2D_16, "convolution_block_2d_16.cl");
        sourcePaths.put(Type.CONV_BLOCK_2D_4, "convolution_block_2d_4.cl");
        sourcePaths.put(Type.CONV_2D_16, "convolution_2d_16.cl");
        sourcePaths.put(Type.CONV_2D_4, "convolution_2d_4.cl");
    }
}
