package org.bigml.mimir.deepnet.layers.twod;

import org.bigml.mimir.deepnet.layers.Activation;
import org.bigml.mimir.deepnet.layers.BatchNormalize;
import org.bigml.mimir.math.gpu.ConvolutionBlock2DKernel;
import org.bigml.mimir.math.gpu.Device;

/* loaded from: input_file:org/bigml/mimir/deepnet/layers/twod/ConvolutionBlock2D.class */
public class ConvolutionBlock2D extends AbstractConvolution2D {
    private float[] _mean;
    private float[] _stdev;
    private float[] _gamma;
    private float[] _beta;
    private Activation.ActivationFn _afn;
    private static final long serialVersionUID = 1;

    public ConvolutionBlock2D(Convolution2D convolution2D, BatchNormalize batchNormalize, Activation activation) {
        super(convolution2D);
        this._programType = ConvolutionBlock2DKernel.getProgramType(this._kernelShape[2]);
        this._mean = batchNormalize.getMean();
        this._stdev = batchNormalize.getStDev();
        this._beta = batchNormalize.getBeta();
        this._gamma = batchNormalize.getGamma();
        this._afn = activation.getFunction();
        this._index = activation.getIndex();
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.AbstractConvolution2D, org.bigml.mimir.math.gpu.GPURunnable
    public float[] run(float[] fArr, int i) {
        float[] run = super.run(fArr, i);
        return (i < 0 || i >= Device.numberOfDevices()) ? Activation.activate(run, this._afn) : run;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.AbstractConvolution2D
    protected void kernelsForPixel(float[] fArr, int i, int i2, float[] fArr2, int i3) {
        int i4 = 0;
        for (int i5 = 0; i5 < this._biases.length; i5++) {
            int i6 = i;
            float f = this._biases[i5];
            for (int i7 = 0; i7 < this._filterH; i7++) {
                int i8 = i6 + i2;
                for (int i9 = 0; i9 < this._filterW; i9++) {
                    for (int i10 = 0; i10 < this._inputChannels; i10++) {
                        f += fArr[i8] * this._filters[i4];
                        i8++;
                        i4++;
                    }
                }
                i6 += this._rowLength;
            }
            fArr2[i3 + i5] = (this._gamma[i5] * ((f - this._mean[i5]) / this._stdev[i5])) + this._beta[i5];
        }
    }

    public float[] getMean() {
        return this._mean;
    }

    public float[] getStDev() {
        return this._stdev;
    }

    public float[] getBeta() {
        return this._beta;
    }

    public float[] getGamma() {
        return this._gamma;
    }

    public Activation.ActivationFn getFunction() {
        return this._afn;
    }
}
