package org.neo4j.gds.ml.models.linearregression;

import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.gradientdescent.Training;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.RegressorTrainer;

/* loaded from: input_file:org/neo4j/gds/ml/models/linearregression/LinearRegressionTrainer.class */
public final class LinearRegressionTrainer implements RegressorTrainer {
    private final int concurrency;
    private final TerminationFlag terminationFlag;
    private final ProgressTracker progressTracker;
    private final LinearRegressionTrainConfig trainConfig;

    public LinearRegressionTrainer(int i, LinearRegressionTrainConfig linearRegressionTrainConfig, TerminationFlag terminationFlag, ProgressTracker progressTracker) {
        this.concurrency = i;
        this.trainConfig = linearRegressionTrainConfig;
        this.terminationFlag = terminationFlag;
        this.progressTracker = progressTracker;
    }

    @Override // org.neo4j.gds.ml.models.RegressorTrainer
    public LinearRegressor train(Features features, HugeDoubleArray hugeDoubleArray, ReadOnlyHugeLongArray readOnlyHugeLongArray) {
        LinearRegressionObjective linearRegressionObjective = new LinearRegressionObjective(features, hugeDoubleArray, this.trainConfig.penalty());
        new Training(this.trainConfig, this.progressTracker, readOnlyHugeLongArray.size(), this.terminationFlag).train(linearRegressionObjective, () -> {
            return BatchQueue.fromArray(readOnlyHugeLongArray, this.trainConfig.batchSize());
        }, this.concurrency);
        return new LinearRegressor(linearRegressionObjective.modelData());
    }
}
