package org.bigml.mimir.deepnet.layers;

import java.io.IOException;
import java.io.ObjectInputStream;
import org.bigml.mimir.deepnet.layers.twod.Layer2D;
import org.bigml.mimir.deepnet.layers.twod.OutputTensor;
import org.bigml.mimir.math.Matrices;

/* loaded from: input_file:org/bigml/mimir/deepnet/layers/BatchNormalize.class */
public class BatchNormalize implements Layer, Layer2D {
    private static float _EPSILON = 0.001f;
    private int _index;
    private int[] _inputShape;
    private float[] _mean;
    private float[] _stdev;
    private float[] _beta;
    private float[] _gamma;
    private transient OutputTensor _output;
    private static final long serialVersionUID = 1;

    public BatchNormalize(double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        this._mean = Matrices.toFloat(dArr);
        this._beta = Matrices.toFloat(dArr3);
        this._gamma = Matrices.toFloat(dArr4);
        this._stdev = new float[dArr2.length];
        for (int i = 0; i < dArr2.length; i++) {
            this._stdev[i] = (float) Math.sqrt(dArr2[i] + _EPSILON);
        }
        this._inputShape = new int[]{this._mean.length};
        this._output = new OutputTensor(this._inputShape);
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int getIndex() {
        return this._index;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public void setIndex(int i) {
        this._index = i;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] initialize(int[] iArr) {
        this._inputShape = iArr;
        this._output = new OutputTensor(this._inputShape);
        return iArr;
    }

    @Override // org.bigml.mimir.deepnet.layers.Layer
    public int getOutputLength() {
        return this._mean.length;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] getOutputShape() {
        return this._output.getShape();
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] getInputShape() {
        return this._inputShape;
    }

    public void normalize(float[] fArr, int i, int i2) {
        for (int i3 = i; i3 < i2; i3++) {
            fArr[i3] = (fArr[i3] - this._mean[i3]) / this._stdev[i3];
            fArr[i3] = (this._gamma[i3] * fArr[i3]) + this._beta[i3];
        }
    }

    @Override // org.bigml.mimir.deepnet.layers.Layer
    public float[] propagate(float[] fArr) {
        float[] fArr2 = this._output.get();
        System.arraycopy(fArr, 0, fArr2, 0, fArr2.length);
        normalize(fArr2, 0, fArr2.length);
        return fArr2;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public float[] propagate2D(float[] fArr) {
        float[] fArr2 = this._output.get();
        int i = 0;
        while (i < fArr.length) {
            for (int i2 = 0; i2 < this._mean.length; i2++) {
                fArr2[i] = (this._gamma[i2] * ((fArr[i] - this._mean[i2]) / this._stdev[i2])) + this._beta[i2];
                i++;
            }
        }
        return fArr2;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public double[][][] propagateArray(double[][][] dArr) {
        return Matrices.reshape(propagate2D(Matrices.unroll(dArr)), this._output.getShape());
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public float[] getLastOutput() {
        return this._output.get();
    }

    public float[] getMean() {
        return this._mean;
    }

    public float[] getStDev() {
        return this._stdev;
    }

    public float[] getBeta() {
        return this._beta;
    }

    public float[] getGamma() {
        return this._gamma;
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        this._output = new OutputTensor(new int[]{getOutputLength()});
    }
}
