package org.neo4j.gds.embeddings.graphsage.algo;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.immutables.value.Value;
import org.neo4j.gds.embeddings.graphsage.ActivationFunction;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.embeddings.graphsage.LayerConfig;
import org.neo4j.graphalgo.annotation.Configuration;
import org.neo4j.graphalgo.annotation.ValueClass;
import org.neo4j.graphalgo.config.AlgoBaseConfig;
import org.neo4j.graphalgo.config.BatchSizeConfig;
import org.neo4j.graphalgo.config.EmbeddingDimensionConfig;
import org.neo4j.graphalgo.config.GraphCreateConfig;
import org.neo4j.graphalgo.config.IterationsConfig;
import org.neo4j.graphalgo.config.RelationshipWeightConfig;
import org.neo4j.graphalgo.config.ToleranceConfig;
import org.neo4j.graphalgo.config.TrainConfig;
import org.neo4j.graphalgo.core.CypherMapWrapper;

@ValueClass
@Configuration("GraphSageTrainConfigImpl")
/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainConfig.class */
public interface GraphSageTrainConfig extends AlgoBaseConfig, TrainConfig, BatchSizeConfig, IterationsConfig, ToleranceConfig, EmbeddingDimensionConfig, RelationshipWeightConfig {
    public static final int PROJECTED_FEATURE_SIZE = -1;

    @Value.Default
    default int embeddingDimension() {
        return 64;
    }

    @Value.Default
    default List<String> nodePropertyNames() {
        return List.of();
    }

    @Value.Default
    default List<Long> sampleSizes() {
        return List.of(25L, 10L);
    }

    @Configuration.ConvertWith("org.neo4j.gds.embeddings.graphsage.Aggregator.AggregatorType#parse")
    @Value.Default
    @Configuration.ToMapValue("org.neo4j.gds.embeddings.graphsage.Aggregator.AggregatorType#toString")
    default Aggregator.AggregatorType aggregator() {
        return Aggregator.AggregatorType.MEAN;
    }

    @Configuration.ConvertWith("org.neo4j.gds.embeddings.graphsage.ActivationFunction#parse")
    @Value.Default
    @Configuration.ToMapValue("org.neo4j.gds.embeddings.graphsage.ActivationFunction#toString")
    default ActivationFunction activationFunction() {
        return ActivationFunction.SIGMOID;
    }

    @Value.Default
    default double tolerance() {
        return 1.0E-4d;
    }

    @Value.Default
    default double learningRate() {
        return 0.1d;
    }

    @Value.Default
    default int epochs() {
        return 1;
    }

    @Value.Default
    default int maxIterations() {
        return 10;
    }

    @Value.Default
    default int searchDepth() {
        return 5;
    }

    @Value.Default
    default int negativeSampleWeight() {
        return 20;
    }

    @Value.Default
    default boolean degreeAsProperty() {
        return false;
    }

    @Value.Default
    default int projectedFeatureSize() {
        return -1;
    }

    @Configuration.Ignore
    default List<LayerConfig> layerConfigs() {
        ArrayList arrayList = new ArrayList(sampleSizes().size());
        int i = 0;
        while (i < sampleSizes().size()) {
            arrayList.add(LayerConfig.builder().aggregatorType(aggregator()).activationFunction(activationFunction()).rows(embeddingDimension()).cols(i == 0 ? featuresSize() : embeddingDimension()).sampleSize(sampleSizes().get(i).longValue()).build());
            i++;
        }
        return arrayList;
    }

    @Configuration.Ignore
    default boolean isMultiLabel() {
        return projectedFeatureSize() > 0;
    }

    @Configuration.Ignore
    default int featuresSize() {
        if (isMultiLabel()) {
            return projectedFeatureSize();
        }
        return nodePropertyNames().size() + (degreeAsProperty() ? 1 : 0);
    }

    @Value.Check
    default void validate() {
        if (nodePropertyNames().isEmpty() && !degreeAsProperty()) {
            throw new IllegalArgumentException("GraphSage requires at least one property. Either `nodePropertyNames` or `degreeAsProperty` must be set.");
        }
    }

    static GraphSageTrainConfig of(String str, Optional<String> optional, Optional<GraphCreateConfig> optional2, CypherMapWrapper cypherMapWrapper) {
        return new GraphSageTrainConfigImpl(optional, optional2, str, cypherMapWrapper);
    }

    static GraphSageTrainConfig of(String str, ActivationFunction activationFunction, Aggregator.AggregatorType aggregatorType, int i, int i2, List<String> list, double d) {
        return ImmutableGraphSageTrainConfig.builder().modelName(str).activationFunction(activationFunction).aggregator(aggregatorType).batchSize(i).embeddingDimension(i2).nodePropertyNames(list).tolerance(d).build();
    }
}
