package com.github.chen0040.glm.solvers;

import Jama.Matrix;
import Jama.QRDecomposition;
import com.github.chen0040.glm.enums.GlmDistributionFamily;
import com.github.chen0040.glm.links.LinkFunction;
import com.github.chen0040.glm.maths.Mean;
import com.github.chen0040.glm.maths.StdDev;
import com.github.chen0040.glm.maths.Variance;
import com.github.chen0040.glm.metrics.GlmStatistics;
import com.github.chen0040.glm.search.LocalSearch;
import java.util.Random;

/* loaded from: input_file:com/github/chen0040/glm/solvers/GlmAlgorithmIrlsQrNewton.class */
public class GlmAlgorithmIrlsQrNewton extends GlmAlgorithm {
    private static final double EPSILON = 1.0E-20d;
    private static Random rand = new Random();
    private Matrix A;
    private Matrix b;
    private Matrix At;

    @Override // com.github.chen0040.glm.solvers.GlmAlgorithm
    public void copy(GlmAlgorithm glmAlgorithm) {
        super.copy(glmAlgorithm);
        GlmAlgorithmIrlsQrNewton glmAlgorithmIrlsQrNewton = (GlmAlgorithmIrlsQrNewton) glmAlgorithm;
        this.A = glmAlgorithmIrlsQrNewton.A == null ? null : (Matrix) glmAlgorithmIrlsQrNewton.A.clone();
        this.b = glmAlgorithmIrlsQrNewton.b == null ? null : (Matrix) glmAlgorithmIrlsQrNewton.b.clone();
        this.At = glmAlgorithmIrlsQrNewton.At == null ? null : (Matrix) glmAlgorithmIrlsQrNewton.At.clone();
    }

    @Override // com.github.chen0040.glm.solvers.GlmAlgorithm
    public GlmAlgorithm makeCopy() {
        GlmAlgorithmIrlsQrNewton glmAlgorithmIrlsQrNewton = new GlmAlgorithmIrlsQrNewton();
        glmAlgorithmIrlsQrNewton.copy(this);
        return glmAlgorithmIrlsQrNewton;
    }

    public GlmAlgorithmIrlsQrNewton() {
    }

    public GlmAlgorithmIrlsQrNewton(GlmDistributionFamily glmDistributionFamily, LinkFunction linkFunction, double[][] dArr, double[] dArr2) {
        super(glmDistributionFamily, linkFunction, (double[][]) null, (double[]) null, (LocalSearch) null);
        this.A = toMatrix(dArr);
        this.b = columnVector(dArr2);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(dArr[0].length, dArr2.length);
    }

    public GlmAlgorithmIrlsQrNewton(GlmDistributionFamily glmDistributionFamily, double[][] dArr, double[] dArr2) {
        super(glmDistributionFamily);
        this.A = toMatrix(dArr);
        this.b = columnVector(dArr2);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(dArr[0].length, dArr2.length);
    }

    private static Matrix toMatrix(double[][] dArr) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        Matrix matrix = new Matrix(length, length2);
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                matrix.set(i, i2, dArr[i][i2]);
            }
        }
        return matrix;
    }

    private static Matrix columnVector(double[] dArr) {
        int length = dArr.length;
        Matrix matrix = new Matrix(length, 1);
        for (int i = 0; i < length; i++) {
            matrix.set(i, 0, dArr[i]);
        }
        return matrix;
    }

    private static Matrix columnVector(int i) {
        return new Matrix(i, 1);
    }

    private static Matrix identity(int i) {
        Matrix matrix = new Matrix(i, i);
        for (int i2 = 0; i2 < i; i2++) {
            matrix.set(i2, i2, 1.0d);
        }
        return matrix;
    }

    @Override // com.github.chen0040.glm.solvers.GlmAlgorithm
    public double[] solve() {
        int rowDimension = this.A.getRowDimension();
        int columnDimension = this.A.getColumnDimension();
        Matrix columnVector = columnVector(columnDimension);
        Matrix columnVector2 = columnVector(columnDimension);
        for (int i = 0; i < columnDimension; i++) {
            columnVector.set(i, 0, 0.0d);
        }
        Matrix columnVector3 = columnVector(rowDimension);
        for (int i2 = 0; i2 < rowDimension; i2++) {
            columnVector3.set(i2, 0, 0.0d);
        }
        double[] dArr = new double[rowDimension];
        double[] dArr2 = new double[rowDimension];
        QRDecomposition qr = this.A.qr();
        Matrix q = qr.getQ();
        Matrix r = qr.getR();
        Matrix transpose = q.transpose();
        double[] dArr3 = null;
        for (int i3 = 0; i3 < this.maxIters; i3++) {
            Matrix columnVector4 = columnVector(rowDimension);
            for (int i4 = 0; i4 < rowDimension; i4++) {
                dArr[i4] = this.linkFunc.GetInvLink(columnVector3.get(i4, 0));
                dArr2[i4] = this.linkFunc.GetInvLinkDerivative(columnVector3.get(i4, 0));
                columnVector4.set(i4, 0, columnVector3.get(i4, 0) + ((this.b.get(i4, 0) - dArr[i4]) / dArr2[i4]));
            }
            dArr3 = new double[rowDimension];
            double d = Double.MAX_VALUE;
            for (int i5 = 0; i5 < rowDimension; i5++) {
                double variance = (dArr2[i5] * dArr2[i5]) / getVariance(dArr[i5]);
                dArr3[i5] = variance;
                d = Math.min(variance, d);
            }
            if (d < Math.sqrt(EPSILON)) {
                System.out.println("Warning: Tiny weights encountered, min(diag(W)) is too small");
            }
            Matrix matrix = columnVector;
            Matrix matrix2 = new Matrix(rowDimension, columnDimension);
            Matrix columnVector5 = columnVector(rowDimension);
            for (int i6 = 0; i6 < rowDimension; i6++) {
                columnVector5.set(i6, 0, columnVector4.get(i6, 0) * dArr3[i6]);
                for (int i7 = 0; i7 < columnDimension; i7++) {
                    matrix2.set(i6, i7, q.get(i6, i7) * dArr3[i6]);
                }
            }
            Matrix times = transpose.times(matrix2);
            Matrix times2 = transpose.times(columnVector5);
            Matrix l = times.chol().getL();
            Matrix transpose2 = l.transpose();
            columnVector = columnVector(columnDimension);
            for (int i8 = 0; i8 < columnDimension; i8++) {
                columnVector.set(i8, 0, 0.0d);
                columnVector2.set(i8, 0, 0.0d);
            }
            for (int i9 = 0; i9 < columnDimension; i9++) {
                double d2 = 0.0d;
                for (int i10 = 0; i10 < i9; i10++) {
                    d2 += l.get(i9, i10) * columnVector2.get(i10, 0);
                }
                columnVector2.set(i9, 0, (times2.get(i9, 0) - d2) / l.get(i9, i9));
            }
            for (int i11 = columnDimension - 1; i11 >= 0; i11--) {
                double d3 = 0.0d;
                for (int i12 = i11 + 1; i12 < columnDimension; i12++) {
                    d3 += transpose2.get(i11, i12) * columnVector.get(i12, 0);
                }
                columnVector.set(i11, 0, (columnVector2.get(i11, 0) - d3) / transpose2.get(i11, i11));
            }
            columnVector3 = q.times(columnVector);
            if (matrix.minus(columnVector).norm2() < this.mTol) {
                break;
            }
        }
        this.glmCoefficients = new double[columnDimension];
        Matrix times3 = transpose.times(columnVector3);
        for (int i13 = columnDimension - 1; i13 >= 0; i13--) {
            double d4 = 0.0d;
            for (int i14 = i13 + 1; i14 < columnDimension; i14++) {
                d4 += r.get(i13, i14) * this.glmCoefficients[i14];
            }
            this.glmCoefficients[i13] = (times3.get(i13, 0) - d4) / r.get(i13, i13);
        }
        updateStatistics(dArr3);
        return getCoefficients();
    }

    private Matrix scalarMultiply(Matrix matrix, double[] dArr) {
        int length = dArr.length;
        int rowDimension = matrix.getRowDimension();
        int columnDimension = matrix.getColumnDimension();
        Matrix matrix2 = new Matrix(rowDimension, columnDimension);
        if (length == rowDimension) {
            for (int i = 0; i < rowDimension; i++) {
                for (int i2 = 0; i2 < columnDimension; i2++) {
                    matrix2.set(i, i2, matrix.get(i, i2) * dArr[i]);
                }
            }
        } else if (length == columnDimension) {
            for (int i3 = 0; i3 < columnDimension; i3++) {
                for (int i4 = 0; i4 < rowDimension; i4++) {
                    matrix2.set(i4, i3, matrix.get(i4, i3) * dArr[i3]);
                }
            }
        }
        return matrix2;
    }

    protected void updateStatistics(double[] dArr) {
        Matrix inverse = scalarMultiply(this.At, dArr).times(this.A).inverse();
        int rowDimension = inverse.getRowDimension();
        int rowDimension2 = this.b.getRowDimension();
        double[] standardErrors = this.mStats.getStandardErrors();
        double[][] vCovMatrix = this.mStats.getVCovMatrix();
        double[] residuals = this.mStats.getResiduals();
        for (int i = 0; i < rowDimension; i++) {
            standardErrors[i] = Math.sqrt(inverse.get(i, i));
            for (int i2 = 0; i2 < rowDimension; i2++) {
                vCovMatrix[i][i2] = inverse.get(i, i2);
            }
        }
        double[] dArr2 = new double[rowDimension2];
        for (int i3 = 0; i3 < rowDimension2; i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < rowDimension; i4++) {
                d += this.A.get(i3, i4) * this.glmCoefficients[i4];
            }
            residuals[i3] = this.b.get(i3, 0) - this.linkFunc.GetInvLink(d);
            dArr2[i3] = this.b.get(i3, 0);
        }
        this.mStats.setResidualStdDev(StdDev.apply(residuals, 0.0d));
        this.mStats.setResponseMean(Mean.apply(dArr2));
        this.mStats.setResponseVariance(Variance.apply(dArr2, this.mStats.getResponseMean()));
        this.mStats.setR2(1.0d - ((this.mStats.getResidualStdDev() * this.mStats.getResidualStdDev()) / this.mStats.getResponseVariance()));
        this.mStats.setAdjustedR2(1.0d - ((((this.mStats.getResidualStdDev() * this.mStats.getResidualStdDev()) / this.mStats.getResponseVariance()) * (rowDimension - 1)) / ((rowDimension - this.glmCoefficients.length) - 1)));
    }
}
