package de.jungblut.classification.eval;

import com.google.common.base.Preconditions;
import de.jungblut.classification.Classifier;
import de.jungblut.classification.eval.Evaluator;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.minimize.IterationCompletionListener;
import java.util.Comparator;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/jungblut/classification/eval/TestSetIterationCallback.class */
public class TestSetIterationCallback<T extends Classifier> implements IterationCompletionListener {
    private static final Logger LOG = LogManager.getLogger(TestSetIterationCallback.class);
    private final EvaluationSplit split;
    private final WeightMapper<T> mapper;
    private final Comparator<Evaluator.EvaluationResult> resultComparison;
    private Evaluator.EvaluationResult bestResult;
    private DoubleVector bestWeights;
    private int evaluationInterval;

    public TestSetIterationCallback(EvaluationSplit evaluationSplit, WeightMapper<T> weightMapper, Comparator<Evaluator.EvaluationResult> comparator, int i) {
        this.evaluationInterval = i;
        this.resultComparison = (Comparator) Preconditions.checkNotNull(comparator, "resultComparison");
        this.split = (EvaluationSplit) Preconditions.checkNotNull(evaluationSplit, "split");
        this.mapper = (WeightMapper) Preconditions.checkNotNull(weightMapper, "mapper");
    }

    public TestSetIterationCallback(EvaluationSplit evaluationSplit, WeightMapper<T> weightMapper, Comparator<Evaluator.EvaluationResult> comparator) {
        this(evaluationSplit, weightMapper, comparator, 10);
    }

    @Override // de.jungblut.math.minimize.IterationCompletionListener
    public void onIterationFinished(int i, double d, DoubleVector doubleVector) {
        if (i % this.evaluationInterval == 0) {
            Evaluator.EvaluationResult testClassifier = Evaluator.testClassifier(this.mapper.mapWeights(doubleVector), this.split.getTestFeatures(), this.split.getTestOutcome());
            if (this.bestResult == null) {
                this.bestResult = testClassifier;
                this.bestWeights = doubleVector.deepCopy();
            } else if (this.resultComparison.compare(this.bestResult, testClassifier) > 0) {
                LOG.info("Found better weights with result:");
                testClassifier.print(LOG);
                this.bestResult = testClassifier;
                this.bestWeights = doubleVector.deepCopy();
            }
        }
    }

    public Evaluator.EvaluationResult getBestResult() {
        return this.bestResult;
    }

    public DoubleVector getBestWeights() {
        return this.bestWeights;
    }
}
