package org.nd4j.autodiff.samediff.ops;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/autodiff/samediff/ops/SDNN.class */
public class SDNN extends SDOps {
    public SDNN(SameDiff sameDiff) {
        super(sameDiff);
    }

    public SDVariable batchNorm(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, double d, int... iArr) {
        return batchNorm(null, sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, true, true, d, iArr);
    }

    public SDVariable batchNorm(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, boolean z, boolean z2, double d, int... iArr) {
        SDValidation.validateFloatingPoint("batchNorm", "input", sDVariable);
        SDValidation.validateFloatingPoint("batchNorm", "mean", sDVariable2);
        SDValidation.validateFloatingPoint("batchNorm", "variance", sDVariable3);
        SDValidation.validateFloatingPoint("batchNorm", "gamma", sDVariable4);
        SDValidation.validateFloatingPoint("batchNorm", "beta", sDVariable5);
        return updateVariableNameAndReference(f().batchNorm(sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, z, z2, d, iArr), str);
    }

    public SDVariable batchNorm(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, double d, int... iArr) {
        return batchNorm(str, sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, true, true, d, iArr);
    }

    public SDVariable biasAdd(SDVariable sDVariable, SDVariable sDVariable2, boolean z) {
        return biasAdd(null, sDVariable, sDVariable2, z);
    }

    public SDVariable biasAdd(String str, SDVariable sDVariable, SDVariable sDVariable2, boolean z) {
        SDValidation.validateFloatingPoint("biasAdd", "input", sDVariable);
        SDValidation.validateFloatingPoint("biasAdd", "bias", sDVariable2);
        return updateVariableNameAndReference(f().biasAdd(sDVariable, sDVariable2, z), str);
    }

    public SDVariable dropout(SDVariable sDVariable, double d) {
        return dropout(null, sDVariable, d);
    }

    public SDVariable dropout(String str, SDVariable sDVariable, double d) {
        SDValidation.validateFloatingPoint("dropout", sDVariable);
        return updateVariableNameAndReference(f().dropout(sDVariable, d), str);
    }

    public SDVariable elu(SDVariable sDVariable) {
        return elu(null, sDVariable);
    }

    public SDVariable elu(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("elu", sDVariable);
        return updateVariableNameAndReference(f().elu(sDVariable), str);
    }

    public SDVariable gelu(SDVariable sDVariable) {
        return gelu(null, sDVariable);
    }

    public SDVariable gelu(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("gelu", sDVariable);
        return updateVariableNameAndReference(f().gelu(sDVariable, false), str);
    }

    public SDVariable hardSigmoid(SDVariable sDVariable) {
        return hardSigmoid(null, sDVariable);
    }

    public SDVariable hardSigmoid(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("hard sigmoid", sDVariable);
        return updateVariableNameAndReference(f().hardSigmoid(sDVariable), str);
    }

    public SDVariable hardTanh(SDVariable sDVariable) {
        return hardTanh(null, sDVariable);
    }

    public SDVariable hardTanh(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("hard Tanh", sDVariable);
        return updateVariableNameAndReference(f().hardTanh(sDVariable), str);
    }

    public SDVariable hardTanhDerivative(SDVariable sDVariable) {
        return hardTanhDerivative(null, sDVariable);
    }

    public SDVariable hardTanhDerivative(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("hard Tanh derivative", sDVariable);
        return updateVariableNameAndReference(f().hardTanhDerivative(sDVariable), str);
    }

    public SDVariable leakyRelu(SDVariable sDVariable, double d) {
        return leakyRelu(null, sDVariable, d);
    }

    public SDVariable leakyRelu(String str, SDVariable sDVariable, double d) {
        SDValidation.validateFloatingPoint("leaky ReLU", sDVariable);
        return updateVariableNameAndReference(f().leakyRelu(sDVariable, d), str);
    }

    public SDVariable leakyReluDerivative(String str, SDVariable sDVariable, double d) {
        SDValidation.validateFloatingPoint("leaky ReLU derivative", sDVariable);
        return updateVariableNameAndReference(f().leakyReluDerivative(sDVariable, d), str);
    }

    public SDVariable linear(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return linear(null, sDVariable, sDVariable2, sDVariable3);
    }

    public SDVariable linear(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        SDValidation.validateFloatingPoint("linear", "input", sDVariable);
        SDValidation.validateFloatingPoint("linear", "weights", sDVariable2);
        SDValidation.validateFloatingPoint("linear", "bias", sDVariable3);
        return updateVariableNameAndReference(f().xwPlusB(sDVariable, sDVariable2, sDVariable3), str);
    }

    public SDVariable logSigmoid(SDVariable sDVariable) {
        return logSigmoid(null, sDVariable);
    }

    public SDVariable logSigmoid(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("log sigmoid", sDVariable);
        return updateVariableNameAndReference(f().logSigmoid(sDVariable), str);
    }

    public SDVariable logSoftmax(SDVariable sDVariable) {
        return logSoftmax((String) null, sDVariable);
    }

    public SDVariable logSoftmax(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("log softmax", sDVariable);
        return updateVariableNameAndReference(f().logSoftmax(sDVariable), str);
    }

    public SDVariable logSoftmax(SDVariable sDVariable, int i) {
        return logSoftmax(null, sDVariable, i);
    }

    public SDVariable logSoftmax(String str, SDVariable sDVariable, int i) {
        SDValidation.validateFloatingPoint("log softmax", sDVariable);
        return updateVariableNameAndReference(f().logSoftmax(sDVariable, i), str);
    }

    public SDVariable relu(SDVariable sDVariable, double d) {
        return relu(null, sDVariable, d);
    }

    public SDVariable relu(String str, SDVariable sDVariable, double d) {
        SDValidation.validateFloatingPoint("ReLU", sDVariable);
        return updateVariableNameAndReference(f().relu(sDVariable, d), str);
    }

    public SDVariable relu6(SDVariable sDVariable, double d) {
        return relu6(null, sDVariable, d);
    }

    public SDVariable relu6(String str, SDVariable sDVariable, double d) {
        SDValidation.validateFloatingPoint("ReLU6", sDVariable);
        return updateVariableNameAndReference(f().relu6(sDVariable, d), str);
    }

    public SDVariable reluLayer(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return reluLayer(null, sDVariable, sDVariable2, sDVariable3);
    }

    public SDVariable reluLayer(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        SDValidation.validateFloatingPoint("reluLayer", "input", sDVariable);
        SDValidation.validateFloatingPoint("reluLayer", "weights", sDVariable2);
        SDValidation.validateFloatingPoint("reluLayer", "bias", sDVariable3);
        return updateVariableNameAndReference(f().reluLayer(sDVariable, sDVariable2, sDVariable3), str);
    }

    public SDVariable prelu(@NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull int... iArr) {
        if (sDVariable == null) {
            throw new NullPointerException("input is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("alpha is marked @NonNull but is null");
        }
        if (iArr == null) {
            throw new NullPointerException("sharedAxes is marked @NonNull but is null");
        }
        return f().prelu(sDVariable, sDVariable2, iArr);
    }

    public SDVariable prelu(String str, @NonNull SDVariable sDVariable, @NonNull SDVariable sDVariable2, @NonNull int... iArr) {
        if (sDVariable == null) {
            throw new NullPointerException("input is marked @NonNull but is null");
        }
        if (sDVariable2 == null) {
            throw new NullPointerException("alpha is marked @NonNull but is null");
        }
        if (iArr == null) {
            throw new NullPointerException("sharedAxes is marked @NonNull but is null");
        }
        return updateVariableNameAndReference(f().prelu(sDVariable, sDVariable2, iArr), str);
    }

    public SDVariable selu(SDVariable sDVariable) {
        return selu(null, sDVariable);
    }

    public SDVariable selu(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("selu", sDVariable);
        return updateVariableNameAndReference(f().selu(sDVariable), str);
    }

    public SDVariable sigmoid(SDVariable sDVariable) {
        return sigmoid(null, sDVariable);
    }

    public SDVariable sigmoid(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("sigmoid", sDVariable);
        return updateVariableNameAndReference(f().sigmoid(sDVariable), str);
    }

    public SDVariable sigmoidDerivative(SDVariable sDVariable, SDVariable sDVariable2) {
        return sigmoidDerivative(null, sDVariable, sDVariable2);
    }

    public SDVariable sigmoidDerivative(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        SDValidation.validateFloatingPoint("sigmoidDerivative", sDVariable);
        return updateVariableNameAndReference(f().sigmoidDerivative(sDVariable, sDVariable2), str);
    }

    public SDVariable softmax(SDVariable sDVariable) {
        return softmax((String) null, sDVariable);
    }

    public SDVariable softmax(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("softmax", sDVariable);
        return updateVariableNameAndReference(f().softmax(sDVariable), str);
    }

    public SDVariable softmax(SDVariable sDVariable, int i) {
        return softmax(null, sDVariable, i);
    }

    public SDVariable softmax(String str, SDVariable sDVariable, int i) {
        SDValidation.validateFloatingPoint("softmax", sDVariable);
        return updateVariableNameAndReference(f().softmax(sDVariable, i), str);
    }

    public SDVariable softmaxDerivative(String str, SDVariable sDVariable, SDVariable sDVariable2) {
        return softmaxDerivative(str, sDVariable, sDVariable2, null);
    }

    public SDVariable softmaxDerivative(String str, SDVariable sDVariable, SDVariable sDVariable2, Integer num) {
        SDValidation.validateFloatingPoint("softmaxDerivative", sDVariable);
        return updateVariableNameAndReference(f().softmaxDerivative(sDVariable, sDVariable2, num), str);
    }

    public SDVariable softplus(SDVariable sDVariable) {
        return softplus(null, sDVariable);
    }

    public SDVariable softplus(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("softplus", sDVariable);
        return updateVariableNameAndReference(f().softplus(sDVariable), str);
    }

    public SDVariable softsign(SDVariable sDVariable) {
        return softsign(null, sDVariable);
    }

    public SDVariable softsign(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("softsign", sDVariable);
        return updateVariableNameAndReference(f().softsign(sDVariable), str);
    }

    public SDVariable softsignDerivative(SDVariable sDVariable) {
        return softsignDerivative(null, sDVariable);
    }

    public SDVariable softsignDerivative(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("softsignDerivative", sDVariable);
        return updateVariableNameAndReference(f().softsignDerivative(sDVariable), str);
    }

    public SDVariable swish(SDVariable sDVariable) {
        return swish(null, sDVariable);
    }

    public SDVariable swish(String str, SDVariable sDVariable) {
        SDValidation.validateFloatingPoint("swish", sDVariable);
        return updateVariableNameAndReference(f().swish(sDVariable), str);
    }

    public SDVariable tanh(String str, SDVariable sDVariable) {
        return this.sd.math().tanh(str, sDVariable);
    }

    public SDVariable tanh(SDVariable sDVariable) {
        return this.sd.math().tanh(sDVariable);
    }

    public SDVariable layerNorm(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, boolean z, int... iArr) {
        return layerNorm(null, sDVariable, sDVariable2, sDVariable3, z, iArr);
    }

    public SDVariable layerNorm(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, boolean z, int... iArr) {
        SDValidation.validateFloatingPoint("layerNorm", "input", sDVariable);
        SDValidation.validateFloatingPoint("layerNorm", "gain", sDVariable2);
        SDValidation.validateFloatingPoint("layerNorm", "bias", sDVariable3);
        return updateVariableNameAndReference(f().layerNorm(sDVariable, sDVariable2, sDVariable3, z, iArr), str);
    }

    public SDVariable layerNorm(SDVariable sDVariable, SDVariable sDVariable2, boolean z, int... iArr) {
        return layerNorm((String) null, sDVariable, sDVariable2, z, iArr);
    }

    public SDVariable layerNorm(String str, SDVariable sDVariable, SDVariable sDVariable2, boolean z, int... iArr) {
        SDValidation.validateFloatingPoint("layerNorm", "input", sDVariable);
        SDValidation.validateFloatingPoint("layerNorm", "gain", sDVariable2);
        return updateVariableNameAndReference(f().layerNorm(sDVariable, sDVariable2, z, iArr), str);
    }

    public SDVariable pad(SDVariable sDVariable, int[][] iArr, double d) {
        return pad(sDVariable, this.sd.constant(Nd4j.createFromArray(iArr)), d);
    }

    public SDVariable pad(SDVariable sDVariable, SDVariable sDVariable2, double d) {
        return pad(null, sDVariable, sDVariable2, Pad.Mode.CONSTANT, d);
    }

    public SDVariable pad(String str, SDVariable sDVariable, SDVariable sDVariable2, Pad.Mode mode, double d) {
        return updateVariableNameAndReference(f().pad(sDVariable, sDVariable2, mode, d), str);
    }

    public SDVariable dotProductAttention(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, boolean z) {
        return dotProductAttention((String) null, sDVariable, sDVariable2, sDVariable3, sDVariable4, z);
    }

    public SDVariable dotProductAttention(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, boolean z) {
        return updateVariableNameAndReference(f().dotProductAttention(sDVariable, sDVariable2, sDVariable3, sDVariable4, z), str);
    }

    public List<SDVariable> dotProductAttention(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, boolean z, boolean z2) {
        return dotProductAttention(null, sDVariable, sDVariable2, sDVariable3, sDVariable4, z, z2);
    }

    public List<SDVariable> dotProductAttention(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, boolean z, boolean z2) {
        List<SDVariable> dotProductAttention = f().dotProductAttention(sDVariable, sDVariable2, sDVariable3, sDVariable4, z, z2);
        return z2 ? Collections.singletonList(updateVariableNameAndReference(dotProductAttention.get(0), str)) : Arrays.asList(updateVariableNameAndReference(dotProductAttention.get(0), str), updateVariableNameAndReference(dotProductAttention.get(1), str + ":weights"));
    }

    public SDVariable multiHeadDotProductAttention(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, SDVariable sDVariable6, SDVariable sDVariable7, SDVariable sDVariable8, boolean z) {
        return multiHeadDotProductAttention((String) null, sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, sDVariable6, sDVariable7, sDVariable8, z);
    }

    public SDVariable multiHeadDotProductAttention(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, SDVariable sDVariable6, SDVariable sDVariable7, SDVariable sDVariable8, boolean z) {
        return updateVariableNameAndReference(f().multiHeadDotProductAttention(sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, sDVariable6, sDVariable7, sDVariable8, z), str);
    }

    public List<SDVariable> multiHeadDotProductAttention(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, SDVariable sDVariable6, SDVariable sDVariable7, SDVariable sDVariable8, boolean z, boolean z2) {
        return multiHeadDotProductAttention(null, sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, sDVariable6, sDVariable7, sDVariable8, z, z2);
    }

    public List<SDVariable> multiHeadDotProductAttention(String str, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5, SDVariable sDVariable6, SDVariable sDVariable7, SDVariable sDVariable8, boolean z, boolean z2) {
        List<SDVariable> multiHeadDotProductAttention = f().multiHeadDotProductAttention(sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5, sDVariable6, sDVariable7, sDVariable8, z, z2);
        return z2 ? Collections.singletonList(updateVariableNameAndReference(multiHeadDotProductAttention.get(0), str)) : Arrays.asList(updateVariableNameAndReference(multiHeadDotProductAttention.get(0), str), updateVariableNameAndReference(multiHeadDotProductAttention.get(1), str + ":weights"));
    }

    public SDVariable[] maxPoolWithArgmax(String[] strArr, SDVariable sDVariable, Pooling2DConfig pooling2DConfig) {
        return this.sd.updateVariableNamesAndReferences(f().maxPoolWithArgmaxs(sDVariable, pooling2DConfig), strArr);
    }

    public SDVariable[] fusedBatchNorm(String[] strArr, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, SDVariable sDVariable5) {
        return this.sd.updateVariableNamesAndReferences(f().fusedBatchNorm(sDVariable, sDVariable2, sDVariable3, sDVariable4, sDVariable5), strArr);
    }
}
