package org.deeplearning4j.nn.layers.feedforward.autoencoder;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.BasePretrainNetwork;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/layers/feedforward/autoencoder/AutoEncoder.class */
public class AutoEncoder extends BasePretrainNetwork<org.deeplearning4j.nn.conf.layers.AutoEncoder> {
    private static final long serialVersionUID = -6445530486350763837L;

    public AutoEncoder(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public AutoEncoder(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.BasePretrainNetwork
    public Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray iNDArray) {
        setInput(iNDArray);
        INDArray encode = encode(true);
        return new Pair<>(encode, encode);
    }

    @Override // org.deeplearning4j.nn.layers.BasePretrainNetwork
    public Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray iNDArray) {
        INDArray decode = decode(iNDArray);
        return new Pair<>(decode, decode);
    }

    public INDArray encode(boolean z) {
        if (this.conf.getLayer().getDropOut() > 0.0d && z) {
            this.dropoutMask = Dropout.applyDropout(this.input, this.conf.getLayer().getDropOut(), this.dropoutMask);
        }
        INDArray param = getParam("W");
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), this.input.mmul(param).addiRowVector(getParam("b"))));
    }

    public INDArray decode(INDArray iNDArray) {
        INDArray param = getParam("W");
        INDArray param2 = getParam("vb");
        INDArray mmul = iNDArray.mmul(param.transposei());
        mmul.addiRowVector(param2);
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), mmul));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z) {
        setInput(iNDArray);
        return encode(z);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        setInput(iNDArray);
        return encode(true);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        return decode(encode(z));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return decode(encode(false));
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore() {
        INDArray param = getParam("W");
        double corruptionLevel = ((org.deeplearning4j.nn.conf.layers.AutoEncoder) layerConf()).getCorruptionLevel();
        INDArray corruptedInput = corruptionLevel > 0.0d ? getCorruptedInput(this.input, corruptionLevel) : this.input;
        setInput(corruptedInput);
        INDArray encode = encode(true);
        INDArray decode = decode(encode);
        INDArray sub = this.input.sub(decode);
        INDArray muli = ((org.deeplearning4j.nn.conf.layers.AutoEncoder) layerConf()).getSparsity() == 0.0d ? sub.mmul(param).muli(encode).muli(encode.rsub(1)) : sub.mmul(param).muli(encode).muli(encode.add(Double.valueOf(-((org.deeplearning4j.nn.conf.layers.AutoEncoder) layerConf()).getSparsity())));
        this.gradient = createGradient(corruptedInput.transposei().mmul(muli).addi(sub.transposei().mmul(encode)), sub.sum(new int[]{0}), muli.sum(new int[]{0}));
        setScoreWithZ(decode);
    }
}
