package de.jungblut.classification.knn;

import de.jungblut.datastructure.DistanceResult;
import de.jungblut.datastructure.InvertedIndex;
import de.jungblut.distance.DistanceMeasurer;
import de.jungblut.jrpt.VectorDistanceTuple;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.named.KeyedDoubleVector;
import gnu.trove.map.hash.TIntObjectHashMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:de/jungblut/classification/knn/SparseKNearestNeighbours.class */
public final class SparseKNearestNeighbours extends AbstractKNearestNeighbours {
    private final InvertedIndex<DoubleVector, Integer> index;
    private final TIntObjectHashMap<DoubleVector> featureOutcomeMap;

    public SparseKNearestNeighbours(int i, int i2, DistanceMeasurer distanceMeasurer) {
        super(i, i2);
        this.featureOutcomeMap = new TIntObjectHashMap<>();
        this.index = InvertedIndex.createVectorIndex(distanceMeasurer);
    }

    @Override // de.jungblut.classification.AbstractClassifier, de.jungblut.classification.Classifier
    public void train(Iterable<DoubleVector> iterable, Iterable<DoubleVector> iterable2) {
        ArrayList arrayList = new ArrayList();
        Iterator<DoubleVector> it = iterable.iterator();
        Iterator<DoubleVector> it2 = iterable2.iterator();
        int i = 0;
        while (it.hasNext()) {
            arrayList.add(new KeyedDoubleVector(i, it.next()));
            this.featureOutcomeMap.put(i, it2.next());
            i++;
        }
        this.index.build(arrayList);
    }

    @Override // de.jungblut.classification.knn.AbstractKNearestNeighbours
    protected List<VectorDistanceTuple<DoubleVector>> getNearestNeighbours(DoubleVector doubleVector, int i) {
        ArrayList arrayList = new ArrayList();
        for (DistanceResult<DoubleVector> distanceResult : this.index.query(doubleVector, i, Double.MAX_VALUE)) {
            arrayList.add(new VectorDistanceTuple(distanceResult.get(), this.featureOutcomeMap.get(distanceResult.get().getKey()), distanceResult.getDistance()));
        }
        return arrayList;
    }
}
