package org.neo4j.gds.ml.linkmodels.logisticregression;

import java.util.List;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.Training;
import org.neo4j.gds.ml.core.batch.HugeBatchQueue;
import org.neo4j.gds.ml.core.features.FeatureExtractor;

/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/logisticregression/LinkLogisticRegressionTrain.class */
public class LinkLogisticRegressionTrain {
    private final Graph graph;
    private final ReadOnlyHugeLongArray trainSet;
    private final List<FeatureExtractor> extractors;
    private final LinkLogisticRegressionTrainConfig config;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;
    private final int concurrency;

    public LinkLogisticRegressionTrain(Graph graph, HugeLongArray hugeLongArray, List<FeatureExtractor> list, LinkLogisticRegressionTrainConfig linkLogisticRegressionTrainConfig, ProgressTracker progressTracker, TerminationFlag terminationFlag, int i) {
        this.graph = graph;
        this.trainSet = ReadOnlyHugeLongArray.of(hugeLongArray);
        this.extractors = list;
        this.config = linkLogisticRegressionTrainConfig;
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
        this.concurrency = i;
    }

    public LinkLogisticRegressionData compute() {
        LinkLogisticRegressionObjective linkLogisticRegressionObjective = new LinkLogisticRegressionObjective(LinkLogisticRegressionData.from(this.graph, this.config.featureProperties(), LinkFeatureCombiners.valueOf(this.config.linkFeatureCombiner())), this.extractors, this.config.penalty(), this.graph);
        long j = 0;
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= this.trainSet.size()) {
                new Training(this.config, this.progressTracker, j, this.terminationFlag).train(linkLogisticRegressionObjective, () -> {
                    return new HugeBatchQueue(this.trainSet, this.config.batchSize());
                }, this.concurrency);
                return linkLogisticRegressionObjective.modelData;
            }
            j += this.graph.degree(this.trainSet.get(j3));
            j2 = j3 + 1;
        }
    }
}
