package org.neo4j.gds.ml;

import java.util.List;
import java.util.Objects;
import org.immutables.value.Value;
import org.neo4j.gds.embeddings.graphsage.AdamOptimizer;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.Weights;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;
import org.neo4j.graphalgo.annotation.ValueClass;

@ValueClass
/* loaded from: input_file:org/neo4j/gds/ml/TrainingSettings.class */
public interface TrainingSettings {
    @Value.Default
    default int batchSize() {
        return 100;
    }

    @Value.Default
    default int minIterations() {
        return 1;
    }

    @Value.Default
    default int maxStreakCount() {
        return 1;
    }

    @Value.Default
    default int maxIterations() {
        return 100;
    }

    @Value.Default
    default int windowSize() {
        return 1;
    }

    @Value.Default
    default double tolerance() {
        return 0.001d;
    }

    default Updater updater(List<Weights<? extends Tensor<?>>> list) {
        AdamOptimizer adamOptimizer = new AdamOptimizer(list);
        Objects.requireNonNull(adamOptimizer);
        return adamOptimizer::update;
    }

    default TrainingStopper stopper() {
        return new StreakStopper(minIterations(), maxStreakCount(), maxIterations(), windowSize(), tolerance());
    }

    default boolean sharedUpdater() {
        return false;
    }

    default BatchQueue batchQueue(long j) {
        return new BatchQueue(j, batchSize());
    }
}
