package net.myrrix.common.math;

import com.google.common.base.Preconditions;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Iterator;
import net.myrrix.common.collection.FastByIDFloatMap;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.collection.FastIDSet;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/myrrix-common-0.10.jar:net/myrrix/common/math/MatrixUtils.class */
public final class MatrixUtils {
    private static final int PRINT_COLUMN_WIDTH = 12;
    private static final Logger log = LoggerFactory.getLogger(MatrixUtils.class);
    private static final double SINGULARITY_THRESHOLD = Double.parseDouble(System.getProperty("common.matrix.singularityThreshold", "0.001"));
    private static final Field MATRIX_DATA_FIELD = loadField(Array2DRowRealMatrix.class, "data");
    private static final Field RDIAG_FIELD = loadField(QRDecomposition.class, "rDiag");

    private static Field loadField(Class<?> cls, String str) {
        try {
            Field declaredField = cls.getDeclaredField(str);
            declaredField.setAccessible(true);
            return declaredField;
        } catch (NoSuchFieldException e) {
            log.error("Can't access {}.{}", cls, str);
            throw new IllegalStateException(e);
        }
    }

    private MatrixUtils() {
    }

    public static void addTo(long j, long j2, float f, FastByIDMap<FastByIDFloatMap> fastByIDMap, FastByIDMap<FastByIDFloatMap> fastByIDMap2) {
        FastByIDFloatMap fastByIDFloatMap = fastByIDMap.get(j);
        if (fastByIDFloatMap == null) {
            fastByIDFloatMap = new FastByIDFloatMap();
            fastByIDMap.put(j, fastByIDFloatMap);
        }
        fastByIDFloatMap.increment(j2, f);
        FastByIDFloatMap fastByIDFloatMap2 = fastByIDMap2.get(j2);
        if (fastByIDFloatMap2 == null) {
            fastByIDFloatMap2 = new FastByIDFloatMap();
            fastByIDMap2.put(j2, fastByIDFloatMap2);
        }
        fastByIDFloatMap2.increment(j, f);
    }

    public static void remove(long j, long j2, FastByIDMap<FastByIDFloatMap> fastByIDMap, FastByIDMap<FastByIDFloatMap> fastByIDMap2) {
        FastByIDFloatMap fastByIDFloatMap = fastByIDMap.get(j);
        if (fastByIDFloatMap != null) {
            fastByIDFloatMap.remove(j2);
            if (fastByIDFloatMap.isEmpty()) {
                fastByIDMap.remove(j);
            }
        }
        FastByIDFloatMap fastByIDFloatMap2 = fastByIDMap2.get(j2);
        if (fastByIDFloatMap2 != null) {
            fastByIDFloatMap2.remove(j);
            if (fastByIDFloatMap2.isEmpty()) {
                fastByIDMap2.remove(j2);
            }
        }
    }

    public static FastByIDMap<float[]> getPseudoInverse(FastByIDMap<float[]> fastByIDMap) {
        return (fastByIDMap == null || fastByIDMap.isEmpty()) ? fastByIDMap : multiply(getTransposeTimesSelfInverse(fastByIDMap), fastByIDMap);
    }

    public static RealMatrix getTransposeTimesSelfInverse(FastByIDMap<float[]> fastByIDMap) {
        if (fastByIDMap == null || fastByIDMap.isEmpty()) {
            return null;
        }
        return invert(transposeTimesSelf(fastByIDMap));
    }

    public static RealMatrix invert(RealMatrix realMatrix) {
        QRDecomposition qRDecomposition = new QRDecomposition(realMatrix, SINGULARITY_THRESHOLD);
        try {
            return new Array2DRowRealMatrix(qRDecomposition.getSolver().getInverse().getData());
        } catch (SingularMatrixException e) {
            log.warn("{} x {} matrix is near-singular (threshold {}); add more data or decrease the value of model.features ({})", Integer.valueOf(realMatrix.getRowDimension()), Integer.valueOf(realMatrix.getColumnDimension()), Double.valueOf(SINGULARITY_THRESHOLD), e.toString());
            try {
                double[] dArr = (double[]) RDIAG_FIELD.get(qRDecomposition);
                log.info("QR decomposition diagonal: {}", Arrays.toString(dArr));
                int i = 0;
                while (true) {
                    if (i >= dArr.length) {
                        break;
                    }
                    if (FastMath.abs(dArr[i]) <= SINGULARITY_THRESHOLD) {
                        log.info("Suggested value of -Dmodel.features is less than {}", Integer.valueOf(i));
                        break;
                    }
                    i++;
                }
                throw e;
            } catch (IllegalAccessException e2) {
                log.warn("Can't read QR decomposition fields to suggest dimensionality");
                throw e;
            }
        }
    }

    public static FastByIDMap<float[]> multiply(RealMatrix realMatrix, FastByIDMap<float[]> fastByIDMap) {
        FastByIDMap<float[]> fastByIDMap2 = new FastByIDMap<>(fastByIDMap.size(), 1.25f);
        double[][] accessMatrixDataDirectly = accessMatrixDataDirectly(realMatrix);
        for (FastByIDMap.MapEntry<float[]> mapEntry : fastByIDMap.entrySet()) {
            fastByIDMap2.put(mapEntry.getKey(), matrixMultiply(accessMatrixDataDirectly, mapEntry.getValue()));
        }
        return fastByIDMap2;
    }

    public static RealMatrix multiplyXYT(FastByIDMap<float[]> fastByIDMap, FastByIDMap<float[]> fastByIDMap2) {
        int size = fastByIDMap2.size();
        int size2 = fastByIDMap.size();
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(size2, size);
        for (int i = 0; i < size2; i++) {
            for (int i2 = 0; i2 < size; i2++) {
                array2DRowRealMatrix.setEntry(i, i2, SimpleVectorMath.dot(fastByIDMap.get(i), fastByIDMap2.get(i2)));
            }
        }
        return array2DRowRealMatrix;
    }

    private static double[][] accessMatrixDataDirectly(RealMatrix realMatrix) {
        try {
            return (double[][]) MATRIX_DATA_FIELD.get(realMatrix);
        } catch (IllegalAccessException e) {
            throw new IllegalStateException(e);
        }
    }

    public static double[] multiply(RealMatrix realMatrix, float[] fArr) {
        double[][] accessMatrixDataDirectly = accessMatrixDataDirectly(realMatrix);
        int length = accessMatrixDataDirectly.length;
        int length2 = fArr.length;
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            double d = 0.0d;
            double[] dArr2 = accessMatrixDataDirectly[i];
            for (int i2 = 0; i2 < length2; i2++) {
                d += fArr[i2] * dArr2[i2];
            }
            dArr[i] = d;
        }
        return dArr;
    }

    private static float[] matrixMultiply(double[][] dArr, float[] fArr) {
        int length = dArr.length;
        int length2 = fArr.length;
        float[] fArr2 = new float[length];
        for (int i = 0; i < length; i++) {
            double d = 0.0d;
            double[] dArr2 = dArr[i];
            for (int i2 = 0; i2 < length2; i2++) {
                d += fArr[i2] * dArr2[i2];
            }
            fArr2[i] = (float) d;
        }
        return fArr2;
    }

    public static RealMatrix transposeTimesSelf(FastByIDMap<float[]> fastByIDMap) {
        Array2DRowRealMatrix array2DRowRealMatrix = null;
        Iterator<FastByIDMap.MapEntry<float[]>> it = fastByIDMap.entrySet().iterator();
        while (it.hasNext()) {
            float[] value = it.next().getValue();
            int length = value.length;
            if (array2DRowRealMatrix == null) {
                array2DRowRealMatrix = new Array2DRowRealMatrix(length, length);
            }
            for (int i = 0; i < length; i++) {
                float f = value[i];
                for (int i2 = 0; i2 < length; i2++) {
                    array2DRowRealMatrix.addToEntry(i, i2, f * value[i2]);
                }
            }
        }
        Preconditions.checkNotNull(array2DRowRealMatrix);
        return array2DRowRealMatrix;
    }

    public static String matrixToString(FastByIDMap<FastByIDFloatMap> fastByIDMap) {
        StringBuilder sb = new StringBuilder();
        long[] unionColumnKeysInOrder = unionColumnKeysInOrder(fastByIDMap);
        appendWithPadOrTruncate("", sb);
        for (long j : unionColumnKeysInOrder) {
            sb.append('\t');
            appendWithPadOrTruncate(j, sb);
        }
        sb.append("\n\n");
        for (long j2 : keysInOrder(fastByIDMap)) {
            appendWithPadOrTruncate(j2, sb);
            FastByIDFloatMap fastByIDFloatMap = fastByIDMap.get(j2);
            for (long j3 : unionColumnKeysInOrder) {
                sb.append('\t');
                float f = fastByIDFloatMap.get(j3);
                if (Float.isNaN(f)) {
                    appendWithPadOrTruncate("", sb);
                } else {
                    appendWithPadOrTruncate(f, sb);
                }
            }
            sb.append('\n');
        }
        sb.append('\n');
        return sb.toString();
    }

    private static long[] keysInOrder(FastByIDMap<?> fastByIDMap) {
        FastIDSet fastIDSet = new FastIDSet(fastByIDMap.size(), 1.25f);
        LongPrimitiveIterator keySetIterator = fastByIDMap.keySetIterator();
        while (keySetIterator.hasNext()) {
            fastIDSet.add(keySetIterator.nextLong());
        }
        long[] array = fastIDSet.toArray();
        Arrays.sort(array);
        return array;
    }

    private static long[] unionColumnKeysInOrder(FastByIDMap<FastByIDFloatMap> fastByIDMap) {
        FastIDSet fastIDSet = new FastIDSet(1000, 1.25f);
        Iterator<FastByIDMap.MapEntry<FastByIDFloatMap>> it = fastByIDMap.entrySet().iterator();
        while (it.hasNext()) {
            LongPrimitiveIterator keySetIterator = it.next().getValue().keySetIterator();
            while (keySetIterator.hasNext()) {
                fastIDSet.add(keySetIterator.nextLong());
            }
        }
        long[] array = fastIDSet.toArray();
        Arrays.sort(array);
        return array;
    }

    private static void appendWithPadOrTruncate(long j, StringBuilder sb) {
        appendWithPadOrTruncate(Long.toString(j), sb);
    }

    private static void appendWithPadOrTruncate(float f, StringBuilder sb) {
        String f2 = Float.toString(f);
        if (f >= 0.0f) {
            f2 = ' ' + f2;
        }
        appendWithPadOrTruncate(f2, sb);
    }

    private static void appendWithPadOrTruncate(CharSequence charSequence, StringBuilder sb) {
        int length = charSequence.length();
        if (length >= 12) {
            sb.append(charSequence, 0, 12);
            return;
        }
        for (int i = length; i < 12; i++) {
            sb.append(' ');
        }
        sb.append(charSequence);
    }
}
