package org.neo4j.gds.ml.linkmodels.pipeline.train;

import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.ProcedureExecutor;
import org.neo4j.gds.executor.ProcedureExecutorSpec;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.splitting.ImmutableSplitRelationshipsMutateConfig;
import org.neo4j.gds.ml.splitting.SplitRelationshipsAlgorithmFactory;
import org.neo4j.gds.ml.splitting.SplitRelationshipsBaseConfig;
import org.neo4j.gds.ml.splitting.SplitRelationshipsMutateProc;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/pipeline/train/RelationshipSplitter.class */
public class RelationshipSplitter {
    private static final String SPLIT_ERROR_TEMPLATE = "%s graph contains no relationships. Consider increasing the `%s` or provide a larger graph";
    private final String graphName;
    private final LinkPredictionSplitConfig splitConfig;
    private final ExecutionContext executionContext;
    private final ProgressTracker progressTracker;

    /* JADX INFO: Access modifiers changed from: package-private */
    public RelationshipSplitter(String str, LinkPredictionSplitConfig linkPredictionSplitConfig, ExecutionContext executionContext, ProgressTracker progressTracker) {
        this.graphName = str;
        this.splitConfig = linkPredictionSplitConfig;
        this.executionContext = executionContext;
        this.progressTracker = progressTracker;
    }

    public void splitRelationships(GraphStore graphStore, List<String> list, List<String> list2, Optional<Long> optional, Optional<String> optional2) {
        this.progressTracker.beginSubTask("Split relationships");
        this.splitConfig.validateAgainstGraphStore(graphStore);
        String testComplementRelationshipType = this.splitConfig.testComplementRelationshipType();
        relationshipSplit(this.splitConfig.testSplit(), list2, list, optional, optional2);
        validateTestSplit(graphStore);
        relationshipSplit(this.splitConfig.trainSplit(), list2, List.of(testComplementRelationshipType), optional, optional2);
        graphStore.deleteRelationships(RelationshipType.of(testComplementRelationshipType));
        this.progressTracker.endSubTask("Split relationships");
    }

    private void validateTestSplit(GraphStore graphStore) {
        if (graphStore.getGraph(new RelationshipType[]{RelationshipType.of(this.splitConfig.testRelationshipType())}).relationshipCount() <= 0) {
            throw new IllegalStateException(StringFormatting.formatWithLocale(SPLIT_ERROR_TEMPLATE, new Object[]{"Test", "testFraction"}));
        }
    }

    private void relationshipSplit(SplitRelationshipsBaseConfig splitRelationshipsBaseConfig, final List<String> list, final List<String> list2, final Optional<Long> optional, final Optional<String> optional2) {
        new ProcedureExecutor(new SplitRelationshipsMutateProc(), new ProcedureExecutorSpec(), this.executionContext).compute(this.graphName, new HashMap<String, Object>(splitRelationshipsBaseConfig.toSplitMap()) { // from class: org.neo4j.gds.ml.linkmodels.pipeline.train.RelationshipSplitter.1
            {
                put("nodeLabels", list);
                put("relationshipTypes", list2);
                optional2.ifPresent(str -> {
                    put("relationshipWeightProperty", str);
                });
                optional.ifPresent(l -> {
                    put("randomSeed", l);
                });
            }
        }, false, false);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MemoryEstimation splitEstimation(LinkPredictionSplitConfig linkPredictionSplitConfig, List<String> list) {
        MemoryEstimation build = MemoryEstimations.builder("Test/Test-complement split").add(new SplitRelationshipsAlgorithmFactory().memoryEstimation(ImmutableSplitRelationshipsMutateConfig.builder().from(linkPredictionSplitConfig.testSplit()).relationshipTypes((List) list.stream().map(str -> {
            return str.equals("*") ? RelationshipType.ALL_RELATIONSHIPS.name : str;
        }).collect(Collectors.toList())).build())).build();
        return MemoryEstimations.builder("Split relationships").add(build).add(MemoryEstimations.builder("Train/Feature-input split").add(new SplitRelationshipsAlgorithmFactory().memoryEstimation(ImmutableSplitRelationshipsMutateConfig.builder().from(linkPredictionSplitConfig.trainSplit()).relationshipTypes(List.of(linkPredictionSplitConfig.testComplementRelationshipType())).build())).build()).build();
    }
}
