package org.deeplearning4j.models.featuredetectors.rbm;

import org.apache.commons.math3.random.MersenneTwister;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.optimize.optimizers.NeuralNetworkOptimizer;
import org.deeplearning4j.optimize.optimizers.rbm.RBMOptimizer;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.RBMUtil;
import org.nd4j.linalg.api.activation.Activations;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.sampling.Sampling;

/* loaded from: input_file:org/deeplearning4j/models/featuredetectors/rbm/RBM.class */
public class RBM extends BaseNeuralNetwork {
    private static final long serialVersionUID = 6189188205731511957L;
    protected NeuralNetworkOptimizer optimizer;
    protected INDArray sigma;
    protected INDArray hiddenSigma;

    /* loaded from: input_file:org/deeplearning4j/models/featuredetectors/rbm/RBM$Builder.class */
    public static class Builder extends BaseNeuralNetwork.Builder<RBM> {
        public Builder() {
            this.clazz = RBM.class;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        public RBM buildEmpty() {
            return (RBM) super.buildEmpty();
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        public BaseNeuralNetwork.Builder<RBM> withClazz(Class<? extends BaseNeuralNetwork> cls) {
            super.withClazz(cls);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        /* renamed from: withInput, reason: merged with bridge method [inline-methods] */
        public BaseNeuralNetwork.Builder<RBM> withInput2(INDArray iNDArray) {
            super.withInput2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        /* renamed from: asType, reason: merged with bridge method [inline-methods] */
        public BaseNeuralNetwork.Builder<RBM> asType2(Class<RBM> cls) {
            super.asType2((Class) cls);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        /* renamed from: withWeights, reason: merged with bridge method [inline-methods] */
        public BaseNeuralNetwork.Builder<RBM> withWeights2(INDArray iNDArray) {
            super.withWeights2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        /* renamed from: withVisibleBias, reason: merged with bridge method [inline-methods] */
        public BaseNeuralNetwork.Builder<RBM> withVisibleBias2(INDArray iNDArray) {
            super.withVisibleBias2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        /* renamed from: withHBias, reason: merged with bridge method [inline-methods] */
        public BaseNeuralNetwork.Builder<RBM> withHBias2(INDArray iNDArray) {
            super.withHBias2(iNDArray);
            return this;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        public RBM build() {
            return (RBM) super.build();
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        /* renamed from: withClazz, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseNeuralNetwork.Builder<RBM> withClazz2(Class cls) {
            return withClazz((Class<? extends BaseNeuralNetwork>) cls);
        }
    }

    /* loaded from: input_file:org/deeplearning4j/models/featuredetectors/rbm/RBM$HiddenUnit.class */
    public enum HiddenUnit {
        RECTIFIED,
        BINARY,
        GAUSSIAN,
        SOFTMAX
    }

    /* loaded from: input_file:org/deeplearning4j/models/featuredetectors/rbm/RBM$VisibleUnit.class */
    public enum VisibleUnit {
        BINARY,
        GAUSSIAN,
        SOFTMAX,
        LINEAR
    }

    protected RBM() {
    }

    public RBM(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, NeuralNetConfiguration neuralNetConfiguration) {
        super(iNDArray, iNDArray2, iNDArray3, iNDArray4, neuralNetConfiguration);
    }

    public void contrastiveDivergence(double d, int i, INDArray iNDArray) {
        if (iNDArray != null) {
            this.input = iNDArray;
        }
        this.lastMiniBatchSize = iNDArray.rows();
        NeuralNetworkGradient gradient = getGradient(new Object[]{Integer.valueOf(i), Double.valueOf(d), -1});
        gradient.getwGradient().norm2(Integer.MAX_VALUE).getDouble(0);
        getW().addi(gradient.getwGradient());
        gethBias().addi(gradient.gethBiasGradient());
        getvBias().addi(gradient.getvBiasGradient());
    }

    public void contrastiveDivergence(double d, int i, INDArray iNDArray, int i2) {
        if (iNDArray != null) {
            this.input = iNDArray;
        }
        this.lastMiniBatchSize = iNDArray.rows();
        NeuralNetworkGradient gradient = getGradient(new Object[]{Integer.valueOf(i), Double.valueOf(d), Integer.valueOf(i2)});
        getW().addi(gradient.getwGradient());
        gethBias().addi(gradient.gethBiasGradient());
        getvBias().addi(gradient.getvBiasGradient());
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public NeuralNetworkGradient getGradient(Object[] objArr) {
        int k = this.conf.getK();
        float lr = this.conf.getLr();
        int intValue = objArr[objArr.length - 1] == null ? 0 : ((Integer) objArr[objArr.length - 1]).intValue();
        if (this.wAdaGrad != null) {
            this.wAdaGrad.setMasterStepSize(lr);
        }
        if (this.hBiasAdaGrad != null) {
            this.hBiasAdaGrad.setMasterStepSize(lr);
        }
        if (this.vBiasAdaGrad != null) {
            this.vBiasAdaGrad.setMasterStepSize(lr);
        }
        Pair<INDArray, INDArray> sampleHiddenGivenVisible = sampleHiddenGivenVisible(this.input);
        INDArray second = sampleHiddenGivenVisible.getSecond();
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        INDArray iNDArray3 = null;
        int i = 0;
        while (i < k) {
            Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> gibbhVh = i == 0 ? gibbhVh(second) : gibbhVh(iNDArray3);
            gibbhVh.getFirst().getFirst();
            iNDArray = gibbhVh.getFirst().getSecond();
            iNDArray2 = gibbhVh.getSecond().getFirst();
            iNDArray3 = gibbhVh.getSecond().getSecond();
            i++;
        }
        NeuralNetworkGradient neuralNetworkGradient = new NeuralNetworkGradient(this.input.transpose().mmul(sampleHiddenGivenVisible.getSecond()).sub(iNDArray.transpose().mmul(iNDArray2)), this.input.sub(iNDArray).mean(0), this.conf.getSparsity() != 0.0f ? sampleHiddenGivenVisible.getSecond().rsubi(Float.valueOf(this.conf.getSparsity())).mean(0) : sampleHiddenGivenVisible.getSecond().sub(iNDArray2).mean(0));
        updateGradientAccordingToParams(neuralNetworkGradient, intValue, lr);
        return neuralNetworkGradient;
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
        fit(iNDArray, null);
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork, org.deeplearning4j.nn.api.NeuralNetwork
    public NeuralNetwork transpose() {
        RBM rbm = (RBM) super.transpose();
        HiddenUnit inverse = RBMUtil.inverse(this.conf.getVisibleUnit());
        VisibleUnit inverse2 = RBMUtil.inverse(this.conf.getHiddenUnit());
        if (inverse == null) {
            this.conf.getHiddenUnit();
        }
        if (inverse2 == null) {
            this.conf.getVisibleUnit();
        }
        rbm.sigma = this.sigma;
        rbm.hiddenSigma = this.hiddenSigma;
        return rbm;
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public NeuralNetwork mo11clone() {
        RBM rbm = (RBM) super.mo11clone();
        rbm.sigma = this.sigma;
        rbm.hiddenSigma = this.hiddenSigma;
        return rbm;
    }

    public double freeEnergy(INDArray iNDArray) {
        INDArray addiRowVector = iNDArray.mmul(this.W).addiRowVector(this.hBias);
        return (-((Double) Transforms.log(Transforms.exp(addiRowVector).add(1)).sum(Integer.MAX_VALUE).element()).doubleValue()) - Nd4j.getBlasWrapper().dot(iNDArray, this.vBias);
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray iNDArray) {
        if (this.conf.getHiddenUnit() == HiddenUnit.RECTIFIED) {
            INDArray propUp = propUp(iNDArray);
            INDArray sigmoid = Transforms.sigmoid(propUp);
            MersenneTwister mersenneTwister = new MersenneTwister(123);
            INDArray sqrt = Transforms.sqrt(sigmoid);
            INDArray normal = Sampling.normal(mersenneTwister, propUp, 1.0d);
            normal.muli(sqrt);
            INDArray max = Transforms.max(propUp.add(normal));
            applyDropOutIfNecessary(max);
            return new Pair<>(propUp, max);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.GAUSSIAN) {
            INDArray propUp2 = propUp(iNDArray);
            this.hiddenSigma = propUp2.var(1);
            INDArray addi = propUp2.addi(Sampling.normal(this.conf.getRng(), propUp2, this.hiddenSigma));
            applyDropOutIfNecessary(addi);
            return new Pair<>(propUp2, addi);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.SOFTMAX) {
            INDArray propUp3 = propUp(iNDArray);
            INDArray iNDArray2 = (INDArray) Activations.softMaxRows().apply(propUp3);
            applyDropOutIfNecessary(iNDArray2);
            return new Pair<>(propUp3, iNDArray2);
        }
        if (this.conf.getHiddenUnit() != HiddenUnit.BINARY) {
            throw new IllegalStateException("Hidden unit type must either be rectified linear or binary");
        }
        INDArray propUp4 = propUp(iNDArray);
        INDArray binomial = Sampling.binomial(propUp4, 1, this.conf.getRng());
        applyDropOutIfNecessary(binomial);
        return new Pair<>(propUp4, binomial);
    }

    public Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> gibbhVh(INDArray iNDArray) {
        Pair<INDArray, INDArray> sampleVisibleGivenHidden = sampleVisibleGivenHidden(iNDArray);
        return new Pair<>(sampleVisibleGivenHidden, sampleHiddenGivenVisible(sampleVisibleGivenHidden.getSecond()));
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray iNDArray) {
        INDArray propDown = propDown(iNDArray);
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            return new Pair<>(propDown, propDown.add(Nd4j.randn(propDown.rows(), propDown.columns(), this.conf.getRng())));
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.LINEAR) {
            return new Pair<>(propDown, Sampling.normal(this.conf.getRng(), propDown, 1.0d));
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.SOFTMAX) {
            return new Pair<>(propDown, (INDArray) Activations.softMaxRows().apply(propDown));
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.BINARY) {
            return new Pair<>(propDown, Sampling.binomial(propDown, 1, this.conf.getRng()));
        }
        throw new IllegalStateException("Visible type must either be binary,gaussian, softmax, or linear");
    }

    public INDArray propUp(INDArray iNDArray) {
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            this.sigma = iNDArray.var(0).divi(Integer.valueOf(this.input.rows()));
        }
        INDArray mmul = iNDArray.mmul(this.W);
        if (this.conf.isConcatBiases()) {
            mmul = Nd4j.hstack(new INDArray[]{mmul, this.hBias});
        } else {
            mmul.addiRowVector(this.hBias);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.RECTIFIED) {
            return Transforms.max(mmul);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.GAUSSIAN) {
            mmul.addi(mmul.add(Nd4j.randn(mmul.rows(), mmul.columns(), this.conf.getRng())));
            return mmul;
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.BINARY) {
            return Transforms.sigmoid(mmul);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.SOFTMAX) {
            return (INDArray) Activations.softMaxRows().apply(mmul);
        }
        throw new IllegalStateException("Hidden unit type should either be binary, gaussian, or rectified linear");
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray hiddenActivation(INDArray iNDArray) {
        return propUp(iNDArray);
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork, org.deeplearning4j.nn.api.NeuralNetwork, org.deeplearning4j.optimize.api.IterationListener
    public void iterationDone(int i) {
        int renderWeightIterations = this.conf.getRenderWeightIterations();
        if (renderWeightIterations <= 0) {
            return;
        }
        if (i % renderWeightIterations == 0 || i == 0) {
            new NeuralNetPlotter().plotNetworkGradient(this, getGradient(new Object[]{1, Double.valueOf(0.001d), 1000}), getInput().rows());
        }
    }

    public INDArray propDown(INDArray iNDArray) {
        INDArray mmul = iNDArray.mmul(this.W.transpose());
        if (this.conf.isConcatBiases()) {
            mmul = Nd4j.hstack(new INDArray[]{mmul, this.vBias});
        } else {
            mmul.addiRowVector(this.vBias);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            mmul.addi(Sampling.normal(this.conf.getRng(), mmul, 1.0d));
            return mmul;
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.LINEAR) {
            return Sampling.normal(this.conf.getRng(), mmul, 1.0d);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.BINARY) {
            return Transforms.sigmoid(mmul);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.SOFTMAX) {
            return (INDArray) Activations.softMaxRows().apply(mmul);
        }
        throw new IllegalStateException("Visible unit type should either be binary or gaussian");
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork, org.deeplearning4j.nn.api.Model
    public INDArray transform(INDArray iNDArray) {
        return propDown(propUp(iNDArray));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, Object[] objArr) {
        if (iNDArray != null) {
            this.input = Transforms.stabilize(iNDArray, 1.0d);
        }
        this.lastMiniBatchSize = iNDArray.rows();
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            this.sigma = iNDArray.var(0);
            this.sigma.divi(Integer.valueOf(iNDArray.rows()));
        }
        this.optimizer = new RBMOptimizer(this, this.conf.getLr(), objArr, this.conf.getOptimizationAlgo(), this.conf.getLossFunction());
        this.optimizer.train(iNDArray);
    }

    public String toString() {
        return "RBM{optimizer=" + this.optimizer + ", visibleType=" + this.conf.getVisibleUnit() + ", hiddenType=" + this.conf.getVisibleUnit() + ", sigma=" + this.sigma + ", hiddenSigma=" + this.hiddenSigma + "} " + super.toString();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void iterate(INDArray iNDArray, Object[] objArr) {
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            this.sigma = iNDArray.var(0).divi(Integer.valueOf(iNDArray.rows()));
        }
        contrastiveDivergence(this.conf.getLr(), ((Integer) objArr[0]).intValue(), iNDArray);
    }
}
