package org.neo4j.gds.embeddings.graphsage.algo;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.immutables.value.Value;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.config.BatchSizeConfig;
import org.neo4j.gds.config.EmbeddingDimensionConfig;
import org.neo4j.gds.config.FeaturePropertiesConfig;
import org.neo4j.gds.config.IterationsConfig;
import org.neo4j.gds.config.RandomSeedConfig;
import org.neo4j.gds.config.RelationshipWeightConfig;
import org.neo4j.gds.config.ToleranceConfig;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.embeddings.graphsage.ActivationFunction;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.embeddings.graphsage.LayerConfig;
import org.neo4j.gds.model.ModelConfig;

@ValueClass
@Configuration("GraphSageTrainConfigImpl")
/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainConfig.class */
public interface GraphSageTrainConfig extends AlgoBaseConfig, ModelConfig, BatchSizeConfig, IterationsConfig, ToleranceConfig, EmbeddingDimensionConfig, RelationshipWeightConfig, FeaturePropertiesConfig, RandomSeedConfig {
    public static final long serialVersionUID = 66;

    @Value.Default
    default int embeddingDimension() {
        return 64;
    }

    @Configuration.IntegerRange(min = 1)
    @Configuration.ConvertWith("convertToIntSamples")
    @Value.Default
    default List<Integer> sampleSizes() {
        return List.of(25, 10);
    }

    static List<Integer> convertToIntSamples(List<Number> list) {
        try {
            return (List) list.stream().map((v0) -> {
                return v0.longValue();
            }).map((v0) -> {
                return Math.toIntExact(v0);
            }).collect(Collectors.toList());
        } catch (ArithmeticException e) {
            throw new IllegalArgumentException("Sample size must smaller than 2^31", e);
        }
    }

    @Configuration.ConvertWith("org.neo4j.gds.embeddings.graphsage.Aggregator.AggregatorType#parse")
    @Value.Default
    @Configuration.ToMapValue("org.neo4j.gds.embeddings.graphsage.Aggregator.AggregatorType#toString")
    default Aggregator.AggregatorType aggregator() {
        return Aggregator.AggregatorType.MEAN;
    }

    @Configuration.ConvertWith("org.neo4j.gds.embeddings.graphsage.ActivationFunction#parse")
    @Value.Default
    @Configuration.ToMapValue("org.neo4j.gds.embeddings.graphsage.ActivationFunction#toString")
    default ActivationFunction activationFunction() {
        return ActivationFunction.SIGMOID;
    }

    @Value.Default
    default double tolerance() {
        return 1.0E-4d;
    }

    @Value.Default
    default double learningRate() {
        return 0.1d;
    }

    @Configuration.IntegerRange(min = 1)
    @Value.Default
    default int epochs() {
        return 1;
    }

    @Value.Default
    default int maxIterations() {
        return 10;
    }

    @Configuration.Key("batchSamplingRatio")
    @Configuration.DoubleRange(min = 0.0d, max = 1.0d, minInclusive = false)
    Optional<Double> maybeBatchSamplingRatio();

    @Configuration.DoubleRange(min = 0.0d)
    default double penaltyL2() {
        return 0.0d;
    }

    @Configuration.Ignore
    @Value.Derived
    default int batchesPerIteration(long j) {
        return (int) Math.ceil(maybeBatchSamplingRatio().orElse(Double.valueOf(Math.min(1.0d, (batchSize() * concurrency()) / j))).doubleValue() * Math.ceil(j / batchSize()));
    }

    @Value.Default
    default int searchDepth() {
        return 5;
    }

    @Value.Default
    default int negativeSampleWeight() {
        return 20;
    }

    @Configuration.IntegerRange(min = 1)
    Optional<Integer> projectedFeatureDimension();

    @Configuration.Ignore
    default boolean propertiesMustExistForEachNodeLabel() {
        return !isMultiLabel();
    }

    @Configuration.Ignore
    default List<LayerConfig> layerConfigs(int i) {
        ArrayList arrayList = new ArrayList(sampleSizes().size());
        Random random = new Random();
        Optional randomSeed = randomSeed();
        Objects.requireNonNull(random);
        randomSeed.ifPresent((v1) -> {
            r1.setSeed(v1);
        });
        int i2 = 0;
        while (i2 < sampleSizes().size()) {
            arrayList.add(LayerConfig.builder().aggregatorType(aggregator()).activationFunction(activationFunction()).rows(embeddingDimension()).cols(i2 == 0 ? i : embeddingDimension()).sampleSize(sampleSizes().get(i2).intValue()).randomSeed(random.nextLong()).build());
            i2++;
        }
        return arrayList;
    }

    @Configuration.Ignore
    default boolean isMultiLabel() {
        return projectedFeatureDimension().isPresent();
    }

    @Configuration.Ignore
    default int estimationFeatureDimension() {
        return projectedFeatureDimension().orElse(Integer.valueOf(featureProperties().size())).intValue();
    }

    @Value.Check
    default void validate() {
        if (featureProperties().isEmpty()) {
            throw new IllegalArgumentException("GraphSage requires at least one property.");
        }
    }

    @Configuration.GraphStoreValidationCheck
    @Value.Default
    default void validateNonEmptyGraph(GraphStore graphStore, Collection<NodeLabel> collection, Collection<RelationshipType> collection2) {
        Stream<RelationshipType> stream = collection2.stream();
        Objects.requireNonNull(graphStore);
        if (stream.mapToLong(graphStore::relationshipCount).sum() == 0) {
            throw new IllegalArgumentException("There should be at least one relationship in the graph.");
        }
    }

    static GraphSageTrainConfig of(String str, CypherMapWrapper cypherMapWrapper) {
        return new GraphSageTrainConfigImpl(str, cypherMapWrapper);
    }
}
