package io.proximax.core.math;

import java.util.ArrayList;
import java.util.Collection;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleConsumer;
import java.util.function.DoubleUnaryOperator;

/* loaded from: input_file:io/proximax/core/math/Matrix.class */
public abstract class Matrix {
    private final int numRows;
    private final int numCols;

    /* JADX INFO: Access modifiers changed from: protected */
    @FunctionalInterface
    /* loaded from: input_file:io/proximax/core/math/Matrix$ElementVisitorFunction.class */
    public interface ElementVisitorFunction {
        void visit(int i, int i2, double d, DoubleConsumer doubleConsumer);
    }

    @FunctionalInterface
    /* loaded from: input_file:io/proximax/core/math/Matrix$ReadOnlyElementVisitorFunction.class */
    public interface ReadOnlyElementVisitorFunction {
        void visit(int i, int i2, double d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix(int i, int i2) {
        this.numRows = i;
        this.numCols = i2;
    }

    public final int getElementCount() {
        return this.numRows * this.numCols;
    }

    public final int getRowCount() {
        return this.numRows;
    }

    public final int getColumnCount() {
        return this.numCols;
    }

    public final double getAt(int i, int i2) {
        checkBounds(i, i2);
        return getAtUnchecked(i, i2);
    }

    public final void setAt(int i, int i2, double d) {
        checkBounds(i, i2);
        setAtUnchecked(i, i2, d);
    }

    public final void incrementAt(int i, int i2, double d) {
        setAtUnchecked(i, i2, getAt(i, i2) + d);
    }

    public final ColumnVector getRowSumVector() {
        double[] dArr = new double[this.numRows];
        forEach((i, i2, d) -> {
            dArr[i] = dArr[i] + d;
        });
        return new ColumnVector(dArr);
    }

    public final ColumnVector getColumnSumVector() {
        return new ColumnVector(getColumnSums(d -> {
            return d;
        }));
    }

    private double[] getColumnSums(DoubleUnaryOperator doubleUnaryOperator) {
        double[] dArr = new double[this.numCols];
        forEach((i, i2, d) -> {
            dArr[i2] = dArr[i2] + doubleUnaryOperator.applyAsDouble(d);
        });
        return dArr;
    }

    public Collection<Integer> normalizeColumns() {
        double[] columnSums = getColumnSums(Math::abs);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numCols; i++) {
            if (0.0d == columnSums[i]) {
                arrayList.add(Integer.valueOf(i));
            }
        }
        forEach((i2, i3, d, doubleConsumer) -> {
            double d = columnSums[i3];
            if (0.0d == d) {
                return;
            }
            doubleConsumer.accept(d / d);
        });
        return arrayList;
    }

    public void removeNegatives() {
        removeLessThan(0.0d);
    }

    public void removeLessThan(double d) {
        forEach((i, i2, d2, doubleConsumer) -> {
            if (d2 < d) {
                doubleConsumer.accept(0.0d);
            }
        });
    }

    public final void scale(double d) {
        forEach((i, i2, d2, doubleConsumer) -> {
            doubleConsumer.accept(d2 / d);
        });
    }

    public Matrix multiplyElementWise(Matrix matrix) {
        return join(matrix, false, (d, d2) -> {
            return d * d2;
        });
    }

    public Matrix addElementWise(Matrix matrix) {
        return join(matrix, true, (d, d2) -> {
            return d + d2;
        });
    }

    private Matrix join(Matrix matrix, boolean z, DoubleBinaryOperator doubleBinaryOperator) {
        if (!isSameSize(matrix)) {
            throw new IllegalArgumentException("matrix sizes must be equal");
        }
        Matrix create = create(getRowCount(), getColumnCount());
        forEach((i, i2, d) -> {
            create.setAtUnchecked(i, i2, doubleBinaryOperator.applyAsDouble(d, matrix.getAtUnchecked(i, i2)));
        });
        if (z) {
            matrix.forEach((i3, i4, d2) -> {
                create.setAtUnchecked(i3, i4, doubleBinaryOperator.applyAsDouble(d2, getAtUnchecked(i3, i4)));
            });
        }
        return create;
    }

    public final double absSum() {
        return aggregate(Math::abs);
    }

    public final double sum() {
        return aggregate(d -> {
            return d;
        });
    }

    private double aggregate(DoubleUnaryOperator doubleUnaryOperator) {
        double[] dArr = {0.0d};
        forEach((i, i2, d) -> {
            dArr[0] = dArr[0] + doubleUnaryOperator.applyAsDouble(d);
        });
        return dArr[0];
    }

    public ColumnVector multiply(ColumnVector columnVector) {
        if (this.numCols != columnVector.size()) {
            throw new IllegalArgumentException("vector size and matrix column count must be equal");
        }
        double[] dArr = new double[this.numRows];
        double[] raw = columnVector.getRaw();
        forEach((i, i2, d) -> {
            dArr[i] = dArr[i] + (d * raw[i2]);
        });
        return new ColumnVector(dArr);
    }

    public final Matrix transpose() {
        Matrix create = create(getColumnCount(), getRowCount());
        forEach((i, i2, d) -> {
            create.setAtUnchecked(i2, i, d);
        });
        return create;
    }

    public Matrix roundTo(int i) {
        double pow = Math.pow(10.0d, i);
        return transform(d -> {
            return Math.round(d * pow) / pow;
        });
    }

    public Matrix multiply(double d) {
        return transform(d2 -> {
            return d2 * d;
        });
    }

    public Matrix add(double d) {
        return transform(d2 -> {
            return d2 + d;
        });
    }

    public Matrix abs() {
        return transform(Math::abs);
    }

    public Matrix sqrt() {
        return transform(Math::sqrt);
    }

    private Matrix transform(DoubleUnaryOperator doubleUnaryOperator) {
        Matrix create = create(getRowCount(), getColumnCount());
        forEach((i, i2, d) -> {
            create.setAtUnchecked(i, i2, doubleUnaryOperator.applyAsDouble(d));
        });
        return create;
    }

    public final boolean isSameSize(Matrix matrix) {
        return this.numRows == matrix.numRows && this.numCols == matrix.numCols;
    }

    private void checkBounds(int i, int i2) {
        if (i < 0 || i >= this.numRows) {
            throw new IndexOutOfBoundsException("Row index out of bounds");
        }
        if (i2 < 0 || i2 >= this.numCols) {
            throw new IndexOutOfBoundsException("Column index out of bounds");
        }
    }

    public final boolean isZeroMatrix() {
        return 0.0d == absSum();
    }

    public int hashCode() {
        return getRowCount() ^ getColumnCount();
    }

    public boolean equals(Object obj) {
        if (!(obj instanceof Matrix)) {
            return false;
        }
        Matrix matrix = (Matrix) obj;
        return isSameSize(matrix) && 0.0d == join(matrix, true, (d, d2) -> {
            return d == d2 ? 0.0d : 1.0d;
        }).sum();
    }

    public void forEach(ReadOnlyElementVisitorFunction readOnlyElementVisitorFunction) {
        forEach((i, i2, d, doubleConsumer) -> {
            readOnlyElementVisitorFunction.visit(i, i2, d);
        });
    }

    protected abstract Matrix create(int i, int i2);

    protected abstract double getAtUnchecked(int i, int i2);

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract void setAtUnchecked(int i, int i2, double d);

    protected abstract void forEach(ElementVisitorFunction elementVisitorFunction);

    public abstract MatrixNonZeroElementRowIterator getNonZeroElementRowIterator(int i);
}
