package org.deeplearning4j.nn.layers.convolution;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.class */
public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layers.ConvolutionLayer> {
    public ConvolutionLayer(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public ConvolutionLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getLayer().getL2() <= 0.0d) {
            return 0.0d;
        }
        double doubleValue = getParam("W").norm2Number().doubleValue();
        return 0.5d * this.conf.getLayer().getL2() * doubleValue * doubleValue;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getLayer().getL1() <= 0.0d) {
            return 0.0d;
        }
        return this.conf.getLayer().getL1() * getParam("W").norm1Number().doubleValue();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.CONVOLUTIONAL;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        INDArray muli;
        INDArray param = getParam("W");
        int size = this.input.size(0);
        int size2 = this.input.size(2);
        int size3 = this.input.size(3);
        int size4 = param.size(0);
        int size5 = param.size(1);
        int size6 = param.size(2);
        int size7 = param.size(3);
        int[] kernelSize = layerConf().getKernelSize();
        int[] stride = layerConf().getStride();
        int[] padding = layerConf().getPadding();
        int outSize = Convolution.outSize(size2, kernelSize[0], stride[0], padding[0], false);
        int outSize2 = Convolution.outSize(size3, kernelSize[1], stride[1], padding[1], false);
        String activationFunction = this.conf.getLayer().getActivationFunction();
        if ("identity".equals(activationFunction)) {
            muli = iNDArray;
        } else {
            INDArray preOutput = preOutput(true);
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(activationFunction, preOutput, this.conf.getExtraArgs()).derivative());
            muli = preOutput.muli(iNDArray);
        }
        INDArray reshape = muli.permute(new int[]{1, 0, 2, 3}).reshape('c', new int[]{size4, size * outSize * outSize2});
        INDArray create = Nd4j.create(new int[]{size, outSize, outSize2, size5, size6, size7}, 'c');
        Convolution.im2col(this.input, size6, size7, stride[0], stride[1], padding[0], padding[1], false, create.permute(new int[]{0, 3, 4, 5, 1, 2}));
        INDArray newShapeNoCopy = Shape.newShapeNoCopy(Nd4j.gemm(create.reshape('c', size * outSize * outSize2, size5 * size6 * size7), reshape, true, true).transpose(), new int[]{size4, size5, size6, size7}, false);
        INDArray permute = Shape.newShapeNoCopy(param.permute(new int[]{3, 2, 1, 0}).reshape('f', size5 * size6 * size7, size4).mmul(reshape), new int[]{size7, size6, size5, outSize2, outSize, size}, true).permute(new int[]{5, 2, 1, 0, 4, 3});
        INDArray permute2 = Nd4j.create(new int[]{size5, size, size2, size3}, 'c').permute(new int[]{1, 0, 2, 3});
        Convolution.col2im(permute, permute2, stride[0], stride[1], padding[0], padding[1], size2, size3);
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.setGradientFor("b", reshape.sum(new int[]{1}));
        defaultGradient.setGradientFor("W", newShapeNoCopy, 'c');
        return new Pair<>(defaultGradient, permute2);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer
    public INDArray preOutput(boolean z) {
        INDArray param = getParam("W");
        INDArray param2 = getParam("b");
        if (this.conf.isUseDropConnect() && z && this.conf.getLayer().getDropOut() > 0.0d) {
            param = Dropout.applyDropConnect(this, "W");
        }
        int size = this.input.size(0);
        int size2 = this.input.size(2);
        int size3 = this.input.size(3);
        int size4 = param.size(0);
        int size5 = param.size(1);
        int size6 = param.size(2);
        int size7 = param.size(3);
        int[] kernelSize = layerConf().getKernelSize();
        int[] stride = layerConf().getStride();
        int[] padding = layerConf().getPadding();
        int outSize = Convolution.outSize(size2, kernelSize[0], stride[0], padding[0], false);
        int outSize2 = Convolution.outSize(size3, kernelSize[1], stride[1], padding[1], false);
        INDArray create = Nd4j.create(new int[]{size, outSize, outSize2, size5, size6, size7}, 'c');
        Convolution.im2col(this.input, size6, size7, stride[0], stride[1], padding[0], padding[1], false, create.permute(new int[]{0, 3, 4, 5, 1, 2}));
        INDArray mmul = Shape.newShapeNoCopy(create, new int[]{size * outSize * outSize2, size5 * size6 * size7}, false).mmul(param.permute(new int[]{3, 2, 1, 0}).reshape('f', size7 * size6 * size5, size4));
        mmul.addiRowVector(param2);
        return Shape.newShapeNoCopy(mmul, new int[]{outSize2, outSize, size, size4}, true).permute(new int[]{2, 3, 1, 0});
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        if (this.input == null) {
            throw new IllegalArgumentException("No null input allowed");
        }
        applyDropOutIfNecessary(z);
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), preOutput(z)));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public void merge(Layer layer, int i) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public INDArray params() {
        return Nd4j.toFlattened('c', this.params.values());
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        setParams(iNDArray, 'c');
    }
}
