package org.deeplearning4j.nn.params;

import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/ConvolutionParamInitializer.class */
public class ConvolutionParamInitializer implements ParamInitializer {
    public static final String WEIGHT_KEY = "W";
    public static final String BIAS_KEY = "b";

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public int numParams(NeuralNetConfiguration neuralNetConfiguration, boolean z) {
        ConvolutionLayer convolutionLayer = (ConvolutionLayer) neuralNetConfiguration.getLayer();
        int[] kernelSize = convolutionLayer.getKernelSize();
        int nIn = convolutionLayer.getNIn();
        int nOut = convolutionLayer.getNOut();
        return (nIn * nOut * kernelSize[0] * kernelSize[1]) + nOut;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public void init(Map<String, INDArray> map, NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        if (((ConvolutionLayer) neuralNetConfiguration.getLayer()).getKernelSize().length != 2) {
            throw new IllegalArgumentException("Filter size must be == 2");
        }
        ConvolutionLayer convolutionLayer = (ConvolutionLayer) neuralNetConfiguration.getLayer();
        convolutionLayer.getKernelSize();
        convolutionLayer.getNIn();
        int nOut = convolutionLayer.getNOut();
        INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut)});
        INDArray iNDArray3 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(nOut, numParams(neuralNetConfiguration, true))});
        map.put("b", createBias(neuralNetConfiguration, iNDArray2));
        map.put("W", createWeightMatrix(neuralNetConfiguration, iNDArray3));
        neuralNetConfiguration.addVariable("W");
        neuralNetConfiguration.addVariable("b");
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        ConvolutionLayer convolutionLayer = (ConvolutionLayer) neuralNetConfiguration.getLayer();
        int[] kernelSize = convolutionLayer.getKernelSize();
        int nIn = convolutionLayer.getNIn();
        int nOut = convolutionLayer.getNOut();
        INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, nOut)});
        INDArray reshape = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(nOut, numParams(neuralNetConfiguration, true))}).reshape('c', new int[]{nOut, nIn, kernelSize[0], kernelSize[1]});
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("b", iNDArray2);
        linkedHashMap.put("W", reshape);
        return linkedHashMap;
    }

    protected INDArray createBias(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        iNDArray.assign(Double.valueOf(((ConvolutionLayer) neuralNetConfiguration.getLayer()).getBiasInit()));
        return iNDArray;
    }

    protected INDArray createWeightMatrix(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        ConvolutionLayer convolutionLayer = (ConvolutionLayer) neuralNetConfiguration.getLayer();
        Distribution createDistribution = Distributions.createDistribution(neuralNetConfiguration.getLayer().getDist());
        int[] kernelSize = convolutionLayer.getKernelSize();
        return WeightInitUtil.initWeights(new int[]{convolutionLayer.getNOut(), convolutionLayer.getNIn(), kernelSize[0], kernelSize[1]}, convolutionLayer.getWeightInit(), createDistribution, 'c', iNDArray);
    }
}
