package org.neo4j.gds.ml.models.mlp;

import java.util.Optional;
import java.util.SplittableRandom;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.gradientdescent.Training;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.Features;

/* loaded from: input_file:org/neo4j/gds/ml/models/mlp/MLPClassifierTrainer.class */
public class MLPClassifierTrainer implements ClassifierTrainer {
    private final int numberOfClasses;
    private final MLPClassifierTrainConfig trainConfig;
    private final SplittableRandom random;
    private final ProgressTracker progressTracker;
    private final LogLevel messageLogLevel;
    private final TerminationFlag terminationFlag;
    private final int concurrency;

    public MLPClassifierTrainer(int i, MLPClassifierTrainConfig mLPClassifierTrainConfig, Optional<Long> optional, ProgressTracker progressTracker, LogLevel logLevel, TerminationFlag terminationFlag, int i2) {
        this.numberOfClasses = i;
        this.trainConfig = mLPClassifierTrainConfig;
        this.random = new SplittableRandom(optional.orElseGet(() -> {
            return Long.valueOf(new SplittableRandom().nextLong());
        }).longValue());
        this.progressTracker = progressTracker;
        this.messageLogLevel = logLevel;
        this.terminationFlag = terminationFlag;
        this.concurrency = i2;
    }

    @Override // org.neo4j.gds.ml.models.ClassifierTrainer
    public MLPClassifier train(Features features, HugeIntArray hugeIntArray, ReadOnlyHugeLongArray readOnlyHugeLongArray) {
        MLPClassifier mLPClassifier = new MLPClassifier(MLPClassifierData.create(this.numberOfClasses, features.featureDimension(), this.trainConfig.hiddenLayerSizes(), this.random));
        new Training(this.trainConfig, this.progressTracker, this.messageLogLevel, readOnlyHugeLongArray.size(), this.terminationFlag).train(new MLPClassifierObjective(mLPClassifier, features, hugeIntArray, this.trainConfig.penalty(), this.trainConfig.focusWeight(), this.trainConfig.initializeClassWeights(this.numberOfClasses)), () -> {
            return BatchQueue.fromArray(readOnlyHugeLongArray, this.trainConfig.batchSize());
        }, this.concurrency);
        return mLPClassifier;
    }
}
