package org.nd4j.linalg.api.ops.impl.accum;

import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BaseAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ComplexUtil;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/accum/Variance.class */
public class Variance extends BaseAccumulation {
    private double mean;
    private double bias;
    private boolean biasCorrected;

    public Variance(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i) {
        super(iNDArray, iNDArray2, iNDArray3, i);
        this.biasCorrected = true;
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2, int i) {
        this(iNDArray, iNDArray2, iNDArray, i);
    }

    public Variance(INDArray iNDArray) {
        this(iNDArray, null, iNDArray, iNDArray.length(), true);
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2) {
        super(iNDArray, iNDArray2);
        this.biasCorrected = true;
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i, boolean z) {
        super(iNDArray, iNDArray2, iNDArray3, i);
        this.biasCorrected = true;
        this.biasCorrected = z;
        init(iNDArray, iNDArray2, iNDArray3, i);
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2, int i, boolean z) {
        super(iNDArray, iNDArray2, i);
        this.biasCorrected = true;
        this.biasCorrected = z;
        init(iNDArray, iNDArray2, this.z, i);
    }

    public Variance(INDArray iNDArray, boolean z) {
        super(iNDArray);
        this.biasCorrected = true;
        this.biasCorrected = z;
        init(iNDArray, this.y, this.z, this.n);
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        super(iNDArray, iNDArray2);
        this.biasCorrected = true;
        this.biasCorrected = z;
        init(iNDArray, iNDArray2, iNDArray, iNDArray.length());
    }

    public void update(Number number) {
        this.currentResult = Double.valueOf(currentResult().doubleValue() + FastMath.pow(number.doubleValue() - this.mean, 2));
        if (numProcessed() == n()) {
            if (this.biasCorrected) {
                this.currentResult = Double.valueOf((this.currentResult.doubleValue() - (FastMath.pow(this.bias, 2.0d) / n())) / (n() - 1.0d));
            } else {
                this.currentResult = Double.valueOf(currentResult().doubleValue() / this.n);
            }
        }
    }

    public void update(IComplexNumber iComplexNumber) {
        this.currentComplexResult.addi(ComplexUtil.pow(iComplexNumber.sub(Double.valueOf(this.mean)), 2.0d));
        if (numProcessed() == n()) {
            if (this.biasCorrected) {
                this.currentComplexResult = this.currentComplexResult.sub(ComplexUtil.pow(Nd4j.createComplexNumber(Double.valueOf(this.bias), 0), 2.0d).div(Nd4j.createComplexNumber(Integer.valueOf(n()), 0))).div(Nd4j.createComplexNumber(Double.valueOf(n() - 1.0d), Double.valueOf(0.0d)));
            } else {
                this.currentComplexResult.divi(Integer.valueOf(this.n - 1));
            }
        }
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public String name() {
        return "var";
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public Op opForDimension(int i, int i2) {
        INDArray vectorAlongDimension = this.x.vectorAlongDimension(i, i2);
        return y() != null ? new Variance(vectorAlongDimension, this.y.vectorAlongDimension(i, i2), vectorAlongDimension.length()) : new Variance(this.x.vectorAlongDimension(i, i2));
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void init(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i) {
        super.init(iNDArray, iNDArray2, iNDArray3, i);
        if (this.biasCorrected) {
            this.bias = Nd4j.getExecutioner().execAndReturn((Accumulation) new Bias(iNDArray)).currentResult().doubleValue();
        }
        this.mean = Nd4j.getExecutioner().execAndReturn((Accumulation) new Mean(iNDArray)).currentResult().doubleValue();
        this.extraArgs = new Object[]{zero(), Double.valueOf(this.bias), Double.valueOf(this.mean)};
    }
}
