package org.deeplearning4j.nn.params;

import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.class */
public class GravesBidirectionalLSTMParamInitializer implements ParamInitializer {
    private static final GravesBidirectionalLSTMParamInitializer INSTANCE = new GravesBidirectionalLSTMParamInitializer();
    public static final String INPUT_WEIGHT_KEY_FORWARDS = "WF";
    public static final String INPUT_WEIGHT_KEY_BACKWARDS = "WB";
    public static final String RECURRENT_WEIGHT_KEY_FORWARDS = "RWF";
    public static final String RECURRENT_WEIGHT_KEY_BACKWARDS = "RWB";
    private static final List<String> WEIGHT_KEYS = Collections.unmodifiableList(Arrays.asList(INPUT_WEIGHT_KEY_FORWARDS, INPUT_WEIGHT_KEY_BACKWARDS, RECURRENT_WEIGHT_KEY_FORWARDS, RECURRENT_WEIGHT_KEY_BACKWARDS));
    public static final String BIAS_KEY_FORWARDS = "bF";
    public static final String BIAS_KEY_BACKWARDS = "bB";
    private static final List<String> BIAS_KEYS = Collections.unmodifiableList(Arrays.asList(BIAS_KEY_FORWARDS, BIAS_KEY_BACKWARDS));
    private static final List<String> ALL_PARAM_KEYS = Collections.unmodifiableList(Arrays.asList(INPUT_WEIGHT_KEY_FORWARDS, INPUT_WEIGHT_KEY_BACKWARDS, RECURRENT_WEIGHT_KEY_FORWARDS, RECURRENT_WEIGHT_KEY_BACKWARDS, BIAS_KEY_FORWARDS, BIAS_KEY_BACKWARDS));

    public static GravesBidirectionalLSTMParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(NeuralNetConfiguration neuralNetConfiguration) {
        return numParams(neuralNetConfiguration.getLayer());
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(Layer layer) {
        GravesBidirectionalLSTM gravesBidirectionalLSTM = (GravesBidirectionalLSTM) layer;
        long nOut = gravesBidirectionalLSTM.getNOut();
        return 2 * ((gravesBidirectionalLSTM.getNIn() * 4 * nOut) + (nOut * ((4 * nOut) + 3)) + (4 * nOut));
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> paramKeys(Layer layer) {
        return ALL_PARAM_KEYS;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> weightKeys(Layer layer) {
        return WEIGHT_KEYS;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public List<String> biasKeys(Layer layer) {
        return BIAS_KEYS;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isWeightParam(Layer layer, String str) {
        return RECURRENT_WEIGHT_KEY_FORWARDS.equals(str) || INPUT_WEIGHT_KEY_FORWARDS.equals(str) || RECURRENT_WEIGHT_KEY_BACKWARDS.equals(str) || INPUT_WEIGHT_KEY_BACKWARDS.equals(str);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public boolean isBiasParam(Layer layer, String str) {
        return BIAS_KEY_FORWARDS.equals(str) || BIAS_KEY_BACKWARDS.equals(str);
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        Map<String, INDArray> synchronizedMap = Collections.synchronizedMap(new LinkedHashMap());
        GravesBidirectionalLSTM gravesBidirectionalLSTM = (GravesBidirectionalLSTM) neuralNetConfiguration.getLayer();
        double forgetGateBiasInit = gravesBidirectionalLSTM.getForgetGateBiasInit();
        long nOut = gravesBidirectionalLSTM.getNOut();
        long nIn = gravesBidirectionalLSTM.getNIn();
        neuralNetConfiguration.addVariable(INPUT_WEIGHT_KEY_FORWARDS);
        neuralNetConfiguration.addVariable(RECURRENT_WEIGHT_KEY_FORWARDS);
        neuralNetConfiguration.addVariable(BIAS_KEY_FORWARDS);
        neuralNetConfiguration.addVariable(INPUT_WEIGHT_KEY_BACKWARDS);
        neuralNetConfiguration.addVariable(RECURRENT_WEIGHT_KEY_BACKWARDS);
        neuralNetConfiguration.addVariable(BIAS_KEY_BACKWARDS);
        long j = nIn * 4 * nOut;
        long j2 = nOut * ((4 * nOut) + 3);
        long j3 = 4 * nOut;
        long j4 = j + j2;
        long j5 = j4 + j3;
        long j6 = j5 + j;
        long j7 = j6 + j2;
        INDArray iNDArray2 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(0L, j));
        INDArray iNDArray3 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j, j4));
        INDArray iNDArray4 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j4, j5));
        INDArray iNDArray5 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j5, j6));
        INDArray iNDArray6 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j6, j7));
        INDArray iNDArray7 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j7, j7 + j3));
        if (z) {
            iNDArray4.put(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(nOut, 2 * nOut)}, Nd4j.ones(1, nOut).muli(Double.valueOf(forgetGateBiasInit)));
            iNDArray7.put(new INDArrayIndex[]{NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(nOut, 2 * nOut)}, Nd4j.ones(1, nOut).muli(Double.valueOf(forgetGateBiasInit)));
        }
        if (z) {
            long j8 = nIn + nOut;
            long[] jArr = {nIn, 4 * nOut};
            long[] jArr2 = {nOut, (4 * nOut) + 3};
            synchronizedMap.put(INPUT_WEIGHT_KEY_FORWARDS, gravesBidirectionalLSTM.getWeightInitFn().init(nOut, j8, jArr, 'f', iNDArray2));
            synchronizedMap.put(RECURRENT_WEIGHT_KEY_FORWARDS, gravesBidirectionalLSTM.getWeightInitFn().init(nOut, j8, jArr2, 'f', iNDArray3));
            synchronizedMap.put(BIAS_KEY_FORWARDS, iNDArray4);
            synchronizedMap.put(INPUT_WEIGHT_KEY_BACKWARDS, gravesBidirectionalLSTM.getWeightInitFn().init(nOut, j8, jArr, 'f', iNDArray5));
            synchronizedMap.put(RECURRENT_WEIGHT_KEY_BACKWARDS, gravesBidirectionalLSTM.getWeightInitFn().init(nOut, j8, jArr2, 'f', iNDArray6));
            synchronizedMap.put(BIAS_KEY_BACKWARDS, iNDArray7);
        } else {
            synchronizedMap.put(INPUT_WEIGHT_KEY_FORWARDS, WeightInitUtil.reshapeWeights(new long[]{nIn, 4 * nOut}, iNDArray2));
            synchronizedMap.put(RECURRENT_WEIGHT_KEY_FORWARDS, WeightInitUtil.reshapeWeights(new long[]{nOut, (4 * nOut) + 3}, iNDArray3));
            synchronizedMap.put(BIAS_KEY_FORWARDS, iNDArray4);
            synchronizedMap.put(INPUT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.reshapeWeights(new long[]{nIn, 4 * nOut}, iNDArray5));
            synchronizedMap.put(RECURRENT_WEIGHT_KEY_BACKWARDS, WeightInitUtil.reshapeWeights(new long[]{nOut, (4 * nOut) + 3}, iNDArray6));
            synchronizedMap.put(BIAS_KEY_BACKWARDS, iNDArray7);
        }
        return synchronizedMap;
    }

    @Override // org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        GravesBidirectionalLSTM gravesBidirectionalLSTM = (GravesBidirectionalLSTM) neuralNetConfiguration.getLayer();
        long nOut = gravesBidirectionalLSTM.getNOut();
        long nIn = gravesBidirectionalLSTM.getNIn();
        long j = nIn * 4 * nOut;
        long j2 = nOut * ((4 * nOut) + 3);
        long j3 = 4 * nOut;
        long j4 = j + j2;
        long j5 = j4 + j3;
        long j6 = j5 + j;
        long j7 = j6 + j2;
        INDArray reshape = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(0L, j)).reshape('f', nIn, 4 * nOut);
        INDArray reshape2 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j, j4)).reshape('f', nOut, (4 * nOut) + 3);
        INDArray iNDArray2 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j4, j5));
        INDArray reshape3 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j5, j6)).reshape('f', nIn, 4 * nOut);
        INDArray reshape4 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j6, j7)).reshape('f', nOut, (4 * nOut) + 3);
        INDArray iNDArray3 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(j7, j7 + j3));
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(INPUT_WEIGHT_KEY_FORWARDS, reshape);
        linkedHashMap.put(RECURRENT_WEIGHT_KEY_FORWARDS, reshape2);
        linkedHashMap.put(BIAS_KEY_FORWARDS, iNDArray2);
        linkedHashMap.put(INPUT_WEIGHT_KEY_BACKWARDS, reshape3);
        linkedHashMap.put(RECURRENT_WEIGHT_KEY_BACKWARDS, reshape4);
        linkedHashMap.put(BIAS_KEY_BACKWARDS, iNDArray3);
        return linkedHashMap;
    }
}
