package org.neo4j.gds.ml.linkmodels;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.neo4j.gds.ml.batch.HugeBatchQueue;
import org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionData;
import org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionPredictor;
import org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionTrain;
import org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionTrainConfig;
import org.neo4j.gds.ml.linkmodels.metrics.LinkMetric;
import org.neo4j.gds.ml.nodemodels.ImmutableModelStats;
import org.neo4j.gds.ml.nodemodels.MetricData;
import org.neo4j.gds.ml.nodemodels.ModelStats;
import org.neo4j.gds.ml.splitting.NodeSplit;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.util.ShuffleUtil;
import org.neo4j.graphalgo.Algorithm;
import org.neo4j.graphalgo.annotation.ValueClass;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.model.Model;
import org.neo4j.graphalgo.core.utils.mem.AllocationTracker;
import org.neo4j.graphalgo.core.utils.paged.HugeLongArray;
import org.neo4j.logging.Log;

/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/LinkPredictionTrain.class */
public class LinkPredictionTrain extends Algorithm<LinkPredictionTrain, Model<LinkLogisticRegressionData, LinkPredictionTrainConfig>> {
    public static final String MODEL_TYPE = "Link Prediction";
    private final Graph trainGraph;
    private final Graph testGraph;
    private final LinkPredictionTrainConfig config;
    private final Log log;
    private final AllocationTracker allocationTracker = AllocationTracker.empty();

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/linkmodels/LinkPredictionTrain$ModelSelectResult.class */
    public interface ModelSelectResult {
        Map<String, Object> bestParameters();

        Map<LinkMetric, List<ModelStats>> trainStats();

        Map<LinkMetric, List<ModelStats>> validationStats();

        static ModelSelectResult of(Map<String, Object> map, Map<LinkMetric, List<ModelStats>> map2, Map<LinkMetric, List<ModelStats>> map3) {
            return ImmutableModelSelectResult.of(map, map2, map3);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/ml/linkmodels/LinkPredictionTrain$ModelStatsBuilder.class */
    public static class ModelStatsBuilder {
        private final Map<LinkMetric, Double> min = new HashMap();
        private final Map<LinkMetric, Double> max = new HashMap();
        private final Map<LinkMetric, Double> sum = new HashMap();
        private final Map<String, Object> modelParams;
        private final int numberOfSplits;

        ModelStatsBuilder(Map<String, Object> map, int i) {
            this.modelParams = map;
            this.numberOfSplits = i;
        }

        void update(LinkMetric linkMetric, double d) {
            this.min.merge(linkMetric, Double.valueOf(d), (v0, v1) -> {
                return Math.min(v0, v1);
            });
            this.max.merge(linkMetric, Double.valueOf(d), (v0, v1) -> {
                return Math.max(v0, v1);
            });
            this.sum.merge(linkMetric, Double.valueOf(d), (v0, v1) -> {
                return Double.sum(v0, v1);
            });
        }

        ModelStats modelStats(LinkMetric linkMetric) {
            return ImmutableModelStats.of(this.modelParams, this.sum.get(linkMetric).doubleValue() / this.numberOfSplits, this.min.get(linkMetric).doubleValue(), this.max.get(linkMetric).doubleValue());
        }
    }

    public LinkPredictionTrain(Graph graph, Graph graph2, LinkPredictionTrainConfig linkPredictionTrainConfig, Log log) {
        this.trainGraph = graph;
        this.testGraph = graph2;
        this.config = linkPredictionTrainConfig;
        this.log = log;
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public Model<LinkLogisticRegressionData, LinkPredictionTrainConfig> m5compute() {
        HugeLongArray newArray = HugeLongArray.newArray(this.trainGraph.nodeCount(), this.allocationTracker);
        newArray.setAll(j -> {
            return j;
        });
        ShuffleUtil.shuffleHugeLongArray(newArray, getRandomDataGenerator());
        ModelSelectResult modelSelect = modelSelect(newArray);
        LinkLogisticRegressionPredictor trainModel = trainModel(newArray, modelSelect.bestParameters());
        return Model.of(this.config.username(), this.config.modelName(), MODEL_TYPE, this.trainGraph.schema(), trainModel.modelData(), this.config, LinkPredictionModelInfo.of(modelSelect.bestParameters(), mergeMetrics(modelSelect, computeMetric(this.trainGraph, newArray, trainModel), computeMetric(this.testGraph, newArray, trainModel))));
    }

    private RandomDataGenerator getRandomDataGenerator() {
        RandomDataGenerator randomDataGenerator = new RandomDataGenerator();
        Optional<Long> randomSeed = this.config.randomSeed();
        Objects.requireNonNull(randomDataGenerator);
        randomSeed.ifPresent((v1) -> {
            r1.reSeed(v1);
        });
        return randomDataGenerator;
    }

    private Map<LinkMetric, MetricData> mergeMetrics(ModelSelectResult modelSelectResult, Map<LinkMetric, Double> map, Map<LinkMetric, Double> map2) {
        return (Map) modelSelectResult.validationStats().keySet().stream().collect(Collectors.toMap(linkMetric -> {
            return linkMetric;
        }, linkMetric2 -> {
            return MetricData.of(modelSelectResult.trainStats().get(linkMetric2), modelSelectResult.validationStats().get(linkMetric2), ((Double) map.get(linkMetric2)).doubleValue(), ((Double) map2.get(linkMetric2)).doubleValue());
        }));
    }

    private ModelSelectResult modelSelect(HugeLongArray hugeLongArray) {
        List<NodeSplit> trainValidationSplits = trainValidationSplits(hugeLongArray);
        Map<LinkMetric, List<ModelStats>> initStatsMap = initStatsMap();
        Map<LinkMetric, List<ModelStats>> initStatsMap2 = initStatsMap();
        this.config.params().forEach(map -> {
            ModelStatsBuilder modelStatsBuilder = new ModelStatsBuilder(map, this.config.validationFolds());
            ModelStatsBuilder modelStatsBuilder2 = new ModelStatsBuilder(map, this.config.validationFolds());
            Iterator it = trainValidationSplits.iterator();
            while (it.hasNext()) {
                NodeSplit nodeSplit = (NodeSplit) it.next();
                HugeLongArray trainSet = nodeSplit.trainSet();
                HugeLongArray testSet = nodeSplit.testSet();
                LinkLogisticRegressionPredictor trainModel = trainModel(trainSet, map);
                Map<LinkMetric, Double> computeMetric = computeMetric(this.trainGraph, trainSet, trainModel);
                Objects.requireNonNull(modelStatsBuilder);
                computeMetric.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
                Map<LinkMetric, Double> computeMetric2 = computeMetric(this.trainGraph, testSet, trainModel);
                Objects.requireNonNull(modelStatsBuilder2);
                computeMetric2.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
            }
            this.config.metrics().forEach(linkMetric -> {
                ((List) initStatsMap2.get(linkMetric)).add(modelStatsBuilder2.modelStats(linkMetric));
                ((List) initStatsMap.get(linkMetric)).add(modelStatsBuilder.modelStats(linkMetric));
            });
        });
        return ModelSelectResult.of(((ModelStats) Collections.max(initStatsMap2.get(this.config.metrics().get(0)), ModelStats.COMPARE_AVERAGE)).params(), initStatsMap, initStatsMap2);
    }

    private List<NodeSplit> trainValidationSplits(HugeLongArray hugeLongArray) {
        HugeLongArray newArray = HugeLongArray.newArray(this.trainGraph.nodeCount(), this.allocationTracker);
        newArray.setAll(j -> {
            return 0L;
        });
        return new StratifiedKFoldSplitter(this.config.validationFolds(), hugeLongArray, newArray).splits();
    }

    private Map<LinkMetric, List<ModelStats>> initStatsMap() {
        HashMap hashMap = new HashMap();
        hashMap.put(LinkMetric.AUCPR, new ArrayList());
        return hashMap;
    }

    private Map<LinkMetric, Double> computeMetric(Graph graph, HugeLongArray hugeLongArray, LinkLogisticRegressionPredictor linkLogisticRegressionPredictor) {
        SignedProbabilities create = SignedProbabilities.create(graph.relationshipCount());
        new HugeBatchQueue(hugeLongArray).parallelConsume(this.config.concurrency(), i -> {
            return new SignedProbabilitiesCollector(graph.concurrentCopy(), linkLogisticRegressionPredictor, create, this.progressLogger);
        });
        return (Map) this.config.metrics().stream().collect(Collectors.toMap(linkMetric -> {
            return linkMetric;
        }, linkMetric2 -> {
            return Double.valueOf(linkMetric2.compute(create, this.config.classRatio()));
        }));
    }

    private LinkLogisticRegressionPredictor trainModel(HugeLongArray hugeLongArray, Map<String, Object> map) {
        return new LinkLogisticRegressionTrain(this.trainGraph, hugeLongArray, LinkLogisticRegressionTrainConfig.of(this.config.featureProperties(), this.config.concurrency(), map), this.log).compute();
    }

    /* renamed from: me, reason: merged with bridge method [inline-methods] */
    public LinkPredictionTrain m4me() {
        return this;
    }

    public void release() {
    }
}
