package org.nd4j.linalg.api.ops.random.impl;

import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/api/ops/random/impl/BernoulliDistribution.class */
public class BernoulliDistribution extends BaseRandomOp {
    private double prob;

    public BernoulliDistribution(SameDiff sameDiff, double d, long[] jArr) {
        super(sameDiff, jArr);
        this.prob = d;
        this.extraArgs = new Object[]{Double.valueOf(this.prob)};
    }

    public BernoulliDistribution() {
    }

    public BernoulliDistribution(double d, DataType dataType, long... jArr) {
        this(Nd4j.createUninitialized(dataType, jArr), d);
    }

    public BernoulliDistribution(@NonNull INDArray iNDArray, double d) {
        super(null, null, iNDArray);
        if (iNDArray == null) {
            throw new NullPointerException("z is marked @NonNull but is null");
        }
        this.prob = d;
        this.extraArgs = new Object[]{Double.valueOf(this.prob)};
    }

    public BernoulliDistribution(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2) {
        super(iNDArray2, null, iNDArray);
        if (iNDArray == null) {
            throw new NullPointerException("z is marked @NonNull but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("prob is marked @NonNull but is null");
        }
        if (iNDArray2.elementWiseStride() != 1) {
            throw new ND4JIllegalStateException("Probabilities should have ElementWiseStride of 1");
        }
        if (iNDArray2.length() != iNDArray.length()) {
            throw new ND4JIllegalStateException("Length of probabilities array [" + iNDArray2.length() + "] doesn't match length of output array [" + iNDArray.length() + "]");
        }
        this.prob = 0.0d;
        this.extraArgs = new Object[]{Double.valueOf(this.prob)};
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int opNum() {
        return 7;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "distribution_bernoulli";
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + opName());
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return Collections.emptyList();
    }

    @Override // org.nd4j.linalg.api.ops.random.BaseRandomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list == null || list.isEmpty(), "Expected no input datatypes (no args) for %s, got %s", getClass(), list);
        return Collections.singletonList(DataType.DOUBLE);
    }
}
