package org.nd4j.linalg.sampling;

import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.MathUtils;

/* loaded from: input_file:org/nd4j/linalg/sampling/Sampling.class */
public class Sampling {
    public static INDArray normal(RandomGenerator randomGenerator, INDArray iNDArray, INDArray iNDArray2) {
        INDArray dup = iNDArray.reshape(1, iNDArray.length()).dup();
        INDArray ravel = iNDArray2.ravel();
        for (int i = 0; i < dup.length(); i++) {
            dup.putScalar(i, Double.valueOf(new NormalDistribution(randomGenerator, iNDArray.get(i), FastMath.sqrt(ravel.get(i)), 1.0E-9d).sample()));
        }
        return dup.reshape(iNDArray.shape());
    }

    public static INDArray normal(RandomGenerator randomGenerator, INDArray iNDArray, double d) {
        INDArray create = Nd4j.create(iNDArray.shape());
        INDArray linearView = iNDArray.linearView();
        INDArray linearView2 = create.linearView();
        double sqrt = FastMath.sqrt(d);
        for (int i = 0; i < linearView.length(); i++) {
            linearView2.putScalar(i, Double.valueOf(new NormalDistribution(randomGenerator, linearView.get(i), sqrt, 1.0E-9d).sample()));
        }
        return create;
    }

    public static INDArray binomial(INDArray iNDArray, int i, RandomGenerator randomGenerator) {
        INDArray dup = iNDArray.dup();
        INDArray linearView = dup.linearView();
        for (int i2 = 0; i2 < dup.length(); i2++) {
            linearView.putScalar(i2, Integer.valueOf(MathUtils.binomial(randomGenerator, i, linearView.get(i2))));
        }
        return dup;
    }
}
