package org.nd4j.autodiff.samediff.ops;

import java.util.Arrays;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.GRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMBlockCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUCellConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.SRUConfiguration;

/* loaded from: input_file:org/nd4j/autodiff/samediff/ops/SDRNN.class */
public class SDRNN extends SDOps {
    public SDRNN(SameDiff sameDiff) {
        super(sameDiff);
    }

    public List<SDVariable> gru(GRUCellConfiguration gRUCellConfiguration) {
        return Arrays.asList(new GRUCell(this.sd, gRUCellConfiguration).outputVariables());
    }

    public List<SDVariable> gru(String str, GRUCellConfiguration gRUCellConfiguration) {
        return Arrays.asList(new GRUCell(this.sd, gRUCellConfiguration).outputVariables(str));
    }

    public SDVariable lstmCell(String str, LSTMCellConfiguration lSTMCellConfiguration) {
        return new LSTMCell(this.sd, lSTMCellConfiguration).outputVariables(str)[0];
    }

    public List<SDVariable> lstmBlockCell(String str, LSTMBlockCellConfiguration lSTMBlockCellConfiguration) {
        return Arrays.asList(new LSTMBlockCell(this.sd, lSTMBlockCellConfiguration).outputVariables(str));
    }

    public List<SDVariable> lstmLayer(String str, LSTMConfiguration lSTMConfiguration) {
        return Arrays.asList(new LSTMLayer(this.sd, lSTMConfiguration).outputVariables(str));
    }

    public SDVariable sru(SRUConfiguration sRUConfiguration) {
        return new SRU(this.sd, sRUConfiguration).outputVariables()[0];
    }

    public SDVariable sru(String str, SRUConfiguration sRUConfiguration) {
        return new SRU(this.sd, sRUConfiguration).outputVariables(str)[0];
    }

    public SDVariable sruCell(SRUCellConfiguration sRUCellConfiguration) {
        return new SRUCell(this.sd, sRUCellConfiguration).outputVariables()[0];
    }

    public SDVariable sruCell(String str, SRUCellConfiguration sRUCellConfiguration) {
        return new SRUCell(this.sd, sRUCellConfiguration).outputVariables(str)[0];
    }
}
