package org.deeplearning4j.nn.conf.layers.variational;

import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/variational/GaussianReconstructionDistribution.class */
public class GaussianReconstructionDistribution implements ReconstructionDistribution {
    private static final double NEG_HALF_LOG_2PI = (-0.5d) * Math.log(6.283185307179586d);
    private final IActivation activationFn;

    public GaussianReconstructionDistribution() {
        this(Activation.IDENTITY);
    }

    public GaussianReconstructionDistribution(Activation activation) {
        this(activation.getActivationFunction());
    }

    public GaussianReconstructionDistribution(IActivation iActivation) {
        this.activationFn = iActivation;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public boolean hasLossFunction() {
        return false;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public int distributionInputSize(int i) {
        return 2 * i;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public double negLogProbability(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        long size = iNDArray2.size(1) / 2;
        INDArray[] calcLogProbArrayExConstants = calcLogProbArrayExConstants(iNDArray, iNDArray2);
        double size2 = (((iNDArray.size(0) * size) * NEG_HALF_LOG_2PI) - (0.5d * calcLogProbArrayExConstants[0].sumNumber().doubleValue())) - calcLogProbArrayExConstants[1].sumNumber().doubleValue();
        return z ? (-size2) / iNDArray.size(0) : -size2;
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray exampleNegLogProbability(INDArray iNDArray, INDArray iNDArray2) {
        long size = iNDArray2.size(1) / 2;
        INDArray[] calcLogProbArrayExConstants = calcLogProbArrayExConstants(iNDArray, iNDArray2);
        return calcLogProbArrayExConstants[0].sum(true, 1).muli(Double.valueOf(0.5d)).subi(Double.valueOf(size * NEG_HALF_LOG_2PI)).addi(calcLogProbArrayExConstants[1].sum(true, 1));
    }

    private INDArray[] calcLogProbArrayExConstants(INDArray iNDArray, INDArray iNDArray2) {
        INDArray dup = iNDArray2.dup();
        this.activationFn.getActivation(dup, false);
        long size = dup.size(1) / 2;
        INDArray iNDArray3 = dup.get(NDArrayIndex.all(), NDArrayIndex.interval(0L, size));
        INDArray iNDArray4 = dup.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size));
        INDArray exp = Transforms.exp(iNDArray4, true);
        INDArray sub = iNDArray.sub(iNDArray3.castTo(iNDArray.dataType()));
        sub.muli(sub);
        sub.divi(exp.castTo(sub.dataType())).divi((Number) 2);
        return new INDArray[]{iNDArray4, sub};
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray gradient(INDArray iNDArray, INDArray iNDArray2) {
        INDArray dup = iNDArray2.dup();
        this.activationFn.getActivation(dup, true);
        long size = dup.size(1) / 2;
        INDArray iNDArray3 = dup.get(NDArrayIndex.all(), NDArrayIndex.interval(0L, size));
        INDArray castTo = Transforms.exp(dup.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)), true).castTo(iNDArray.dataType());
        INDArray sub = iNDArray.sub(iNDArray3.castTo(iNDArray.dataType()));
        INDArray mul = sub.mul(sub);
        INDArray divi = sub.divi(castTo);
        INDArray sqrt = Transforms.sqrt(castTo, true);
        INDArray muli = sqrt.divi((Number) 2).muli(sqrt.rdiv((Number) (-1)).addi(mul.divi(Transforms.pow(castTo, Double.valueOf(1.5d)))));
        INDArray createUninitialized = Nd4j.createUninitialized(iNDArray2.dataType(), dup.shape());
        createUninitialized.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0L, size)}, divi);
        createUninitialized.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}, muli);
        createUninitialized.negi();
        return this.activationFn.backprop(iNDArray2.dup(), createUninitialized).getFirst();
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray generateRandom(INDArray iNDArray) {
        INDArray dup = iNDArray.dup();
        this.activationFn.getActivation(dup, true);
        long size = dup.size(1) / 2;
        INDArray iNDArray2 = dup.get(NDArrayIndex.all(), NDArrayIndex.interval(0L, size));
        INDArray exp = Transforms.exp(dup.get(NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)), true);
        Transforms.sqrt(exp, false);
        return Nd4j.randn(exp.shape()).muli(exp).addi(iNDArray2);
    }

    @Override // org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution
    public INDArray generateAtMean(INDArray iNDArray) {
        INDArray dup = iNDArray.get(NDArrayIndex.all(), NDArrayIndex.interval(0L, iNDArray.size(1) / 2)).dup();
        this.activationFn.getActivation(dup, false);
        return dup;
    }

    public String toString() {
        return "GaussianReconstructionDistribution(afn=" + this.activationFn + ")";
    }

    public IActivation getActivationFn() {
        return this.activationFn;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof GaussianReconstructionDistribution)) {
            return false;
        }
        GaussianReconstructionDistribution gaussianReconstructionDistribution = (GaussianReconstructionDistribution) obj;
        if (!gaussianReconstructionDistribution.canEqual(this)) {
            return false;
        }
        IActivation activationFn = getActivationFn();
        IActivation activationFn2 = gaussianReconstructionDistribution.getActivationFn();
        return activationFn == null ? activationFn2 == null : activationFn.equals(activationFn2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof GaussianReconstructionDistribution;
    }

    public int hashCode() {
        IActivation activationFn = getActivationFn();
        return (1 * 59) + (activationFn == null ? 43 : activationFn.hashCode());
    }
}
