package org.deeplearning4j.nn.graph.vertex.impl.rnn;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.class */
public class LastTimeStepVertex extends BaseGraphVertex {
    private String inputName;
    private int inputIdx;
    private int[] fwdPassShape;
    private int[] fwdPassTimeSteps;

    public LastTimeStepVertex(ComputationGraph computationGraph, String str, int i, String str2) {
        this(computationGraph, str, i, null, null, str2);
    }

    public LastTimeStepVertex(ComputationGraph computationGraph, String str, int i, VertexIndices[] vertexIndicesArr, VertexIndices[] vertexIndicesArr2, String str2) {
        super(computationGraph, str, i, vertexIndicesArr, vertexIndicesArr2);
        this.inputName = str2;
        this.inputIdx = computationGraph.getConfiguration().getNetworkInputs().indexOf(str2);
        if (this.inputIdx == -1) {
            throw new IllegalArgumentException("Invalid input name: \"" + str2 + "\" not found in list of network inputs (" + computationGraph.getConfiguration().getNetworkInputs() + ")");
        }
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean hasLayer() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean isOutputVertex() {
        return false;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Layer getLayer() {
        return null;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public INDArray doForward(boolean z) {
        INDArray create;
        INDArray[] inputMaskArrays = this.graph.getInputMaskArrays();
        INDArray iNDArray = inputMaskArrays != null ? inputMaskArrays[this.inputIdx] : null;
        this.fwdPassShape = this.inputs[0].shape();
        if (iNDArray == null) {
            create = this.inputs[0].get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(this.inputs[0].size(2) - 1)});
            this.fwdPassTimeSteps = null;
        } else {
            create = Nd4j.create(new int[]{this.inputs[0].size(0), this.inputs[0].size(1)});
            int i = this.fwdPassShape[2];
            INDArray argMax = Nd4j.argMax(iNDArray.mulRowVector(Nd4j.linspace(0, i - 1, i)), new int[]{1});
            this.fwdPassTimeSteps = new int[this.fwdPassShape[0]];
            for (int i2 = 0; i2 < this.fwdPassTimeSteps.length; i2++) {
                this.fwdPassTimeSteps[i2] = (int) argMax.getDouble(i2);
            }
            for (int i3 = 0; i3 < this.fwdPassTimeSteps.length; i3++) {
                create.putRow(i3, this.inputs[0].get(new INDArrayIndex[]{NDArrayIndex.point(i3), NDArrayIndex.all(), NDArrayIndex.point(this.fwdPassTimeSteps[i3])}));
            }
        }
        return create;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<Gradient, INDArray[]> doBackward(boolean z) {
        INDArray create = Nd4j.create(this.fwdPassShape);
        if (this.fwdPassTimeSteps == null) {
            create.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(this.fwdPassShape[2] - 1)}, this.epsilons[0]);
        } else {
            for (int i = 0; i < this.fwdPassTimeSteps.length; i++) {
                create.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(this.fwdPassTimeSteps[i])}, this.epsilons[0].getRow(i));
            }
        }
        return new Pair<>(null, new INDArray[]{create});
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public String toString() {
        return "LastTimeStepVertex(inputName=" + this.inputName + ")";
    }
}
