package org.bigml.mimir.math.gpu;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.bigml.mimir.deepnet.layers.Activation;
import org.bigml.mimir.deepnet.layers.twod.ConvolutionBlock2D;
import org.bigml.mimir.math.gpu.Program;

/* loaded from: input_file:org/bigml/mimir/math/gpu/ConvolutionBlock2DKernel.class */
public class ConvolutionBlock2DKernel extends Convolution2DKernel {
    private static final Map<Integer, Program.Type> PROGRAMS_FOR_SHAPE = new HashMap();
    private static final Map<Activation.ActivationFn, Integer> ACTIVATION_CODES;
    private static final List<Integer> SHAPE_DIVISORS;

    public static Program.Type getProgramType(int i) {
        Iterator<Integer> it = SHAPE_DIVISORS.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (i % intValue == 0) {
                return PROGRAMS_FOR_SHAPE.get(Integer.valueOf(intValue));
            }
        }
        return null;
    }

    public ConvolutionBlock2DKernel(Device device, ConvolutionBlock2D convolutionBlock2D) {
        super(device, convolutionBlock2D);
    }

    @Override // org.bigml.mimir.math.gpu.Convolution2DKernel, org.bigml.mimir.math.gpu.KernelFunction
    protected void setArgsAndSize(Program program, Object obj) {
        ConvolutionBlock2D convolutionBlock2D = (ConvolutionBlock2D) obj;
        float[] filters = convolutionBlock2D.getFilters();
        float[] biases = convolutionBlock2D.getBiases();
        int[] paddedInputShape = convolutionBlock2D.getPaddedInputShape();
        int[] outputShape = convolutionBlock2D.getOutputShape();
        int[] filtersShape = convolutionBlock2D.getFiltersShape();
        int[] strides = convolutionBlock2D.getStrides();
        long j = outputShape[0] * outputShape[1] * outputShape[2];
        long j2 = outputShape[2];
        long maxLocalWorkSize = this._program.getDevice().getMaxLocalWorkSize();
        int i = (int) j2;
        float[] fArr = new float[i * 4];
        System.arraycopy(convolutionBlock2D.getMean(), 0, fArr, 0, i);
        System.arraycopy(convolutionBlock2D.getStDev(), 0, fArr, i, i);
        System.arraycopy(convolutionBlock2D.getBeta(), 0, fArr, i * 2, i);
        System.arraycopy(convolutionBlock2D.getGamma(), 0, fArr, i * 3, i);
        int[] iArr = {ACTIVATION_CODES.get(convolutionBlock2D.getFunction()).intValue()};
        if (j2 > maxLocalWorkSize) {
            j2 = findChunkSize(j2, maxLocalWorkSize);
        }
        long j3 = filtersShape[0] * filtersShape[1];
        long maxLocalMemory = this._program.getDevice().getMaxLocalMemory() / (4 * j3);
        int i2 = paddedInputShape[2];
        int i3 = 1;
        if (i2 > maxLocalMemory) {
            i2 = (int) findChunkSize(i2, maxLocalMemory);
            i3 = paddedInputShape[2] / i2;
        }
        this._batchParams = new int[]{i3, i2};
        this._gSize = new long[]{j * i3};
        this._lSize = new long[]{j2};
        this._kernel.addArg(2, filters);
        this._kernel.addArg(3, biases);
        this._kernel.addArg(4, paddedInputShape);
        this._kernel.addArg(5, outputShape);
        this._kernel.addArg(6, filtersShape);
        this._kernel.addArg(7, strides);
        this._kernel.addArg(8, this._batchParams);
        this._kernel.addArg(9, fArr);
        this._kernel.addArg(10, iArr);
        this._kernel.addArg(11, j3 * i2);
    }

    static {
        PROGRAMS_FOR_SHAPE.put(16, Program.Type.CONV_BLOCK_2D_16);
        PROGRAMS_FOR_SHAPE.put(4, Program.Type.CONV_BLOCK_2D_4);
        SHAPE_DIVISORS = new ArrayList(PROGRAMS_FOR_SHAPE.size());
        SHAPE_DIVISORS.addAll(PROGRAMS_FOR_SHAPE.keySet());
        Collections.sort(SHAPE_DIVISORS, Collections.reverseOrder());
        ACTIVATION_CODES = new HashMap();
        ACTIVATION_CODES.put(null, 0);
        ACTIVATION_CODES.put(Activation.ActivationFn.IDENTITY, 0);
        ACTIVATION_CODES.put(Activation.ActivationFn.RELU, 1);
        ACTIVATION_CODES.put(Activation.ActivationFn.RELU6, 2);
        ACTIVATION_CODES.put(Activation.ActivationFn.LEAKY_RELU, 3);
    }
}
