package org.neo4j.gds.ml.linkmodels;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.batch.HugeBatchQueue;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.features.FeatureExtractor;
import org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionData;
import org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionPredictor;
import org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionTrain;
import org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionTrainConfig;
import org.neo4j.gds.ml.linkmodels.metrics.LinkMetric;
import org.neo4j.gds.ml.nodemodels.ImmutableModelStats;
import org.neo4j.gds.ml.nodemodels.MetricData;
import org.neo4j.gds.ml.nodemodels.ModelStats;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.ml.util.ShuffleUtil;

/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/LinkPredictionTrain.class */
public class LinkPredictionTrain extends Algorithm<Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo>> {
    public static final String MODEL_TYPE = "Link Prediction";
    private final Graph trainGraph;
    private final Graph testGraph;
    private final LinkPredictionTrainConfig config;
    private final AllocationTracker allocationTracker;
    private final List<FeatureExtractor> trainExtractors;
    private final List<FeatureExtractor> testExtractors;

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/linkmodels/LinkPredictionTrain$ModelSelectResult.class */
    public interface ModelSelectResult {
        LinkLogisticRegressionTrainConfig bestParameters();

        Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> trainStats();

        Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> validationStats();

        static ModelSelectResult of(LinkLogisticRegressionTrainConfig linkLogisticRegressionTrainConfig, Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> map, Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> map2) {
            return ImmutableModelSelectResult.of(linkLogisticRegressionTrainConfig, map, map2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/ml/linkmodels/LinkPredictionTrain$ModelStatsBuilder.class */
    public static class ModelStatsBuilder {
        private final Map<LinkMetric, Double> min = new HashMap();
        private final Map<LinkMetric, Double> max = new HashMap();
        private final Map<LinkMetric, Double> sum = new HashMap();
        private final LinkLogisticRegressionTrainConfig modelParams;
        private final int numberOfSplits;

        /* JADX INFO: Access modifiers changed from: package-private */
        public static long sizeInBytes() {
            return (3 * MemoryUsage.sizeOfInstance(HashMap.class)) + 8;
        }

        ModelStatsBuilder(LinkLogisticRegressionTrainConfig linkLogisticRegressionTrainConfig, int i) {
            this.modelParams = linkLogisticRegressionTrainConfig;
            this.numberOfSplits = i;
        }

        void update(LinkMetric linkMetric, double d) {
            this.min.merge(linkMetric, Double.valueOf(d), (v0, v1) -> {
                return Math.min(v0, v1);
            });
            this.max.merge(linkMetric, Double.valueOf(d), (v0, v1) -> {
                return Math.max(v0, v1);
            });
            this.sum.merge(linkMetric, Double.valueOf(d), (v0, v1) -> {
                return Double.sum(v0, v1);
            });
        }

        ModelStats<LinkLogisticRegressionTrainConfig> modelStats(LinkMetric linkMetric) {
            return ImmutableModelStats.of(this.modelParams, this.sum.get(linkMetric).doubleValue() / this.numberOfSplits, this.min.get(linkMetric).doubleValue(), this.max.get(linkMetric).doubleValue());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MemoryEstimation estimateModelSelectResult(LinkPredictionTrainConfig linkPredictionTrainConfig) {
        return MemoryEstimations.builder("model selection result").fixed("instance", MemoryUsage.sizeOfInstance(ImmutableModelSelectResult.class)).fixed("model stats map train", MemoryUsage.sizeOfInstance(HashMap.class)).fixed("model stats list train", MemoryUsage.sizeOfInstance(ArrayList.class)).fixed("model stats map test", MemoryUsage.sizeOfInstance(HashMap.class)).fixed("model stats list test", MemoryUsage.sizeOfInstance(ArrayList.class)).fixed("model stats train", MemoryUsage.sizeOfInstance(ImmutableModelStats.class) * linkPredictionTrainConfig.paramConfigs().size()).build();
    }

    public LinkPredictionTrain(Graph graph, LinkPredictionTrainConfig linkPredictionTrainConfig, ProgressTracker progressTracker) {
        super(progressTracker);
        this.trainGraph = graph.relationshipTypeFilteredGraph(Set.of(linkPredictionTrainConfig.trainRelationshipType()));
        this.testGraph = graph.relationshipTypeFilteredGraph(Set.of(linkPredictionTrainConfig.testRelationshipType()));
        this.trainExtractors = FeatureExtraction.propertyExtractors(this.trainGraph, linkPredictionTrainConfig.featureProperties());
        this.testExtractors = FeatureExtraction.propertyExtractors(this.testGraph, linkPredictionTrainConfig.featureProperties());
        this.config = linkPredictionTrainConfig;
        this.allocationTracker = AllocationTracker.empty();
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo> m38compute() {
        this.progressTracker.beginSubTask();
        HugeLongArray newArray = HugeLongArray.newArray(this.trainGraph.nodeCount(), this.allocationTracker);
        newArray.setAll(j -> {
            return j;
        });
        ShuffleUtil.shuffleHugeLongArray(newArray, ShuffleUtil.createRandomDataGenerator(this.config.randomSeed()));
        this.progressTracker.beginSubTask("ModelSelection");
        ModelSelectResult modelSelect = modelSelect(newArray);
        this.progressTracker.endSubTask("ModelSelection");
        LinkLogisticRegressionTrainConfig bestParameters = modelSelect.bestParameters();
        this.progressTracker.beginSubTask("Training");
        LinkLogisticRegressionData trainModel = trainModel(newArray, bestParameters, this.progressTracker);
        this.progressTracker.endSubTask("Training");
        this.progressTracker.beginSubTask("Evaluation");
        this.progressTracker.beginSubTask("Training");
        Map<LinkMetric, Double> computeMetric = computeMetric(this.trainGraph, newArray, predictor(trainModel, this.trainExtractors), this.progressTracker);
        this.progressTracker.endSubTask("Training");
        this.progressTracker.beginSubTask("Testing");
        Map<LinkMetric, Double> computeMetric2 = computeMetric(this.testGraph, newArray, predictor(trainModel, this.testExtractors), this.progressTracker);
        this.progressTracker.endSubTask("Testing");
        Map<LinkMetric, MetricData<LinkLogisticRegressionTrainConfig>> mergeMetrics = mergeMetrics(modelSelect, computeMetric, computeMetric2);
        this.progressTracker.endSubTask("Evaluation");
        this.progressTracker.endSubTask();
        return Model.of(this.config.username(), this.config.modelName(), MODEL_TYPE, this.trainGraph.schema(), trainModel, this.config, LinkPredictionModelInfo.of(modelSelect.bestParameters(), mergeMetrics));
    }

    public LinkLogisticRegressionPredictor predictor(LinkLogisticRegressionData linkLogisticRegressionData, List<FeatureExtractor> list) {
        return new LinkLogisticRegressionPredictor(linkLogisticRegressionData, list);
    }

    private Map<LinkMetric, MetricData<LinkLogisticRegressionTrainConfig>> mergeMetrics(ModelSelectResult modelSelectResult, Map<LinkMetric, Double> map, Map<LinkMetric, Double> map2) {
        return (Map) modelSelectResult.validationStats().keySet().stream().collect(Collectors.toMap(Function.identity(), linkMetric -> {
            return MetricData.of(modelSelectResult.trainStats().get(linkMetric), modelSelectResult.validationStats().get(linkMetric), ((Double) map.get(linkMetric)).doubleValue(), ((Double) map2.get(linkMetric)).doubleValue());
        }));
    }

    private ModelSelectResult modelSelect(HugeLongArray hugeLongArray) {
        List<TrainingExamplesSplit> trainValidationSplits = trainValidationSplits(ReadOnlyHugeLongArray.of(hugeLongArray));
        Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> initStatsMap = initStatsMap();
        Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> initStatsMap2 = initStatsMap();
        this.config.paramConfigs().forEach(linkLogisticRegressionTrainConfig -> {
            ModelStatsBuilder modelStatsBuilder = new ModelStatsBuilder(linkLogisticRegressionTrainConfig, this.config.validationFolds());
            ModelStatsBuilder modelStatsBuilder2 = new ModelStatsBuilder(linkLogisticRegressionTrainConfig, this.config.validationFolds());
            Iterator it = trainValidationSplits.iterator();
            while (it.hasNext()) {
                TrainingExamplesSplit trainingExamplesSplit = (TrainingExamplesSplit) it.next();
                HugeLongArray trainSet = trainingExamplesSplit.trainSet();
                HugeLongArray testSet = trainingExamplesSplit.testSet();
                LinkLogisticRegressionPredictor predictor = predictor(trainModel(trainSet, linkLogisticRegressionTrainConfig, ProgressTracker.NULL_TRACKER), this.trainExtractors);
                this.progressTracker.logProgress();
                Map<LinkMetric, Double> computeMetric = computeMetric(this.trainGraph, trainSet, predictor, ProgressTracker.NULL_TRACKER);
                Objects.requireNonNull(modelStatsBuilder);
                computeMetric.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
                Map<LinkMetric, Double> computeMetric2 = computeMetric(this.trainGraph, testSet, predictor, ProgressTracker.NULL_TRACKER);
                Objects.requireNonNull(modelStatsBuilder2);
                computeMetric2.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
            }
            this.config.metrics().forEach(linkMetric -> {
                ((List) initStatsMap2.get(linkMetric)).add(modelStatsBuilder2.modelStats(linkMetric));
                ((List) initStatsMap.get(linkMetric)).add(modelStatsBuilder.modelStats(linkMetric));
            });
        });
        return ModelSelectResult.of((LinkLogisticRegressionTrainConfig) ((ModelStats) Collections.max(initStatsMap2.get(this.config.metrics().get(0)), ModelStats.COMPARE_AVERAGE)).params(), initStatsMap, initStatsMap2);
    }

    private List<TrainingExamplesSplit> trainValidationSplits(ReadOnlyHugeLongArray readOnlyHugeLongArray) {
        return new StratifiedKFoldSplitter(this.config.validationFolds(), readOnlyHugeLongArray, new ReadOnlyHugeLongArray() { // from class: org.neo4j.gds.ml.linkmodels.LinkPredictionTrain.1
            public long get(long j) {
                return 0L;
            }

            public long size() {
                return LinkPredictionTrain.this.trainGraph.nodeCount();
            }
        }, this.config.randomSeed()).splits();
    }

    private Map<LinkMetric, List<ModelStats<LinkLogisticRegressionTrainConfig>>> initStatsMap() {
        HashMap hashMap = new HashMap();
        hashMap.put(LinkMetric.AUCPR, new ArrayList());
        return hashMap;
    }

    private Map<LinkMetric, Double> computeMetric(Graph graph, HugeLongArray hugeLongArray, LinkLogisticRegressionPredictor linkLogisticRegressionPredictor, ProgressTracker progressTracker) {
        SignedProbabilities create = SignedProbabilities.create(graph.relationshipCount());
        progressTracker.setVolume(graph.nodeCount());
        new HugeBatchQueue(ReadOnlyHugeLongArray.of(hugeLongArray)).parallelConsume(this.config.concurrency(), i -> {
            return new SignedProbabilitiesCollector(graph.concurrentCopy(), linkLogisticRegressionPredictor, create, progressTracker);
        }, this.terminationFlag);
        return (Map) this.config.metrics().stream().collect(Collectors.toMap(Function.identity(), linkMetric -> {
            return Double.valueOf(linkMetric.compute(create, this.config.negativeClassWeight()));
        }));
    }

    private LinkLogisticRegressionData trainModel(HugeLongArray hugeLongArray, LinkLogisticRegressionTrainConfig linkLogisticRegressionTrainConfig, ProgressTracker progressTracker) {
        return new LinkLogisticRegressionTrain(this.trainGraph, hugeLongArray, this.trainExtractors, linkLogisticRegressionTrainConfig, progressTracker, this.terminationFlag, this.config.concurrency()).compute();
    }

    public void release() {
    }
}
