package org.deeplearning4j.clustering.vptree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.deeplearning4j.berkeley.CounterMap;
import org.deeplearning4j.berkeley.PriorityQueue;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.sptree.HeapItem;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/clustering/vptree/VPTree.class */
public class VPTree {
    public static final String EUCLIDEAN = "euclidean";
    private List<DataPoint> items;
    private double tau;
    private Node root;
    private CounterMap<DataPoint, DataPoint> distances;
    private String similarityFunction;
    private boolean invert;

    /* loaded from: input_file:org/deeplearning4j/clustering/vptree/VPTree$Node.class */
    public static class Node {
        private int index;
        private double threshold;
        private Node left;
        private Node right;

        public Node(int i, double d) {
            this.index = i;
            this.threshold = d;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            Node node = (Node) obj;
            if (this.index != node.index || Double.compare(node.threshold, this.threshold) != 0) {
                return false;
            }
            if (this.left != null) {
                if (!this.left.equals(node.left)) {
                    return false;
                }
            } else if (node.left != null) {
                return false;
            }
            return this.right == null ? node.right == null : this.right.equals(node.right);
        }

        public int hashCode() {
            int i = this.index;
            long doubleToLongBits = Double.doubleToLongBits(this.threshold);
            return (31 * ((31 * ((31 * i) + ((int) (doubleToLongBits ^ (doubleToLongBits >>> 32))))) + (this.left != null ? this.left.hashCode() : 0))) + (this.right != null ? this.right.hashCode() : 0);
        }

        public int getIndex() {
            return this.index;
        }

        public void setIndex(int i) {
            this.index = i;
        }

        public double getThreshold() {
            return this.threshold;
        }

        public void setThreshold(double d) {
            this.threshold = d;
        }

        public Node getLeft() {
            return this.left;
        }

        public void setLeft(Node node) {
            this.left = node;
        }

        public Node getRight() {
            return this.right;
        }

        public void setRight(Node node) {
            this.right = node;
        }
    }

    public VPTree(INDArray iNDArray, String str, boolean z) {
        this.invert = true;
        ArrayList arrayList = new ArrayList();
        this.similarityFunction = str;
        this.invert = z;
        for (int i = 0; i < iNDArray.slices(); i++) {
            arrayList.add(new DataPoint(i, iNDArray.slice(i), this.similarityFunction, z));
        }
        this.items = arrayList;
        final int intValue = Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue();
        this.distances = CounterMap.runPairWise(arrayList, new CounterMap.CountFunction<DataPoint>() { // from class: org.deeplearning4j.clustering.vptree.VPTree.1
            public double count(DataPoint dataPoint, DataPoint dataPoint2) {
                Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread(), Integer.valueOf(intValue));
                return dataPoint.distance(dataPoint2);
            }
        });
        this.root = buildFromPoints(0, this.items.size());
    }

    public VPTree(List<DataPoint> list, CounterMap<DataPoint, DataPoint> counterMap, String str, boolean z) {
        this.invert = true;
        this.items = list;
        this.distances = counterMap;
        this.invert = z;
        this.similarityFunction = str;
        this.root = buildFromPoints(0, list.size());
    }

    public VPTree(List<DataPoint> list, String str, boolean z) {
        this.invert = true;
        this.items = list;
        this.invert = z;
        this.similarityFunction = str;
        this.distances = CounterMap.runPairWise(list, new CounterMap.CountFunction<DataPoint>() { // from class: org.deeplearning4j.clustering.vptree.VPTree.2
            public double count(DataPoint dataPoint, DataPoint dataPoint2) {
                return dataPoint.distance(dataPoint2);
            }
        });
        this.root = buildFromPoints(0, list.size());
    }

    public VPTree(INDArray iNDArray, String str) {
        this(iNDArray, str, true);
    }

    public VPTree(List<DataPoint> list, CounterMap<DataPoint, DataPoint> counterMap, String str) {
        this(list, counterMap, str, true);
    }

    public VPTree(List<DataPoint> list, String str) {
        this(list, str, true);
    }

    public VPTree(INDArray iNDArray) {
        this(iNDArray, EUCLIDEAN);
    }

    public VPTree(List<DataPoint> list, CounterMap<DataPoint, DataPoint> counterMap) {
        this(list, counterMap, EUCLIDEAN);
    }

    public VPTree(List<DataPoint> list) {
        this(list, EUCLIDEAN);
    }

    public static INDArray buildFromData(List<DataPoint> list) {
        INDArray create = Nd4j.create(list.size(), list.get(0).getD());
        for (int i = 0; i < create.slices(); i++) {
            create.putSlice(i, list.get(i).getPoint());
        }
        return create;
    }

    public List<DataPoint> getItems() {
        return this.items;
    }

    public void setItems(List<DataPoint> list) {
        this.items = list;
    }

    private double getDistance(DataPoint dataPoint, DataPoint dataPoint2) {
        double count = this.distances.getCount(dataPoint, dataPoint2);
        if (count != 0.0d) {
            return count;
        }
        double distance = dataPoint.distance(dataPoint2);
        this.distances.setCount(dataPoint, dataPoint2, distance);
        this.distances.setCount(dataPoint2, dataPoint, distance);
        return distance;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Node buildFromPoints(int i, int i2) {
        if (i2 == i) {
            return null;
        }
        Node node = new Node(i, 0.0d);
        if (i2 - i > 1) {
            int randomNumberBetween = MathUtils.randomNumberBetween(i, i2 - 1);
            int i3 = (i2 + i) / 2;
            double[] dArr = new double[this.items.size()];
            double[] dArr2 = new double[this.items.size()];
            DataPoint dataPoint = this.items.get(randomNumberBetween);
            for (int i4 = 0; i4 < this.items.size(); i4++) {
                dArr[i4] = getDistance(dataPoint, this.items.get(i4));
                dArr2[i4] = dArr[i4];
            }
            Arrays.sort(dArr2);
            double d = dArr2[dArr2.length / 2];
            ArrayList arrayList = new ArrayList(dArr2.length);
            ArrayList arrayList2 = new ArrayList(dArr2.length);
            for (int i5 = 0; i5 < dArr.length; i5++) {
                if (dArr[i5] < d) {
                    arrayList.add(this.items.get(i5));
                } else {
                    arrayList2.add(this.items.get(i5));
                }
            }
            for (int i6 = 0; i6 < arrayList.size(); i6++) {
                this.items.set(i6, arrayList.get(i6));
            }
            for (int i7 = 0; i7 < arrayList2.size(); i7++) {
                this.items.set(i7 + arrayList.size(), arrayList2.get(i7));
            }
            node.setThreshold(getDistance(this.items.get(i), this.items.get(i3)));
            node.setIndex(i);
            node.setLeft(buildFromPoints(i + 1, i3));
            node.setRight(buildFromPoints(i3, i2));
        }
        return node;
    }

    public void search(DataPoint dataPoint, int i, List<DataPoint> list, List<Double> list2) {
        PriorityQueue<HeapItem> priorityQueue = new PriorityQueue<>();
        this.tau = Double.MAX_VALUE;
        search(this.root, dataPoint, i, priorityQueue);
        list.clear();
        list2.clear();
        while (!priorityQueue.isEmpty()) {
            list.add(this.items.get(((HeapItem) priorityQueue.peek()).getIndex()));
            list2.add(Double.valueOf(((HeapItem) priorityQueue.peek()).getDistance()));
            priorityQueue.next();
        }
        Collections.reverse(list);
        Collections.reverse(list2);
    }

    public void search(Node node, DataPoint dataPoint, int i, PriorityQueue<HeapItem> priorityQueue) {
        if (node == null) {
            return;
        }
        double distance = getDistance(this.items.get(node.getIndex()), dataPoint);
        if (distance < this.tau) {
            if (priorityQueue.size() == i) {
                priorityQueue.next();
            }
            priorityQueue.add(new HeapItem(node.index, distance), distance);
            if (priorityQueue.size() == i) {
                this.tau = ((HeapItem) priorityQueue.peek()).getDistance();
            }
        }
        if (node.getLeft() == null && node.getRight() == null) {
            return;
        }
        if (distance < node.getThreshold()) {
            if (distance - this.tau <= node.getThreshold()) {
                search(node.getLeft(), dataPoint, i, priorityQueue);
            }
            if (distance + this.tau >= node.getThreshold()) {
                search(node.getRight(), dataPoint, i, priorityQueue);
                return;
            }
            return;
        }
        if (distance + this.tau >= node.getThreshold()) {
            search(node.getRight(), dataPoint, i, priorityQueue);
        }
        if (distance - this.tau <= node.getThreshold()) {
            search(node.getLeft(), dataPoint, i, priorityQueue);
        }
    }

    public CounterMap<DataPoint, DataPoint> getDistances() {
        return this.distances;
    }

    public void setDistances(CounterMap<DataPoint, DataPoint> counterMap) {
        this.distances = counterMap;
    }
}
