package org.deeplearning4j.nn.params;

import com.clearspring.analytics.stream.frequency.CountMinSketch;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.class */
public class VariationalAutoencoderParamInitializer extends DefaultParamInitializer {
    private static final VariationalAutoencoderParamInitializer INSTANCE = new VariationalAutoencoderParamInitializer();
    public static final String WEIGHT_KEY_SUFFIX = "W";
    public static final String BIAS_KEY_SUFFIX = "b";
    public static final String PZX_PREFIX = "pZX";
    public static final String PZX_MEAN_PREFIX = "pZXMean";
    public static final String PZX_LOGSTD2_PREFIX = "pZXLogStd2";
    public static final String ENCODER_PREFIX = "e";
    public static final String DECODER_PREFIX = "d";
    public static final String PZX_MEAN_W = "pZXMeanW";
    public static final String PZX_MEAN_B = "pZXMeanb";
    public static final String PZX_LOGSTD2_W = "pZXLogStd2W";
    public static final String PZX_LOGSTD2_B = "pZXLogStd2b";
    public static final String PXZ_PREFIX = "pXZ";
    public static final String PXZ_W = "pXZW";
    public static final String PXZ_B = "pXZb";

    public static VariationalAutoencoderParamInitializer getInstance() {
        return INSTANCE;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public long numParams(NeuralNetConfiguration neuralNetConfiguration) {
        VariationalAutoencoder variationalAutoencoder = (VariationalAutoencoder) neuralNetConfiguration.getLayer();
        long nIn = variationalAutoencoder.getNIn();
        long nOut = variationalAutoencoder.getNOut();
        int[] encoderLayerSizes = variationalAutoencoder.getEncoderLayerSizes();
        int[] decoderLayerSizes = variationalAutoencoder.getDecoderLayerSizes();
        int i = 0;
        int i2 = 0;
        while (i2 < encoderLayerSizes.length) {
            i = (int) (i + (((i2 == 0 ? nIn : encoderLayerSizes[i2 - 1]) + 1) * encoderLayerSizes[i2]));
            i2++;
        }
        int i3 = (int) (i + ((encoderLayerSizes[encoderLayerSizes.length - 1] + 1) * 2 * nOut));
        int i4 = 0;
        while (i4 < decoderLayerSizes.length) {
            i3 = (int) (i3 + (((i4 == 0 ? nOut : decoderLayerSizes[i4 - 1]) + 1) * decoderLayerSizes[i4]));
            i4++;
        }
        if (nIn > CountMinSketch.PRIME_MODULUS) {
            throw new ND4JArraySizeException();
        }
        return i3 + ((decoderLayerSizes[decoderLayerSizes.length - 1] + 1) * variationalAutoencoder.getOutputDistribution().distributionInputSize((int) nIn));
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public List<String> paramKeys(Layer layer) {
        VariationalAutoencoder variationalAutoencoder = (VariationalAutoencoder) layer;
        int[] encoderLayerSizes = variationalAutoencoder.getEncoderLayerSizes();
        int[] decoderLayerSizes = variationalAutoencoder.getDecoderLayerSizes();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < encoderLayerSizes.length; i++) {
            String str = ENCODER_PREFIX + i + "W";
            String str2 = ENCODER_PREFIX + i + "b";
            arrayList.add(str);
            arrayList.add(str2);
        }
        arrayList.add(PZX_MEAN_W);
        arrayList.add(PZX_MEAN_B);
        arrayList.add(PZX_LOGSTD2_W);
        arrayList.add(PZX_LOGSTD2_B);
        for (int i2 = 0; i2 < decoderLayerSizes.length; i2++) {
            String str3 = DECODER_PREFIX + i2 + "W";
            String str4 = DECODER_PREFIX + i2 + "b";
            arrayList.add(str3);
            arrayList.add(str4);
        }
        arrayList.add(PXZ_W);
        arrayList.add(PXZ_B);
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public List<String> weightKeys(Layer layer) {
        ArrayList arrayList = new ArrayList();
        for (String str : paramKeys(layer)) {
            if (isWeightParam(layer, str)) {
                arrayList.add(str);
            }
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public List<String> biasKeys(Layer layer) {
        ArrayList arrayList = new ArrayList();
        for (String str : paramKeys(layer)) {
            if (isBiasParam(layer, str)) {
                arrayList.add(str);
            }
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public boolean isWeightParam(Layer layer, String str) {
        return str.endsWith("W");
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public boolean isBiasParam(Layer layer, String str) {
        return str.endsWith("b");
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> init(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, boolean z) {
        if (iNDArray.length() != numParams(neuralNetConfiguration)) {
            throw new IllegalArgumentException("Incorrect paramsView length: Expected length " + numParams(neuralNetConfiguration) + ", got length " + iNDArray.length());
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        VariationalAutoencoder variationalAutoencoder = (VariationalAutoencoder) neuralNetConfiguration.getLayer();
        long nIn = variationalAutoencoder.getNIn();
        long nOut = variationalAutoencoder.getNOut();
        int[] encoderLayerSizes = variationalAutoencoder.getEncoderLayerSizes();
        int[] decoderLayerSizes = variationalAutoencoder.getDecoderLayerSizes();
        IWeightInit weightInitFn = variationalAutoencoder.getWeightInitFn();
        int i = 0;
        int i2 = 0;
        while (i2 < encoderLayerSizes.length) {
            long j = i2 == 0 ? nIn : encoderLayerSizes[i2 - 1];
            long j2 = j * encoderLayerSizes[i2];
            INDArray iNDArray2 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i, i + j2));
            int i3 = (int) (i + j2);
            INDArray iNDArray3 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i3, i3 + encoderLayerSizes[i2]));
            i = i3 + encoderLayerSizes[i2];
            INDArray createWeightMatrix = createWeightMatrix(j, encoderLayerSizes[i2], weightInitFn, iNDArray2, z);
            INDArray createBias = createBias(encoderLayerSizes[i2], 0.0d, iNDArray3, z);
            String str = ENCODER_PREFIX + i2 + "W";
            String str2 = ENCODER_PREFIX + i2 + "b";
            linkedHashMap.put(str, createWeightMatrix);
            linkedHashMap.put(str2, createBias);
            neuralNetConfiguration.addVariable(str);
            neuralNetConfiguration.addVariable(str2);
            i2++;
        }
        long j3 = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
        INDArray iNDArray4 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i, i + j3));
        int i4 = (int) (i + j3);
        INDArray iNDArray5 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i4, i4 + nOut));
        int i5 = (int) (i4 + nOut);
        INDArray createWeightMatrix2 = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, weightInitFn, iNDArray4, z);
        INDArray createBias2 = createBias(nOut, 0.0d, iNDArray5, z);
        linkedHashMap.put(PZX_MEAN_W, createWeightMatrix2);
        linkedHashMap.put(PZX_MEAN_B, createBias2);
        neuralNetConfiguration.addVariable(PZX_MEAN_W);
        neuralNetConfiguration.addVariable(PZX_MEAN_B);
        INDArray iNDArray6 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i5, i5 + j3));
        int i6 = (int) (i5 + j3);
        INDArray iNDArray7 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i6, i6 + nOut));
        int i7 = (int) (i6 + nOut);
        INDArray createWeightMatrix3 = createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, weightInitFn, iNDArray6, z);
        INDArray createBias3 = createBias(nOut, 0.0d, iNDArray7, z);
        linkedHashMap.put(PZX_LOGSTD2_W, createWeightMatrix3);
        linkedHashMap.put(PZX_LOGSTD2_B, createBias3);
        neuralNetConfiguration.addVariable(PZX_LOGSTD2_W);
        neuralNetConfiguration.addVariable(PZX_LOGSTD2_B);
        int i8 = 0;
        while (i8 < decoderLayerSizes.length) {
            long j4 = i8 == 0 ? nOut : decoderLayerSizes[i8 - 1];
            long j5 = j4 * decoderLayerSizes[i8];
            INDArray iNDArray8 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i7, i7 + j5));
            int i9 = (int) (i7 + j5);
            INDArray iNDArray9 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i9, i9 + decoderLayerSizes[i8]));
            i7 = i9 + decoderLayerSizes[i8];
            INDArray createWeightMatrix4 = createWeightMatrix(j4, decoderLayerSizes[i8], weightInitFn, iNDArray8, z);
            INDArray createBias4 = createBias(decoderLayerSizes[i8], 0.0d, iNDArray9, z);
            String str3 = DECODER_PREFIX + i8 + "W";
            String str4 = DECODER_PREFIX + i8 + "b";
            linkedHashMap.put(str3, createWeightMatrix4);
            linkedHashMap.put(str4, createBias4);
            neuralNetConfiguration.addVariable(str3);
            neuralNetConfiguration.addVariable(str4);
            i8++;
        }
        if (nIn > CountMinSketch.PRIME_MODULUS) {
            throw new ND4JArraySizeException();
        }
        int distributionInputSize = variationalAutoencoder.getOutputDistribution().distributionInputSize((int) nIn);
        int i10 = decoderLayerSizes[decoderLayerSizes.length - 1] * distributionInputSize;
        INDArray iNDArray10 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i7, i7 + i10));
        int i11 = i7 + i10;
        INDArray iNDArray11 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i11, i11 + distributionInputSize));
        INDArray createWeightMatrix5 = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], distributionInputSize, weightInitFn, iNDArray10, z);
        INDArray createBias5 = createBias(distributionInputSize, 0.0d, iNDArray11, z);
        linkedHashMap.put(PXZ_W, createWeightMatrix5);
        linkedHashMap.put(PXZ_B, createBias5);
        neuralNetConfiguration.addVariable(PXZ_W);
        neuralNetConfiguration.addVariable(PXZ_B);
        return linkedHashMap;
    }

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        VariationalAutoencoder variationalAutoencoder = (VariationalAutoencoder) neuralNetConfiguration.getLayer();
        long nIn = variationalAutoencoder.getNIn();
        long nOut = variationalAutoencoder.getNOut();
        int[] encoderLayerSizes = variationalAutoencoder.getEncoderLayerSizes();
        int[] decoderLayerSizes = variationalAutoencoder.getDecoderLayerSizes();
        int i = 0;
        int i2 = 0;
        while (i2 < encoderLayerSizes.length) {
            long j = i2 == 0 ? nIn : encoderLayerSizes[i2 - 1];
            long j2 = j * encoderLayerSizes[i2];
            INDArray iNDArray2 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i, i + j2));
            int i3 = (int) (i + j2);
            INDArray iNDArray3 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i3, i3 + encoderLayerSizes[i2]));
            i = i3 + encoderLayerSizes[i2];
            linkedHashMap.put(ENCODER_PREFIX + i2 + "W", iNDArray2.reshape('f', j, encoderLayerSizes[i2]));
            linkedHashMap.put(ENCODER_PREFIX + i2 + "b", iNDArray3);
            i2++;
        }
        long j3 = encoderLayerSizes[encoderLayerSizes.length - 1] * nOut;
        INDArray iNDArray4 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i, i + j3));
        int i4 = (int) (i + j3);
        INDArray iNDArray5 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i4, i4 + nOut));
        int i5 = (int) (i4 + nOut);
        linkedHashMap.put(PZX_MEAN_W, iNDArray4.reshape('f', encoderLayerSizes[encoderLayerSizes.length - 1], nOut));
        linkedHashMap.put(PZX_MEAN_B, iNDArray5);
        INDArray iNDArray6 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i5, i5 + j3));
        int i6 = (int) (i5 + j3);
        INDArray iNDArray7 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i6, i6 + nOut));
        int i7 = (int) (i6 + nOut);
        linkedHashMap.put(PZX_LOGSTD2_W, createWeightMatrix(encoderLayerSizes[encoderLayerSizes.length - 1], nOut, null, iNDArray6, false));
        linkedHashMap.put(PZX_LOGSTD2_B, iNDArray7);
        int i8 = 0;
        while (i8 < decoderLayerSizes.length) {
            long j4 = i8 == 0 ? nOut : decoderLayerSizes[i8 - 1];
            long j5 = j4 * decoderLayerSizes[i8];
            INDArray iNDArray8 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i7, i7 + j5));
            int i9 = (int) (i7 + j5);
            INDArray iNDArray9 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i9, i9 + decoderLayerSizes[i8]));
            i7 = i9 + decoderLayerSizes[i8];
            INDArray createWeightMatrix = createWeightMatrix(j4, decoderLayerSizes[i8], null, iNDArray8, false);
            INDArray createBias = createBias(decoderLayerSizes[i8], 0.0d, iNDArray9, false);
            String str = DECODER_PREFIX + i8 + "W";
            String str2 = DECODER_PREFIX + i8 + "b";
            linkedHashMap.put(str, createWeightMatrix);
            linkedHashMap.put(str2, createBias);
            i8++;
        }
        if (nIn > CountMinSketch.PRIME_MODULUS) {
            throw new ND4JArraySizeException();
        }
        int distributionInputSize = variationalAutoencoder.getOutputDistribution().distributionInputSize((int) nIn);
        int i10 = decoderLayerSizes[decoderLayerSizes.length - 1] * distributionInputSize;
        INDArray iNDArray10 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i7, i7 + i10));
        int i11 = i7 + i10;
        INDArray iNDArray11 = iNDArray.get(NDArrayIndex.interval(0L, 0L, true), NDArrayIndex.interval(i11, i11 + distributionInputSize));
        INDArray createWeightMatrix2 = createWeightMatrix(decoderLayerSizes[decoderLayerSizes.length - 1], distributionInputSize, null, iNDArray10, false);
        INDArray createBias2 = createBias(distributionInputSize, 0.0d, iNDArray11, false);
        linkedHashMap.put(PXZ_W, createWeightMatrix2);
        linkedHashMap.put(PXZ_B, createBias2);
        return linkedHashMap;
    }
}
