package org.deeplearning4j.nn.params;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.Distributions;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/PReLUParamInitializer.class */
public class PReLUParamInitializer implements ParamInitializer {
    public static final String WEIGHT_KEY = "W";
    private long[] weightShape;
    private long[] sharedAxes;

    public PReLUParamInitializer(long[] jArr, long[] jArr2) {
        this.weightShape = jArr;
        this.sharedAxes = jArr2;
        if (jArr2 != null) {
            for (long j : jArr2) {
                this.weightShape[((int) j) - 1] = 1;
            }
        }
    }

    public static PReLUParamInitializer getInstance(long[] jArr, long[] jArr2) {
        return new PReLUParamInitializer(jArr, jArr2);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(NeuralNetConfiguration neuralNetConfiguration) {
        return numParams(neuralNetConfiguration.getLayer());
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(Layer layer) {
        return numParams(this.weightShape);
    }

    private long numParams(long[] jArr) {
        long j = 1;
        for (long j2 : jArr) {
            j *= j2;
        }
        return j;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> paramKeys(Layer layer) {
        return weightKeys(layer);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> weightKeys(Layer layer) {
        return Collections.singletonList("W");
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> biasKeys(Layer layer) {
        return Collections.emptyList();
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isWeightParam(Layer layer, String str) {
        return "W".equals(str);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isBiasParam(Layer layer, String str) {
        return false;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        if (!(neuralNetConfiguration.getLayer() instanceof FeedForwardLayer)) {
            throw new IllegalArgumentException("unsupported layer type: " + neuralNetConfiguration.getLayer().getClass().getName());
        }
        Map<String, INDArray> synchronizedMap = Collections.synchronizedMap(new LinkedHashMap());
        long numParams = numParams(neuralNetConfiguration);
        if (iNDArray.length() != numParams) {
            throw new IllegalStateException("Expected params view of length " + numParams + ", got length " + iNDArray.length());
        }
        synchronizedMap.put("W", createWeightMatrix(neuralNetConfiguration, iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(0L, numParams)), z));
        neuralNetConfiguration.addVariable("W");
        return synchronizedMap;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        INDArray reshape = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(0L, numParams(neuralNetConfiguration))).reshape('f', this.weightShape);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("W", reshape);
        return linkedHashMap;
    }

    protected INDArray createWeightMatrix(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        FeedForwardLayer feedForwardLayer = (FeedForwardLayer) neuralNetConfiguration.getLayer();
        if (!z) {
            return WeightInitUtil.reshapeWeights(this.weightShape, iNDArray);
        }
        return WeightInitUtil.initWeights(feedForwardLayer.getNIn(), feedForwardLayer.getNOut(), this.weightShape, feedForwardLayer.getWeightInit(), Distributions.createDistribution(feedForwardLayer.getDist()), iNDArray);
    }
}
