package org.bigml.mimir.deepnet.layers.twod;

import org.bigml.mimir.math.gpu.Device;
import org.bigml.mimir.math.gpu.GPUExecutable;
import org.bigml.mimir.math.gpu.GPURunnable;
import org.bigml.mimir.math.gpu.Program;

/* loaded from: input_file:org/bigml/mimir/deepnet/layers/twod/AbstractConvolution2D.class */
public abstract class AbstractConvolution2D extends AbstractPool2D implements GPURunnable {
    protected float[] _filters;
    protected float[] _biases;
    protected int[] _kernelShape;
    protected Program.Type _programType;
    protected transient GPUExecutable[] _kernels;
    private static final long serialVersionUID = 1;

    public AbstractConvolution2D(int[] iArr, int i, boolean z) {
        super(iArr, i, z);
        this._programType = null;
    }

    public AbstractConvolution2D(AbstractConvolution2D abstractConvolution2D) {
        this(abstractConvolution2D.getStrides(), abstractConvolution2D.getFiltersShape()[3], abstractConvolution2D.isSamePadding());
        this._biases = abstractConvolution2D.getBiases();
        this._filters = abstractConvolution2D.getFilters();
        this._kernelShape = abstractConvolution2D.getFiltersShape();
        this._filterH = abstractConvolution2D._filterH;
        this._filterW = abstractConvolution2D._filterW;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float[] unrollFilters(double[][][][] dArr) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        int length3 = dArr[0][0].length;
        int length4 = dArr[0][0][0].length;
        float[] fArr = new float[length4 * length * length2 * length3];
        int i = 0;
        for (int i2 = 0; i2 < length4; i2++) {
            for (double[][][] dArr2 : dArr) {
                for (int i3 = 0; i3 < length2; i3++) {
                    for (int i4 = 0; i4 < length3; i4++) {
                        fArr[i] = (float) dArr2[i3][i4][i2];
                        i++;
                    }
                }
            }
        }
        return fArr;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.AbstractPool2D, org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] initialize(int[] iArr) {
        int[] initialize = super.initialize(iArr);
        this._kernels = GPUExecutable.createGPUKernels(this);
        return initialize;
    }

    @Override // org.bigml.mimir.math.gpu.GPURunnable
    public float[] run(float[] fArr, int i) {
        float[] fArr2 = fArr;
        float[] fArr3 = this._output.get();
        if (this._paddedInput != null) {
            fArr2 = this._paddedInput.copyToPad(fArr2);
        }
        if (i < 0 || i >= Device.numberOfDevices() || this._kernels[i] == null) {
            int i2 = 0;
            int i3 = this._hLimit * this._rowLength;
            int i4 = this._wLimit * this._inputChannels;
            int i5 = this._strideH * this._rowLength;
            int i6 = this._strideW * this._inputChannels;
            int i7 = 0;
            while (true) {
                int i8 = i7;
                if (i8 >= i3) {
                    break;
                }
                int i9 = 0;
                while (true) {
                    int i10 = i9;
                    if (i10 < i4) {
                        kernelsForPixel(fArr2, i8, i10, fArr3, i2);
                        i2 += this._outputDepth;
                        i9 = i10 + i6;
                    }
                }
                i7 = i8 + i5;
            }
        } else {
            fArr3 = this._kernels[i].execute(fArr2, fArr3);
        }
        return fArr3;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.AbstractPool2D, org.bigml.mimir.deepnet.layers.twod.Layer2D
    public float[] propagate2D(float[] fArr) {
        return run(fArr, -1);
    }

    @Override // org.bigml.mimir.math.gpu.GPURunnable
    public Program.Type getProgramType() {
        return this._programType;
    }

    public int[] getFiltersShape() {
        return this._kernelShape;
    }

    public float[] getFilters() {
        return this._filters;
    }

    public float[] getBiases() {
        return this._biases;
    }

    protected abstract void kernelsForPixel(float[] fArr, int i, int i2, float[] fArr2, int i3);
}
