package org.bigml.mimir.math.gpu;

import java.util.Arrays;
import org.bigml.mimir.deepnet.layers.twod.AbstractConvolution2D;
import org.bigml.mimir.deepnet.layers.twod.DepthwiseConvolution2D;
import org.jocl.CL;
import org.jocl.Pointer;
import org.jocl.cl_mem;

/* loaded from: input_file:org/bigml/mimir/math/gpu/DepthwiseConvolution2DKernel.class */
public class DepthwiseConvolution2DKernel extends KernelFunction {
    private static final String KERNEL_FN = "convolve";

    public DepthwiseConvolution2DKernel(Device device, AbstractConvolution2D abstractConvolution2D) {
        super(device.getProgram(abstractConvolution2D.getProgramType()), KERNEL_FN, abstractConvolution2D);
    }

    @Override // org.bigml.mimir.math.gpu.KernelFunction
    protected void setArgsAndSize(Program program, Object obj) {
        DepthwiseConvolution2D depthwiseConvolution2D = (DepthwiseConvolution2D) obj;
        float[] filters = depthwiseConvolution2D.getFilters();
        float[] biases = depthwiseConvolution2D.getBiases();
        int[] paddedInputShape = depthwiseConvolution2D.getPaddedInputShape();
        int[] outputShape = depthwiseConvolution2D.getOutputShape();
        int[] filtersShape = depthwiseConvolution2D.getFiltersShape();
        int[] strides = depthwiseConvolution2D.getStrides();
        if (filtersShape[0] != 3 || filtersShape[1] != 3) {
            throw new IllegalArgumentException("Depthwise kernel can only be 3 x 3 but is " + Arrays.toString(filtersShape));
        }
        long j = outputShape[2];
        long maxLocalWorkSize = this._program.getDevice().getMaxLocalWorkSize();
        if (j > maxLocalWorkSize) {
            j = findChunkSize(j, maxLocalWorkSize);
        }
        this._gSize = new long[]{outputShape[0] * outputShape[1] * outputShape[2]};
        this._lSize = new long[]{j};
        this._kernel.addArg(2, filters);
        this._kernel.addArg(3, biases);
        this._kernel.addArg(4, paddedInputShape);
        this._kernel.addArg(5, outputShape);
        this._kernel.addArg(6, filtersShape);
        this._kernel.addArg(7, strides);
    }

    @Override // org.bigml.mimir.math.gpu.KernelFunction, org.bigml.mimir.math.gpu.GPUExecutable
    public float[] execute(float[] fArr, float[] fArr2) {
        cl_mem readArg = this._kernel.setReadArg(0, fArr);
        runKernel(this._program.getDevice().getQueue(), this._kernel.get(), 4 * fArr2.length, Pointer.to(fArr2), this._kernel.setWriteArg(1, fArr2));
        CL.clReleaseMemObject(readArg);
        return fArr2;
    }
}
