package org.deeplearning4j.nn.layers.feedforward.autoencoder.recursive;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.params.RecursiveParamInitializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/layers/feedforward/autoencoder/recursive/RecursiveAutoEncoder.class */
public class RecursiveAutoEncoder extends BaseLayer<org.deeplearning4j.nn.conf.layers.RecursiveAutoEncoder> {
    private INDArray currInput;
    private INDArray allInput;
    private INDArray visibleLoss;
    private INDArray hiddenLoss;
    private INDArray vbLoss;
    private INDArray bLoss;
    private INDArray y;
    private INDArray z;
    double currScore;

    public RecursiveAutoEncoder(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
        this.currInput = null;
        this.allInput = null;
        this.visibleLoss = null;
        this.hiddenLoss = null;
        this.vbLoss = null;
        this.bLoss = null;
        this.y = null;
        this.z = null;
        this.currScore = 0.0d;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURSIVE;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public double score() {
        return this.currScore;
    }

    public INDArray encode(boolean z) {
        INDArray param = getParam("W");
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), this.currInput.mmul(param).addiRowVector(getParam("b"))));
    }

    public INDArray decode(INDArray iNDArray) {
        INDArray param = getParam(RecursiveParamInitializer.DECODER_WEIGHT_KEY);
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), iNDArray.mmul(param).addiRowVector(getParam("vb"))));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z) {
        setInput(iNDArray);
        return encode(z);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        setInput(iNDArray);
        return encode(true);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        return decode(encode(z));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return decode(encode(false));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void iterate(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore() {
        this.currScore = 0.0d;
        int i = 0;
        while (i < this.input.rows()) {
            INDArray concat = this.currInput == null ? Nd4j.concat(0, new INDArray[]{this.input.slice(i), this.input.slice(i + 1)}) : Nd4j.concat(0, new INDArray[]{this.input.slice(i), this.currInput});
            if (i == 0) {
                i++;
            }
            this.currInput = concat;
            this.allInput = concat;
            this.y = encode(true);
            this.z = decode(this.y);
            INDArray sub = this.currInput.sub(this.z);
            INDArray muli = sub.mmul(getParam("W")).muli(this.y).muli(this.y.rsub(1));
            INDArray mmul = this.z.transpose().mmul(muli);
            INDArray mmul2 = this.y.transpose().mmul(sub);
            if (this.visibleLoss == null) {
                this.visibleLoss = mmul2;
            } else {
                this.visibleLoss.addi(mmul2);
            }
            if (this.hiddenLoss == null) {
                this.hiddenLoss = mmul;
            } else {
                this.hiddenLoss.addi(mmul);
            }
            INDArray mean = sub.isMatrix() ? sub.mean(new int[]{0}) : sub;
            INDArray mean2 = muli.isMatrix() ? muli.mean(new int[]{0}) : muli;
            if (this.vbLoss == null) {
                this.vbLoss = mean;
            } else {
                this.vbLoss.addi(mean);
            }
            if (this.bLoss == null) {
                this.bLoss = mean2;
            } else {
                this.bLoss.addi(mean2);
            }
            this.currScore += 0.5d * Transforms.pow(this.z.sub(this.allInput), 2).mean(new int[]{Integer.MAX_VALUE}).getDouble(0);
            i++;
        }
        this.gradient = createGradient(this.hiddenLoss, this.visibleLoss, this.bLoss, this.vbLoss);
        this.score = this.currScore;
    }
}
