package de.jungblut.classification.eval;

import com.google.common.base.Preconditions;
import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.math.DoubleVector;
import java.util.Arrays;
import java.util.Deque;
import java.util.LinkedList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/jungblut/classification/eval/EvaluationSplit.class */
public class EvaluationSplit {
    private static final Logger LOG = LogManager.getLogger(EvaluationSplit.class);
    private final DoubleVector[] trainFeatures;
    private final DoubleVector[] trainOutcome;
    private final DoubleVector[] testFeatures;
    private final DoubleVector[] testOutcome;

    public EvaluationSplit(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, DoubleVector[] doubleVectorArr3, DoubleVector[] doubleVectorArr4) {
        this.trainFeatures = doubleVectorArr;
        this.trainOutcome = doubleVectorArr2;
        this.testFeatures = doubleVectorArr3;
        this.testOutcome = doubleVectorArr4;
    }

    public DoubleVector[] getTrainFeatures() {
        return this.trainFeatures;
    }

    public DoubleVector[] getTrainOutcome() {
        return this.trainOutcome;
    }

    public DoubleVector[] getTestFeatures() {
        return this.testFeatures;
    }

    public DoubleVector[] getTestOutcome() {
        return this.testOutcome;
    }

    /* JADX WARN: Type inference failed for: r1v17, types: [de.jungblut.math.DoubleVector[], java.lang.Object[][]] */
    public static EvaluationSplit create(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, float f, boolean z) {
        Preconditions.checkArgument(doubleVectorArr.length == doubleVectorArr2.length, "Feature vector and outcome vector must match in length!");
        Preconditions.checkArgument(f >= 0.0f && f <= 1.0f, "splitFraction must be between 0 and 1! Given: " + f);
        if (z) {
            ArrayUtils.multiShuffle(doubleVectorArr, new DoubleVector[]{doubleVectorArr2});
        }
        int length = (int) (doubleVectorArr.length * f);
        return new EvaluationSplit((DoubleVector[]) ArrayUtils.subArray(doubleVectorArr, length - 1), (DoubleVector[]) ArrayUtils.subArray(doubleVectorArr2, length - 1), (DoubleVector[]) ArrayUtils.subArray(doubleVectorArr, length, doubleVectorArr.length - 1), (DoubleVector[]) ArrayUtils.subArray(doubleVectorArr2, length, doubleVectorArr2.length - 1));
    }

    /* JADX WARN: Type inference failed for: r1v38, types: [de.jungblut.math.DoubleVector[], java.lang.Object[][]] */
    /* JADX WARN: Type inference failed for: r1v40, types: [de.jungblut.math.DoubleVector[], java.lang.Object[][]] */
    public static EvaluationSplit createStratified(DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, float f, boolean z) {
        Preconditions.checkArgument(doubleVectorArr.length == doubleVectorArr2.length, "Feature vector and outcome vector must match in length!");
        Preconditions.checkArgument(f >= 0.0f && f <= 1.0f, "splitFraction must be between 0 and 1! Given: " + f);
        Deque[] dequeArr = new Deque[Math.max(2, doubleVectorArr2[0].getDimension())];
        DoubleVector[] doubleVectorArr3 = new DoubleVector[dequeArr.length];
        for (int i = 0; i < doubleVectorArr.length; i++) {
            int maxIndex = dequeArr.length == 2 ? (int) doubleVectorArr2[i].get(0) : doubleVectorArr2[i].maxIndex();
            Deque deque = dequeArr[maxIndex];
            if (deque == null) {
                deque = new LinkedList();
                dequeArr[maxIndex] = deque;
            }
            deque.addLast(doubleVectorArr[i]);
            doubleVectorArr3[maxIndex] = doubleVectorArr2[i];
        }
        for (int i2 = 0; i2 < dequeArr.length; i2++) {
            Preconditions.checkNotNull(dequeArr[i2], "Queue for class " + i2 + " couldn't be found. This happens when the mentioned class label doesn't exists in the given set of vectors.");
        }
        int length = (int) (doubleVectorArr.length * f);
        double[] dArr = new double[dequeArr.length];
        int[] iArr = new int[dequeArr.length];
        int i3 = 0;
        for (int i4 = 0; i4 < dequeArr.length; i4++) {
            dArr[i4] = dequeArr[i4].size() / doubleVectorArr.length;
            iArr[i4] = (int) (length * dArr[i4]);
            Preconditions.checkArgument(iArr[i4] > 0, "Can't stratify the class " + i4 + " because the split size was too small to satisfy the sampling requirement.");
            i3 += iArr[i4];
        }
        if (i3 != length) {
            LOG.warn("Correcting the split size from " + length + " to " + i3 + ", to satisfy the sampling target.");
            length = i3;
        }
        DoubleVector[] doubleVectorArr4 = new DoubleVector[length];
        DoubleVector[] doubleVectorArr5 = new DoubleVector[length];
        LOG.info("Sampling probabilities by class: " + Arrays.toString(dArr));
        int i5 = 0;
        for (int i6 = 0; i6 < iArr.length; i6++) {
            for (int i7 = 0; i7 < iArr[i6]; i7++) {
                doubleVectorArr4[i5] = (DoubleVector) dequeArr[i6].poll();
                doubleVectorArr5[i5] = doubleVectorArr3[i6];
                i5++;
            }
        }
        Preconditions.checkArgument(i5 == doubleVectorArr4.length, "Didn't fill up the targeted split size of " + length + " vectors in the training set!");
        DoubleVector[] doubleVectorArr6 = new DoubleVector[doubleVectorArr.length - length];
        DoubleVector[] doubleVectorArr7 = new DoubleVector[doubleVectorArr.length - length];
        int i8 = 0;
        for (int i9 = 0; i9 < dequeArr.length; i9++) {
            while (!dequeArr[i9].isEmpty()) {
                Preconditions.checkArgument(i8 < doubleVectorArr6.length, "Features are overflowing the calculated testset size, stratifying failed.");
                doubleVectorArr6[i8] = (DoubleVector) dequeArr[i9].poll();
                doubleVectorArr7[i8] = doubleVectorArr3[i9];
                i8++;
            }
        }
        if (z) {
            ArrayUtils.multiShuffle(doubleVectorArr4, new DoubleVector[]{doubleVectorArr5});
            ArrayUtils.multiShuffle(doubleVectorArr6, new DoubleVector[]{doubleVectorArr7});
        }
        return new EvaluationSplit(doubleVectorArr4, doubleVectorArr5, doubleVectorArr6, doubleVectorArr7);
    }
}
