package org.deeplearning4j.nn;

import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/WeightInitUtil.class */
public class WeightInitUtil {
    public static INDArray initWeights(int[] iArr, float f, float f2) {
        return Nd4j.rand(iArr, f, f2, new MersenneTwister(123));
    }

    public static INDArray initWeights(int i, int i2, WeightInit weightInit, ActivationFunction activationFunction, RealDistribution realDistribution) {
        INDArray randn = Nd4j.randn(i, i2);
        switch (weightInit) {
            case VI:
                double sqrt = Math.sqrt(6.0d) / Math.sqrt((i + i2) + 1);
                randn.muli(2).muli(Double.valueOf(sqrt)).subi(Double.valueOf(sqrt));
                return randn;
            case DISTRIBUTION:
                for (int i3 = 0; i3 < randn.rows(); i3++) {
                    randn.putRow(i3, Nd4j.create(realDistribution.sample(randn.columns())));
                }
                return randn;
            default:
                throw new IllegalStateException("Illegal weight init value");
        }
    }
}
