package org.bigml.mimir.deepnet.layers.twod;

import org.bigml.mimir.math.Matrices;
import org.bigml.mimir.math.gpu.GPURunnable;
import org.bigml.mimir.math.gpu.Program;

/* loaded from: input_file:org/bigml/mimir/deepnet/layers/twod/SeparableConvolution2D.class */
public class SeparableConvolution2D implements Layer2D, GPURunnable {
    private int _index;
    private final Layer2D[] _layers;
    private static final long serialVersionUID = 1;

    public SeparableConvolution2D(double[][][][] dArr, double[][][][] dArr2, double[] dArr3, int[] iArr, boolean z) {
        double[] dArr4 = new double[dArr[0][0].length];
        this._layers = new Layer2D[2];
        this._layers[0] = new DepthwiseConvolution2D(dArr, dArr4, iArr, z);
        this._layers[1] = new Convolution2D(dArr2, dArr3, new int[]{1, 1}, z);
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int getIndex() {
        return this._index;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public void setIndex(int i) {
        this._index = i;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] initialize(int[] iArr) {
        return Layer2D.initialize(this._layers, iArr);
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] getOutputShape() {
        return this._layers[1].getOutputShape();
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] getInputShape() {
        return this._layers[0].getInputShape();
    }

    @Override // org.bigml.mimir.math.gpu.GPURunnable
    public Program.Type getProgramType() {
        return null;
    }

    @Override // org.bigml.mimir.math.gpu.GPURunnable
    public float[] run(float[] fArr, int i) {
        return Layer2D.propagate2D(this._layers, fArr, i);
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public float[] propagate2D(float[] fArr) {
        return run(fArr, -1);
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public double[][][] propagateArray(double[][][] dArr) {
        return Matrices.reshape(propagate2D(Matrices.unroll(dArr)), getOutputShape());
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public float[] getLastOutput() {
        return this._layers[1].getLastOutput();
    }
}
