package org.neo4j.gds.ml.nodemodels.multiclasslogisticregression;

import org.neo4j.gds.ml.Training;
import org.neo4j.gds.ml.batch.HugeBatchQueue;
import org.neo4j.gds.ml.nodemodels.logisticregression.MultiClassNLRTrainConfig;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.utils.paged.HugeLongArray;
import org.neo4j.logging.Log;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/multiclasslogisticregression/MultiClassNLRTrain.class */
public class MultiClassNLRTrain {
    private final Graph graph;
    private final HugeLongArray trainSet;
    private final MultiClassNLRTrainConfig config;
    private final Log log;

    public MultiClassNLRTrain(Graph graph, HugeLongArray hugeLongArray, MultiClassNLRTrainConfig multiClassNLRTrainConfig, Log log) {
        this.graph = graph;
        this.trainSet = hugeLongArray;
        this.config = multiClassNLRTrainConfig;
        this.log = log;
    }

    public MultiClassNLRData compute() {
        MultiClassNLRObjective multiClassNLRObjective = new MultiClassNLRObjective(this.graph, new MultiClassNLRPredictor(MultiClassNLRData.from(this.graph, this.config.featureProperties(), this.config.targetProperty()), this.config.featureProperties()), this.config.targetProperty(), this.config.penalty());
        new Training(this.config, this.log, this.graph.nodeCount()).train(multiClassNLRObjective, () -> {
            return new HugeBatchQueue(this.trainSet, this.config.batchSize());
        }, this.config.concurrency());
        return multiClassNLRObjective.modelData();
    }
}
