package org.deeplearning4j.rbm;

import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.NeuralNetworkGradient;
import org.deeplearning4j.optimize.NeuralNetworkOptimizer;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;

/* loaded from: input_file:org/deeplearning4j/rbm/RBM.class */
public class RBM extends BaseNeuralNetwork {
    private static final long serialVersionUID = 6189188205731511957L;
    protected NeuralNetworkOptimizer optimizer;

    /* loaded from: input_file:org/deeplearning4j/rbm/RBM$Builder.class */
    public static class Builder extends BaseNeuralNetwork.Builder<RBM> {
        public Builder() {
            this.clazz = RBM.class;
        }
    }

    public RBM() {
    }

    public RBM(int i, int i2, DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3, RandomGenerator randomGenerator, double d, RealDistribution realDistribution) {
        super(i, i2, doubleMatrix, doubleMatrix2, doubleMatrix3, randomGenerator, d);
    }

    public RBM(DoubleMatrix doubleMatrix, int i, int i2, DoubleMatrix doubleMatrix2, DoubleMatrix doubleMatrix3, DoubleMatrix doubleMatrix4, RandomGenerator randomGenerator, double d, RealDistribution realDistribution) {
        super(doubleMatrix, i, i2, doubleMatrix2, doubleMatrix3, doubleMatrix4, randomGenerator, d, realDistribution);
    }

    public void trainTillConvergence(double d, int i, DoubleMatrix doubleMatrix) {
        if (doubleMatrix != null) {
            this.input = doubleMatrix;
        }
        this.optimizer = new RBMOptimizer(this, d, new Object[]{Integer.valueOf(i), Double.valueOf(d)});
        this.optimizer.train(doubleMatrix);
    }

    public void contrastiveDivergence(double d, int i, DoubleMatrix doubleMatrix) {
        if (doubleMatrix != null) {
            this.input = doubleMatrix;
        }
        NeuralNetworkGradient gradient = getGradient(new Object[]{Integer.valueOf(i), Double.valueOf(d)});
        this.W.addi(gradient.getwGradient());
        this.hBias.addi(gradient.gethBiasGradient());
        this.vBias.addi(gradient.getvBiasGradient());
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public synchronized NeuralNetworkGradient getGradient(Object[] objArr) {
        int intValue = ((Integer) objArr[0]).intValue();
        double doubleValue = ((Double) objArr[1]).doubleValue();
        Pair<DoubleMatrix, DoubleMatrix> sampleHiddenGivenVisible = sampleHiddenGivenVisible(this.input);
        DoubleMatrix second = sampleHiddenGivenVisible.getSecond();
        DoubleMatrix doubleMatrix = null;
        DoubleMatrix doubleMatrix2 = null;
        DoubleMatrix doubleMatrix3 = null;
        int i = 0;
        while (i < intValue) {
            Pair<Pair<DoubleMatrix, DoubleMatrix>, Pair<DoubleMatrix, DoubleMatrix>> gibbhVh = i == 0 ? gibbhVh(second) : gibbhVh(doubleMatrix3);
            gibbhVh.getFirst().getFirst();
            doubleMatrix = gibbhVh.getFirst().getSecond();
            doubleMatrix2 = gibbhVh.getSecond().getFirst();
            doubleMatrix3 = gibbhVh.getSecond().getSecond();
            i++;
        }
        DoubleMatrix mul = this.input.transpose().mmul(sampleHiddenGivenVisible.getSecond()).sub(doubleMatrix.transpose().mmul(doubleMatrix2)).mul(doubleValue);
        if (this.useRegularization) {
            mul.subi(this.W.muli(this.l2));
        }
        if (this.momentum != 0.0d) {
            mul.muli(1.0d - this.momentum);
        }
        mul.divi(this.input.rows);
        return new NeuralNetworkGradient(mul, MatrixUtil.mean(this.input.sub(doubleMatrix), 0).mul(doubleValue), this.sparsity != 0.0d ? MatrixUtil.mean(sampleHiddenGivenVisible.getSecond().add(-this.sparsity), 0).mul(doubleValue) : MatrixUtil.mean(sampleHiddenGivenVisible.getSecond().sub(doubleMatrix2), 0).mul(doubleValue));
    }

    public Pair<DoubleMatrix, DoubleMatrix> sampleHiddenGivenVisible(DoubleMatrix doubleMatrix) {
        DoubleMatrix propUp = propUp(doubleMatrix);
        return new Pair<>(propUp, MatrixUtil.binomial(propUp, 1, this.rng));
    }

    public Pair<Pair<DoubleMatrix, DoubleMatrix>, Pair<DoubleMatrix, DoubleMatrix>> gibbhVh(DoubleMatrix doubleMatrix) {
        Pair<DoubleMatrix, DoubleMatrix> sampleVGivenH = sampleVGivenH(doubleMatrix);
        return new Pair<>(sampleVGivenH, sampleHiddenGivenVisible(sampleVGivenH.getSecond()));
    }

    public Pair<DoubleMatrix, DoubleMatrix> sampleVGivenH(DoubleMatrix doubleMatrix) {
        DoubleMatrix propDown = propDown(doubleMatrix);
        return new Pair<>(propDown, MatrixUtil.binomial(propDown, 1, this.rng));
    }

    public DoubleMatrix propUp(DoubleMatrix doubleMatrix) {
        return MatrixUtil.sigmoid(doubleMatrix.mmul(this.W).addiRowVector(this.hBias));
    }

    public DoubleMatrix propDown(DoubleMatrix doubleMatrix) {
        return MatrixUtil.sigmoid(doubleMatrix.mmul(this.W.transpose()).addRowVector(this.vBias));
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork
    public DoubleMatrix reconstruct(DoubleMatrix doubleMatrix) {
        return propDown(propUp(doubleMatrix));
    }

    @Override // org.deeplearning4j.nn.NeuralNetwork
    public void trainTillConvergence(DoubleMatrix doubleMatrix, double d, Object[] objArr) {
        if (doubleMatrix != null) {
            this.input = doubleMatrix;
        }
        this.optimizer = new RBMOptimizer(this, d, objArr);
        this.optimizer.train(doubleMatrix);
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork
    public double lossFunction(Object[] objArr) {
        return getReConstructionCrossEntropy();
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork, org.deeplearning4j.nn.NeuralNetwork
    public void train(DoubleMatrix doubleMatrix, double d, Object[] objArr) {
        contrastiveDivergence(d, ((Integer) objArr[0]).intValue(), doubleMatrix);
    }
}
