package org.bigml.mimir.math.gpu;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.bigml.mimir.deepnet.layers.twod.AbstractConvolution2D;
import org.bigml.mimir.deepnet.layers.twod.Convolution2D;
import org.bigml.mimir.math.gpu.Program;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.cl_command_queue;
import org.jocl.cl_mem;

/* loaded from: input_file:org/bigml/mimir/math/gpu/Convolution2DKernel.class */
public class Convolution2DKernel extends KernelFunction {
    protected int[] _batchParams;
    private static final String KERNEL_FN = "convolve";
    private static final Map<Integer, Program.Type> PROGRAMS_FOR_SHAPE = new HashMap();
    private static final List<Integer> SHAPE_DIVISORS;

    public static Program.Type getProgramType(int i) {
        Iterator<Integer> it = SHAPE_DIVISORS.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (i % intValue == 0) {
                return PROGRAMS_FOR_SHAPE.get(Integer.valueOf(intValue));
            }
        }
        return null;
    }

    public Convolution2DKernel(Device device, AbstractConvolution2D abstractConvolution2D) {
        super(device.getProgram(abstractConvolution2D.getProgramType()), KERNEL_FN, abstractConvolution2D);
    }

    @Override // org.bigml.mimir.math.gpu.KernelFunction
    protected void setArgsAndSize(Program program, Object obj) {
        Convolution2D convolution2D = (Convolution2D) obj;
        float[] filters = convolution2D.getFilters();
        float[] biases = convolution2D.getBiases();
        int[] paddedInputShape = convolution2D.getPaddedInputShape();
        int[] outputShape = convolution2D.getOutputShape();
        int[] filtersShape = convolution2D.getFiltersShape();
        int[] strides = convolution2D.getStrides();
        long j = outputShape[0] * outputShape[1] * outputShape[2];
        long j2 = outputShape[2];
        long maxLocalWorkSize = this._program.getDevice().getMaxLocalWorkSize();
        if (j2 > maxLocalWorkSize) {
            j2 = findChunkSize(j2, maxLocalWorkSize);
        }
        long j3 = filtersShape[0] * filtersShape[1];
        long maxLocalMemory = this._program.getDevice().getMaxLocalMemory() / (4 * j3);
        int i = paddedInputShape[2];
        int i2 = 1;
        if (i > maxLocalMemory) {
            i = (int) findChunkSize(i, maxLocalMemory);
            i2 = paddedInputShape[2] / i;
        }
        this._batchParams = new int[]{i2, i};
        this._gSize = new long[]{j * i2};
        this._lSize = new long[]{j2};
        this._kernel.addArg(2, filters);
        this._kernel.addArg(3, biases);
        this._kernel.addArg(4, paddedInputShape);
        this._kernel.addArg(5, outputShape);
        this._kernel.addArg(6, filtersShape);
        this._kernel.addArg(7, strides);
        this._kernel.addArg(8, this._batchParams);
        this._kernel.addArg(9, j3 * i);
    }

    @Override // org.bigml.mimir.math.gpu.KernelFunction, org.bigml.mimir.math.gpu.GPUExecutable
    public float[] execute(float[] fArr, float[] fArr2) {
        cl_command_queue queue = this._program.getDevice().getQueue();
        cl_mem readArg = this._kernel.setReadArg(0, fArr);
        if (this._batchParams[0] > 1) {
            float[] fArr3 = new float[fArr2.length * this._batchParams[0]];
            runKernel(queue, this._kernel.get(), 4 * fArr3.length, Pointer.to(fArr3), this._kernel.setWriteArg(1, fArr3));
            for (int i = 0; i < fArr2.length; i++) {
                fArr2[i] = 0.0f;
                int i2 = i * this._batchParams[0];
                int i3 = i2 + this._batchParams[0];
                for (int i4 = i2; i4 < i3; i4++) {
                    int i5 = i;
                    fArr2[i5] = fArr2[i5] + fArr3[i4];
                }
            }
        } else {
            runKernel(queue, this._kernel.get(), 4 * fArr2.length, Pointer.to(fArr2), this._kernel.setWriteArg(1, fArr2));
        }
        CL.clReleaseMemObject(readArg);
        return fArr2;
    }

    static {
        PROGRAMS_FOR_SHAPE.put(16, Program.Type.CONV_2D_16);
        PROGRAMS_FOR_SHAPE.put(4, Program.Type.CONV_2D_4);
        SHAPE_DIVISORS = new ArrayList(PROGRAMS_FOR_SHAPE.size());
        SHAPE_DIVISORS.addAll(PROGRAMS_FOR_SHAPE.keySet());
        Collections.sort(SHAPE_DIVISORS, Collections.reverseOrder());
    }
}
