package de.jungblut.classification.knn;

import de.jungblut.classification.AbstractClassifier;
import de.jungblut.jrpt.VectorDistanceTuple;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.dense.SingleEntryDoubleVector;
import java.util.List;

/* loaded from: input_file:de/jungblut/classification/knn/AbstractKNearestNeighbours.class */
public abstract class AbstractKNearestNeighbours extends AbstractClassifier {
    protected final int numOutcomes;
    protected final int k;

    public AbstractKNearestNeighbours(int i, int i2) {
        this.numOutcomes = i;
        this.k = i2;
    }

    @Override // de.jungblut.classification.Predictor
    public DoubleVector predict(DoubleVector doubleVector) {
        List<VectorDistanceTuple<DoubleVector>> nearestNeighbours = getNearestNeighbours(doubleVector, this.k);
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(this.numOutcomes);
        for (VectorDistanceTuple<DoubleVector> vectorDistanceTuple : nearestNeighbours) {
            int maxIndex = this.numOutcomes == 2 ? (int) ((DoubleVector) vectorDistanceTuple.getValue()).get(0) : ((DoubleVector) vectorDistanceTuple.getValue()).maxIndex();
            denseDoubleVector.set(maxIndex, denseDoubleVector.get(maxIndex) + 1.0d);
        }
        return this.numOutcomes == 2 ? new SingleEntryDoubleVector(denseDoubleVector.maxIndex()) : denseDoubleVector;
    }

    @Override // de.jungblut.classification.AbstractPredictor, de.jungblut.classification.Predictor
    public DoubleVector predictProbability(DoubleVector doubleVector) {
        DoubleVector predict = predict(doubleVector);
        if (this.numOutcomes != 2) {
            predict = predict.divide(predict.sum());
        }
        return predict;
    }

    protected abstract List<VectorDistanceTuple<DoubleVector>> getNearestNeighbours(DoubleVector doubleVector, int i);
}
