package org.deeplearning4j.nn.conf.layers.recurrent;

import java.util.Collection;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.class */
public class LastTimeStep extends BaseWrapperLayer {
    private LastTimeStep() {
    }

    public LastTimeStep(Layer layer) {
        super(layer);
        this.layerName = layer.getLayerName();
    }

    @Override // org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer
    public Layer getUnderlying() {
        return this.underlying;
    }

    @Override // org.deeplearning4j.nn.conf.layers.Layer
    public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration neuralNetConfiguration, Collection<TrainingListener> collection, int i, INDArray iNDArray, boolean z) {
        NeuralNetConfiguration m6296clone = neuralNetConfiguration.m6296clone();
        m6296clone.setLayer(((LastTimeStep) m6296clone.getLayer()).getUnderlying());
        return new LastTimeStepLayer(this.underlying.instantiate(m6296clone, collection, i, iNDArray, z));
    }

    @Override // org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer, org.deeplearning4j.nn.conf.layers.Layer
    public InputType getOutputType(int i, InputType inputType) {
        if (inputType.getType() != InputType.Type.RNN) {
            throw new IllegalArgumentException("Require RNN input type - got " + inputType);
        }
        return InputType.feedForward(((InputType.InputTypeRecurrent) this.underlying.getOutputType(i, inputType)).getSize());
    }
}
