package tagbio.umap;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import tagbio.umap.metric.Metric;

/* loaded from: input_file:tagbio/umap/ParallelNearestNeighborDescent.class */
class ParallelNearestNeighborDescent extends NearestNeighborDescent {
    private final int mThreads;

    /* JADX INFO: Access modifiers changed from: package-private */
    public ParallelNearestNeighborDescent(Metric metric, int i) {
        super(metric);
        if (i < 1) {
            throw new IllegalArgumentException();
        }
        this.mThreads = i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // tagbio.umap.NearestNeighborDescent
    public Heap descent(Matrix matrix, int i, Random random, int i2, boolean z, int i3, List<FlatTree> list) {
        return descent(matrix, i, random, i2, z, i3, list, 0.001f, 0.5f);
    }

    @Override // tagbio.umap.NearestNeighborDescent
    Heap descent(Matrix matrix, int i, Random random, int i2, boolean z, int i3, List<FlatTree> list, float f, float f2) {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.mThreads);
        try {
            try {
                UmapProgress.incTotal(i3);
                ArrayList arrayList = new ArrayList();
                int rows = matrix.rows();
                Heap heap = new Heap(matrix.rows(), i);
                int log2 = (int) (this.mThreads * (1.0d + MathUtils.log2(this.mThreads)));
                int i4 = ((rows + log2) - 1) / log2;
                for (int i5 = 0; i5 < log2; i5++) {
                    int i6 = i5 * i4;
                    int min = Math.min((i5 + 1) * i4, rows);
                    arrayList.add(newFixedThreadPool.submit(() -> {
                        for (int i7 = i6; i7 < min; i7++) {
                            float[] row = matrix.row(i7);
                            for (int i8 : Utils.rejectionSample(i, matrix.rows(), random)) {
                                float distance = this.mMetric.distance(row, matrix.row(i8));
                                heap.push(i7, distance, i8, true);
                                heap.push(i8, distance, i7, true);
                            }
                        }
                        return 0;
                    }));
                }
                waitForFutures(arrayList);
                if (z) {
                    int size = ((list.size() + log2) - 1) / log2;
                    for (int i7 = 0; i7 < log2; i7++) {
                        int i8 = i7 * size;
                        int min2 = Math.min((i7 + 1) * size, list.size());
                        arrayList.add(newFixedThreadPool.submit(() -> {
                            for (int i9 = i8; i9 < min2; i9++) {
                                for (int[] iArr : ((FlatTree) list.get(i9)).getIndices()) {
                                    for (int i10 = 0; i10 < iArr.length; i10++) {
                                        float[] row = matrix.row(iArr[i10]);
                                        for (int i11 = i10 + 1; i11 < iArr.length; i11++) {
                                            float distance = this.mMetric.distance(row, matrix.row(iArr[i11]));
                                            heap.push(iArr[i10], distance, iArr[i11], true);
                                            heap.push(iArr[i11], distance, iArr[i10], true);
                                        }
                                    }
                                }
                            }
                            return 0;
                        }));
                    }
                    waitForFutures(arrayList);
                }
                int i9 = 0;
                while (true) {
                    if (i9 >= i3) {
                        break;
                    }
                    if (this.mVerbose) {
                        Utils.message("NearestNeighborDescent: " + (i9 + 1) + " / " + i3);
                    }
                    Heap buildCandidates = heap.buildCandidates(rows, i, i2, random);
                    for (int i10 = 0; i10 < log2; i10++) {
                        int i11 = i10 * i4;
                        int min3 = Math.min((i10 + 1) * i4, rows);
                        arrayList.add(newFixedThreadPool.submit(() -> {
                            boolean[] zArr = new boolean[i2];
                            int i12 = 0;
                            for (int i13 = i11; i13 < min3; i13++) {
                                for (int i14 = 0; i14 < i2; i14++) {
                                    zArr[i14] = random.nextFloat() < f2;
                                }
                                for (int i15 = 0; i15 < i2; i15++) {
                                    int index = buildCandidates.index(i13, i15);
                                    if (index >= 0) {
                                        for (int i16 = 0; i16 <= i15; i16++) {
                                            int index2 = buildCandidates.index(i13, i16);
                                            if (index2 >= 0 && ((!zArr[i15] || !zArr[i16]) && (buildCandidates.isNew(i13, i15) || buildCandidates.isNew(i13, i16)))) {
                                                float distance = this.mMetric.distance(matrix.row(index), matrix.row(index2));
                                                if (heap.push(index, distance, index2, true)) {
                                                    i12++;
                                                }
                                                if (heap.push(index2, distance, index, true)) {
                                                    i12++;
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                            return Integer.valueOf(i12);
                        }));
                    }
                    if (waitForFutures(arrayList) <= f * i * matrix.rows()) {
                        UmapProgress.update(i3 - i9);
                        break;
                    }
                    UmapProgress.update();
                    i9++;
                }
                Heap deheapSort = heap.deheapSort();
                newFixedThreadPool.shutdown();
                return deheapSort;
            } catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            newFixedThreadPool.shutdown();
            throw th;
        }
    }

    private static int waitForFutures(List<Future<Integer>> list) throws InterruptedException, ExecutionException {
        int i = 0;
        Iterator<Future<Integer>> it = list.iterator();
        while (it.hasNext()) {
            i += it.next().get().intValue();
        }
        list.clear();
        return i;
    }
}
