package org.deeplearning4j.models.featuredetectors.da;

import java.io.Serializable;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.optimize.optimizers.da.DenoisingAutoEncoderOptimizer;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/models/featuredetectors/da/DenoisingAutoEncoder.class */
public class DenoisingAutoEncoder extends BaseNeuralNetwork implements Serializable {
    private static final long serialVersionUID = -6445530486350763837L;

    /* loaded from: input_file:org/deeplearning4j/models/featuredetectors/da/DenoisingAutoEncoder$Builder.class */
    public static class Builder extends BaseNeuralNetwork.Builder<DenoisingAutoEncoder> {
        public Builder() {
            this.clazz = DenoisingAutoEncoder.class;
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        public BaseNeuralNetwork.Builder<DenoisingAutoEncoder> withClazz(Class<? extends BaseNeuralNetwork> cls) {
            super.withClazz(cls);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        /* renamed from: withInput */
        public BaseNeuralNetwork.Builder<DenoisingAutoEncoder> withInput2(INDArray iNDArray) {
            super.withInput2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        /* renamed from: withWeights */
        public BaseNeuralNetwork.Builder<DenoisingAutoEncoder> withWeights2(INDArray iNDArray) {
            super.withWeights2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        /* renamed from: withVisibleBias */
        public BaseNeuralNetwork.Builder<DenoisingAutoEncoder> withVisibleBias2(INDArray iNDArray) {
            super.withVisibleBias2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        /* renamed from: withHBias */
        public BaseNeuralNetwork.Builder<DenoisingAutoEncoder> withHBias2(INDArray iNDArray) {
            super.withHBias2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseNeuralNetwork.Builder
        /* renamed from: withClazz, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseNeuralNetwork.Builder<DenoisingAutoEncoder> withClazz2(Class cls) {
            return withClazz((Class<? extends BaseNeuralNetwork>) cls);
        }
    }

    private DenoisingAutoEncoder() {
    }

    public DenoisingAutoEncoder(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, NeuralNetConfiguration neuralNetConfiguration) {
        super(iNDArray, iNDArray2, iNDArray3, iNDArray4, neuralNetConfiguration);
    }

    public INDArray getCorruptedInput(INDArray iNDArray, float f) {
        INDArray zeros = Nd4j.zeros(iNDArray.rows(), iNDArray.columns());
        for (int i = 0; i < iNDArray.rows(); i++) {
            for (int i2 = 0; i2 < iNDArray.columns(); i2++) {
                zeros.put(i, i2, Integer.valueOf(MathUtils.binomial(this.conf.getRng(), 1, 1.0f - f)));
            }
        }
        return zeros.mul(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray iNDArray) {
        INDArray hiddenValues = getHiddenValues(iNDArray);
        return new Pair<>(hiddenValues, hiddenValues);
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray iNDArray) {
        INDArray reconstructedInput = getReconstructedInput(iNDArray);
        return new Pair<>(reconstructedInput, reconstructedInput);
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray hiddenActivation(INDArray iNDArray) {
        return getHiddenValues(iNDArray);
    }

    public INDArray getHiddenValues(INDArray iNDArray) {
        INDArray sigmoid = Transforms.sigmoid(this.conf.isConcatBiases() ? iNDArray.mmul(Nd4j.hstack(new INDArray[]{this.W, this.hBias.transpose()})) : iNDArray.mmul(this.W).addiRowVector(this.hBias));
        applyDropOutIfNecessary(sigmoid);
        return sigmoid;
    }

    public INDArray getReconstructedInput(INDArray iNDArray) {
        if (this.conf.isConcatBiases()) {
            INDArray mmul = iNDArray.mmul(this.W.transpose());
            return Transforms.sigmoid(Nd4j.hstack(new INDArray[]{mmul, Nd4j.ones(mmul.rows(), 1)}));
        }
        INDArray mmul2 = iNDArray.mmul(this.W.transpose());
        mmul2.addiRowVector(this.vBias);
        return Transforms.sigmoid(mmul2);
    }

    public void train(INDArray iNDArray, float f, float f2, int i) {
        if (iNDArray != null) {
            this.input = iNDArray;
        }
        this.lastMiniBatchSize = iNDArray.rows();
        NeuralNetworkGradient gradient = getGradient(new Object[]{Float.valueOf(f2), Float.valueOf(f), Integer.valueOf(i)});
        this.vBias.addi(gradient.getvBiasGradient());
        this.W.addi(gradient.getwGradient());
        this.hBias.addi(gradient.gethBiasGradient());
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork, org.deeplearning4j.nn.api.Model
    public INDArray transform(INDArray iNDArray) {
        return getReconstructedInput(getHiddenValues(iNDArray));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, Object[] objArr) {
        if (iNDArray != null) {
            this.input = iNDArray;
        }
        this.lastMiniBatchSize = iNDArray.rows();
        this.optimizer = new DenoisingAutoEncoderOptimizer(this, this.conf.getLr(), objArr, this.conf.getOptimizationAlgo(), this.conf.getLossFunction());
        this.optimizer.train(iNDArray);
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
        fit(iNDArray, null);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void iterate(INDArray iNDArray, Object[] objArr) {
        float corruptionLevel = this.conf.getCorruptionLevel();
        if (iNDArray != null) {
            this.input = preProcessInput(iNDArray);
        }
        this.lastMiniBatchSize = iNDArray.rows();
        NeuralNetworkGradient gradient = getGradient(new Object[]{Float.valueOf(corruptionLevel), Float.valueOf(this.conf.getLr()), 0});
        this.vBias.addi(gradient.getvBiasGradient());
        this.W.addi(gradient.getwGradient());
        this.hBias.addi(gradient.gethBiasGradient());
    }

    @Override // org.deeplearning4j.nn.BaseNeuralNetwork, org.deeplearning4j.nn.api.NeuralNetwork, org.deeplearning4j.optimize.api.IterationListener
    public void iterationDone(int i) {
        int renderWeightIterations = this.conf.getRenderWeightIterations();
        if (renderWeightIterations <= 0) {
            return;
        }
        if (i % renderWeightIterations == 0 || i == 0) {
            new NeuralNetPlotter().plotNetworkGradient(this, getGradient(new Object[]{Double.valueOf(0.3d), Double.valueOf(0.001d), 1000}), getInput().rows());
        }
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public NeuralNetworkGradient getGradient(Object[] objArr) {
        float corruptionLevel = this.conf.getCorruptionLevel();
        float lr = this.conf.getLr();
        int numIterations = this.conf.getNumIterations();
        if (this.wAdaGrad != null) {
            this.wAdaGrad.setMasterStepSize(lr);
        }
        if (this.hBiasAdaGrad != null) {
            this.hBiasAdaGrad.setMasterStepSize(lr);
        }
        if (this.vBiasAdaGrad != null) {
            this.vBiasAdaGrad.setMasterStepSize(lr);
        }
        INDArray corruptedInput = getCorruptedInput(this.input, corruptionLevel);
        INDArray hiddenValues = getHiddenValues(corruptedInput);
        INDArray sub = this.input.sub(getReconstructedInput(hiddenValues));
        INDArray mul = this.conf.getSparsity() == 0.0f ? sub.mmul(this.W).mul(hiddenValues).mul(hiddenValues.rsub(1)) : sub.mmul(this.W).mul(hiddenValues).mul(hiddenValues.add(Float.valueOf(-this.conf.getSparsity())));
        NeuralNetworkGradient neuralNetworkGradient = new NeuralNetworkGradient(corruptedInput.transpose().mmul(mul).add(sub.transpose().mmul(hiddenValues)), sub.mean(0), mul.mean(0));
        updateGradientAccordingToParams(neuralNetworkGradient, numIterations, lr);
        return neuralNetworkGradient;
    }
}
