package com.github.chen0040.glm.solvers;

import Jama.Matrix;
import Jama.SingularValueDecomposition;
import com.github.chen0040.glm.enums.GlmDistributionFamily;
import com.github.chen0040.glm.links.LinkFunction;
import com.github.chen0040.glm.maths.Mean;
import com.github.chen0040.glm.maths.StdDev;
import com.github.chen0040.glm.maths.Variance;
import com.github.chen0040.glm.metrics.GlmStatistics;
import com.github.chen0040.ls.LocalSearch;

/* loaded from: input_file:com/github/chen0040/glm/solvers/GlmAlgorithmIrlsSvdNewton.class */
public class GlmAlgorithmIrlsSvdNewton extends GlmAlgorithm {
    private static final double EPSILON = 1.0E-34d;
    private Matrix A;
    private Matrix b;
    private Matrix At;

    @Override // com.github.chen0040.glm.solvers.GlmAlgorithm
    public void copy(GlmAlgorithm glmAlgorithm) {
        super.copy(glmAlgorithm);
        GlmAlgorithmIrlsSvdNewton glmAlgorithmIrlsSvdNewton = (GlmAlgorithmIrlsSvdNewton) glmAlgorithm;
        this.A = glmAlgorithmIrlsSvdNewton.A == null ? null : (Matrix) glmAlgorithmIrlsSvdNewton.A.clone();
        this.b = glmAlgorithmIrlsSvdNewton.b == null ? null : (Matrix) glmAlgorithmIrlsSvdNewton.b.clone();
        this.At = glmAlgorithmIrlsSvdNewton.At == null ? null : (Matrix) glmAlgorithmIrlsSvdNewton.At.clone();
    }

    @Override // com.github.chen0040.glm.solvers.GlmAlgorithm
    public GlmAlgorithm makeCopy() {
        GlmAlgorithmIrlsSvdNewton glmAlgorithmIrlsSvdNewton = new GlmAlgorithmIrlsSvdNewton();
        glmAlgorithmIrlsSvdNewton.copy(this);
        return glmAlgorithmIrlsSvdNewton;
    }

    public GlmAlgorithmIrlsSvdNewton() {
    }

    public GlmAlgorithmIrlsSvdNewton(GlmDistributionFamily glmDistributionFamily, LinkFunction linkFunction, double[][] dArr, double[] dArr2) {
        super(glmDistributionFamily, linkFunction, (double[][]) null, (double[]) null, (LocalSearch) null);
        this.A = new Matrix(dArr);
        this.b = columnVector(dArr2);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(dArr[0].length, dArr2.length);
    }

    public GlmAlgorithmIrlsSvdNewton(GlmDistributionFamily glmDistributionFamily, double[][] dArr, double[] dArr2) {
        super(glmDistributionFamily);
        this.A = new Matrix(dArr);
        this.b = columnVector(dArr2);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(dArr[0].length, dArr2.length);
    }

    private static Matrix columnVector(double[] dArr) {
        int length = dArr.length;
        Matrix matrix = new Matrix(length, 1);
        for (int i = 0; i < length; i++) {
            matrix.set(i, 0, dArr[i]);
        }
        return matrix;
    }

    private static Matrix columnVector(int i) {
        return new Matrix(i, 1);
    }

    @Override // com.github.chen0040.glm.solvers.GlmAlgorithm
    public double[] solve() {
        int rowDimension = this.A.getRowDimension();
        int columnDimension = this.A.getColumnDimension();
        int min = Math.min(rowDimension, columnDimension);
        Matrix columnVector = columnVector(rowDimension);
        Matrix columnVector2 = columnVector(columnDimension);
        Matrix columnVector3 = columnVector(columnDimension);
        SingularValueDecomposition svd = this.A.svd();
        Matrix u = svd.getU();
        Matrix v = svd.getV();
        Matrix s = svd.getS();
        Matrix transpose = u.transpose();
        Matrix matrix = new Matrix(min, min);
        for (int i = 0; i < min; i++) {
            double d = s.get(i, i);
            if (d < EPSILON) {
                System.out.println("Near rank-deficient model matrix");
                return null;
            }
            matrix.set(i, i, 1.0d / d);
        }
        Matrix transpose2 = matrix.transpose();
        double[] dArr = new double[rowDimension];
        for (int i2 = 0; i2 < this.maxIters; i2++) {
            Matrix columnVector4 = columnVector(rowDimension);
            double[] dArr2 = new double[rowDimension];
            double[] dArr3 = new double[rowDimension];
            for (int i3 = 0; i3 < rowDimension; i3++) {
                dArr2[i3] = this.linkFunc.GetInvLink(columnVector.get(i3, 0));
                dArr3[i3] = this.linkFunc.GetInvLinkDerivative(columnVector.get(i3, 0));
                columnVector4.set(i3, 0, columnVector.get(i3, 0) + ((this.b.get(i3, 0) - dArr2[i3]) / dArr3[i3]));
            }
            int i4 = 0;
            for (int i5 = 0; i5 < rowDimension; i5++) {
                double variance = (dArr3[i5] * dArr3[i5]) / getVariance(dArr2[i5]);
                dArr[i5] = variance;
                if (variance < 2.0E-34d) {
                    i4++;
                }
            }
            if (i4 > 0) {
                System.out.println("Warning: tiny weights encountered, (diag(W)) is too small");
            }
            Matrix matrix2 = columnVector2;
            Matrix matrix3 = new Matrix(min, rowDimension);
            for (int i6 = 0; i6 < min; i6++) {
                for (int i7 = 0; i7 < rowDimension; i7++) {
                    matrix3.set(i6, i7, transpose.get(i6, i7) * dArr[i7]);
                }
            }
            Matrix l = matrix3.times(u).chol().getL();
            Matrix transpose3 = l.transpose();
            Matrix times = matrix3.times(columnVector4);
            columnVector2 = columnVector(columnDimension);
            for (int i8 = 0; i8 < columnDimension; i8++) {
                columnVector2.set(i8, 0, 0.0d);
                columnVector3.set(i8, 0, 0.0d);
            }
            for (int i9 = 0; i9 < columnDimension; i9++) {
                double d2 = 0.0d;
                for (int i10 = 0; i10 < i9; i10++) {
                    d2 += l.get(i9, i10) * columnVector3.get(i10, 0);
                }
                columnVector3.set(i9, 0, (times.get(i9, 0) - d2) / l.get(i9, i9));
            }
            for (int i11 = columnDimension - 1; i11 >= 0; i11--) {
                double d3 = 0.0d;
                for (int i12 = i11 + 1; i12 < columnDimension; i12++) {
                    d3 += transpose3.get(i11, i12) * columnVector2.get(i12, 0);
                }
                columnVector2.set(i11, 0, (columnVector3.get(i11, 0) - d3) / transpose3.get(i11, i11));
            }
            columnVector = u.times(columnVector2);
            if (matrix2.minus(columnVector2).norm2() < this.mTol) {
                break;
            }
        }
        Matrix times2 = v.times(transpose2).times(transpose).times(columnVector);
        this.glmCoefficients = new double[columnDimension];
        for (int i13 = 0; i13 < columnDimension; i13++) {
            this.glmCoefficients[i13] = times2.get(i13, 0);
        }
        updateStatistics(dArr);
        return getCoefficients();
    }

    @Override // com.github.chen0040.glm.solvers.GlmAlgorithm
    public double[] getCoefficients() {
        return this.glmCoefficients;
    }

    private Matrix scalarMultiply(Matrix matrix, double[] dArr) {
        int length = dArr.length;
        int rowDimension = matrix.getRowDimension();
        int columnDimension = matrix.getColumnDimension();
        Matrix matrix2 = new Matrix(rowDimension, columnDimension);
        if (length == rowDimension) {
            for (int i = 0; i < rowDimension; i++) {
                for (int i2 = 0; i2 < columnDimension; i2++) {
                    matrix2.set(i, i2, matrix.get(i, i2) * dArr[i]);
                }
            }
        } else if (length == columnDimension) {
            for (int i3 = 0; i3 < columnDimension; i3++) {
                for (int i4 = 0; i4 < rowDimension; i4++) {
                    matrix2.set(i4, i3, matrix.get(i4, i3) * dArr[i3]);
                }
            }
        }
        return matrix2;
    }

    protected void updateStatistics(double[] dArr) {
        Matrix inverse = scalarMultiply(this.At, dArr).times(this.A).inverse();
        int rowDimension = inverse.getRowDimension();
        int rowDimension2 = this.b.getRowDimension();
        double[] standardErrors = this.mStats.getStandardErrors();
        double[][] vCovMatrix = this.mStats.getVCovMatrix();
        double[] residuals = this.mStats.getResiduals();
        for (int i = 0; i < rowDimension; i++) {
            standardErrors[i] = Math.sqrt(inverse.get(i, i));
            for (int i2 = 0; i2 < rowDimension; i2++) {
                vCovMatrix[i][i2] = inverse.get(i, i2);
            }
        }
        double[] dArr2 = new double[rowDimension2];
        for (int i3 = 0; i3 < rowDimension2; i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < rowDimension; i4++) {
                d += this.A.get(i3, i4) * this.glmCoefficients[i4];
            }
            residuals[i3] = this.b.get(i3, 0) - this.linkFunc.GetInvLink(d);
            dArr2[i3] = this.b.get(i3, 0);
        }
        this.mStats.setResidualStdDev(StdDev.apply(residuals, 0.0d));
        this.mStats.setResponseMean(Mean.apply(dArr2));
        this.mStats.setResponseVariance(Variance.apply(dArr2, this.mStats.getResponseMean()));
        this.mStats.setR2(1.0d - ((this.mStats.getResidualStdDev() * this.mStats.getResidualStdDev()) / this.mStats.getResponseVariance()));
        this.mStats.setAdjustedR2(1.0d - ((((this.mStats.getResidualStdDev() * this.mStats.getResidualStdDev()) / this.mStats.getResponseVariance()) * (rowDimension - 1)) / ((rowDimension - this.glmCoefficients.length) - 1)));
    }
}
