package org.nd4j.evaluation;

import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/nd4j/evaluation/EvaluationUtils.class */
public class EvaluationUtils {
    public static double precision(long j, long j2, double d) {
        return (j == 0 && j2 == 0) ? d : j / (j + j2);
    }

    public static double recall(long j, long j2, double d) {
        return (j == 0 && j2 == 0) ? d : j / (j + j2);
    }

    public static double falsePositiveRate(long j, long j2, double d) {
        return (j == 0 && j2 == 0) ? d : j / (j + j2);
    }

    public static double falseNegativeRate(long j, long j2, double d) {
        return (j == 0 && j2 == 0) ? d : j / (j + j2);
    }

    public static double fBeta(double d, long j, long j2, long j3) {
        return fBeta(d, j / (j + j2), j / (j + j3));
    }

    public static double fBeta(double d, double d2, double d3) {
        if (d2 == 0.0d || d3 == 0.0d) {
            return 0.0d;
        }
        return (((1.0d + (d * d)) * d2) * d3) / (((d * d) * d2) + d3);
    }

    public static double gMeasure(double d, double d2) {
        return Math.sqrt(d * d2);
    }

    public static double matthewsCorrelation(long j, long j2, long j3, long j4) {
        return ((j * j4) - (j2 * j3)) / Math.sqrt((((j + j2) * (j + j3)) * (j4 + j2)) * (j4 + j3));
    }

    public static INDArray reshapeTimeSeriesTo2d(INDArray iNDArray) {
        long[] shape = iNDArray.shape();
        return shape[0] == 1 ? iNDArray.tensorAlongDimension(0L, 1, 2).permutei(1, 0) : shape[2] == 1 ? iNDArray.tensorAlongDimension(0L, 1, 0) : iNDArray.permute(0, 2, 1).reshape('f', shape[0] * shape[2], shape[1]);
    }

    public static Pair<INDArray, INDArray> extractNonMaskedTimeSteps(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        if (iNDArray.rank() != 3 || iNDArray2.rank() != 3) {
            throw new IllegalArgumentException("Invalid data: expect rank 3 arrays. Got arrays with shapes labels=" + Arrays.toString(iNDArray.shape()) + ", predictions=" + Arrays.toString(iNDArray2.shape()));
        }
        INDArray dup = iNDArray.dup('f');
        INDArray dup2 = iNDArray2.dup('f');
        INDArray reshapeTimeSeriesTo2d = reshapeTimeSeriesTo2d(dup);
        INDArray reshapeTimeSeriesTo2d2 = reshapeTimeSeriesTo2d(dup2);
        if (iNDArray3 == null) {
            return new Pair<>(reshapeTimeSeriesTo2d, reshapeTimeSeriesTo2d2);
        }
        float[] asFloat = reshapeTimeSeriesMaskToVector(iNDArray3).dup().data().asFloat();
        int[] iArr = new int[asFloat.length];
        int i = 0;
        for (int i2 = 0; i2 < asFloat.length; i2++) {
            if (asFloat[i2] == 1.0f) {
                int i3 = i;
                i++;
                iArr[i3] = i2;
            }
        }
        if (i == 0) {
            return null;
        }
        int[] copyOfRange = Arrays.copyOfRange(iArr, 0, i);
        return new Pair<>(Nd4j.pullRows(reshapeTimeSeriesTo2d, 1, copyOfRange), Nd4j.pullRows(reshapeTimeSeriesTo2d2, 1, copyOfRange));
    }

    public static INDArray reshapeTimeSeriesMaskToVector(INDArray iNDArray) {
        if (iNDArray.rank() != 2) {
            throw new IllegalArgumentException("Cannot reshape mask: rank is not 2");
        }
        if (iNDArray.ordering() != 'f') {
            iNDArray = iNDArray.dup('f');
        }
        return iNDArray.reshape('f', iNDArray.length(), 1);
    }
}
