package org.neo4j.gds.ml.nodemodels;

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;
import org.eclipse.collections.api.tuple.Pair;
import org.eclipse.collections.impl.tuple.Tuples;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.NodeProperties;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.Training;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionData;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionPredictor;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionTrain;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionTrainConfig;
import org.neo4j.gds.ml.nodemodels.metrics.Metric;
import org.neo4j.gds.ml.nodemodels.metrics.MetricSpecification;
import org.neo4j.gds.ml.splitting.FractionSplitter;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.ml.util.ShuffleUtil;
import org.openjdk.jol.util.Multiset;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/NodeClassificationTrain.class */
public final class NodeClassificationTrain extends Algorithm<NodeClassificationTrain, Model<NodeLogisticRegressionData, NodeClassificationTrainConfig, NodeClassificationModelInfo>> {
    public static final String MODEL_TYPE = "nodeLogisticRegression";
    private final Graph graph;
    private final NodeClassificationTrainConfig config;
    private final HugeLongArray targets;
    private final Multiset<Long> classCounts;
    private final HugeLongArray nodeIds;
    private final AllocationTracker allocationTracker;
    private final List<Metric> metrics;
    private final StatsMap trainStats;
    private final StatsMap validationStats;

    /* JADX INFO: Access modifiers changed from: package-private */
    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/nodemodels/NodeClassificationTrain$ModelSelectResult.class */
    public interface ModelSelectResult {
        NodeLogisticRegressionTrainConfig bestParameters();

        Map<Metric, List<ModelStats<NodeLogisticRegressionTrainConfig>>> trainStats();

        Map<Metric, List<ModelStats<NodeLogisticRegressionTrainConfig>>> validationStats();

        static ModelSelectResult of(NodeLogisticRegressionTrainConfig nodeLogisticRegressionTrainConfig, StatsMap statsMap, StatsMap statsMap2) {
            return ImmutableModelSelectResult.of(nodeLogisticRegressionTrainConfig, statsMap.getMap(), statsMap2.getMap());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/ml/nodemodels/NodeClassificationTrain$ModelStatsBuilder.class */
    public static class ModelStatsBuilder {
        private final Map<Metric, Double> min = new HashMap();
        private final Map<Metric, Double> max = new HashMap();
        private final Map<Metric, Double> sum = new HashMap();
        private final NodeLogisticRegressionTrainConfig modelParams;
        private final int numberOfSplits;

        ModelStatsBuilder(NodeLogisticRegressionTrainConfig nodeLogisticRegressionTrainConfig, int i) {
            this.modelParams = nodeLogisticRegressionTrainConfig;
            this.numberOfSplits = i;
        }

        void update(Metric metric, double d) {
            this.min.merge(metric, Double.valueOf(d), (v0, v1) -> {
                return Math.min(v0, v1);
            });
            this.max.merge(metric, Double.valueOf(d), (v0, v1) -> {
                return Math.max(v0, v1);
            });
            this.sum.merge(metric, Double.valueOf(d), (v0, v1) -> {
                return Double.sum(v0, v1);
            });
        }

        ModelStats<NodeLogisticRegressionTrainConfig> build(Metric metric) {
            return ImmutableModelStats.of(this.modelParams, this.sum.get(metric).doubleValue() / this.numberOfSplits, this.min.get(metric).doubleValue(), this.max.get(metric).doubleValue());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MemoryEstimation estimate(NodeClassificationTrainConfig nodeClassificationTrainConfig) {
        int asInt = nodeClassificationTrainConfig.paramsConfig().stream().mapToInt((v0) -> {
            return v0.batchSize();
        }).max().getAsInt();
        int i = 1000;
        double holdoutFraction = nodeClassificationTrainConfig.holdoutFraction();
        int validationFolds = nodeClassificationTrainConfig.validationFolds();
        return MemoryEstimations.builder().perNode("global targets", HugeLongArray::memoryEstimation).rangePerNode("global class counts", j -> {
            return MemoryRange.of(16L, i * 8);
        }).add("metrics", MetricSpecification.memoryEstimation(1000)).perNode("node IDs", HugeLongArray::memoryEstimation).add("outer split", FractionSplitter.estimate(1.0d - holdoutFraction)).add("inner split", StratifiedKFoldSplitter.memoryEstimation(validationFolds, 1.0d - holdoutFraction)).add("stats map train", StatsMap.memoryEstimation(nodeClassificationTrainConfig.metrics().size(), nodeClassificationTrainConfig.params().size())).add("stats map validation", StatsMap.memoryEstimation(nodeClassificationTrainConfig.metrics().size(), nodeClassificationTrainConfig.params().size())).add("max of model selection and best model evaluation", MemoryEstimations.maxEstimation(List.of(modelTrainAndEvaluateMemoryUsage(asInt, 1000, 500, j2 -> {
            return (long) (((j2 * holdoutFraction) * (validationFolds - 1)) / validationFolds);
        }), MemoryEstimations.delegateEstimation(modelTrainAndEvaluateMemoryUsage(asInt, 1000, 500, j3 -> {
            return (long) (j3 * holdoutFraction);
        }), "best model evaluation")))).build();
    }

    public static String taskName() {
        return "NCTrain";
    }

    public static Task progressTask(int i, int i2) {
        return Tasks.task(taskName(), Tasks.leaf("ShuffleAndSplit"), new Task[]{Tasks.iterativeFixed("SelectBestModel", () -> {
            return List.of(Tasks.iterativeFixed("Model Candidate", () -> {
                return List.of(Tasks.task("Split", Training.progressTask("Training"), new Task[]{Tasks.leaf("Evaluate")}));
            }, i));
        }, i2), Training.progressTask("TrainSelectedOnRemainder"), Tasks.leaf("EvaluateSelectedModel"), Training.progressTask("RetrainSelectedModel")});
    }

    @NotNull
    private static MemoryEstimation modelTrainAndEvaluateMemoryUsage(int i, int i2, int i3, LongUnaryOperator longUnaryOperator) {
        return MemoryEstimations.builder("model selection").max(List.of(NodeLogisticRegressionTrain.memoryEstimation(i2, i3, i), MemoryEstimations.builder("computing metrics").perNode("local targets", j -> {
            return HugeLongArray.memoryEstimation(longUnaryOperator.applyAsLong(j));
        }).perNode("predicted classes", j2 -> {
            return HugeLongArray.memoryEstimation(longUnaryOperator.applyAsLong(j2));
        }).fixed("probabilities", MemoryUsage.sizeOfDoubleArray(i2)).fixed("computation graph", NodeLogisticRegressionPredictor.sizeOfPredictionsVariableInBytes(100, i3, i2)).build())).build();
    }

    public static NodeClassificationTrain create(Graph graph, NodeClassificationTrainConfig nodeClassificationTrainConfig, AllocationTracker allocationTracker, ProgressTracker progressTracker) {
        Pair<HugeLongArray, Multiset<Long>> computeGlobalTargetsAndClasses = computeGlobalTargetsAndClasses(graph.nodeProperties(nodeClassificationTrainConfig.targetProperty()), graph.nodeCount(), allocationTracker);
        HugeLongArray hugeLongArray = (HugeLongArray) computeGlobalTargetsAndClasses.getOne();
        Multiset multiset = (Multiset) computeGlobalTargetsAndClasses.getTwo();
        List<Metric> createMetrics = createMetrics(nodeClassificationTrainConfig, multiset);
        HugeLongArray newArray = HugeLongArray.newArray(graph.nodeCount(), allocationTracker);
        newArray.setAll(j -> {
            return j;
        });
        return new NodeClassificationTrain(graph, nodeClassificationTrainConfig, hugeLongArray, multiset, createMetrics, newArray, StatsMap.create(createMetrics), StatsMap.create(createMetrics), allocationTracker, progressTracker);
    }

    private static Pair<HugeLongArray, Multiset<Long>> computeGlobalTargetsAndClasses(NodeProperties nodeProperties, long j, AllocationTracker allocationTracker) {
        Multiset multiset = new Multiset();
        HugeLongArray newArray = HugeLongArray.newArray(j, allocationTracker);
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= j) {
                return Tuples.pair(newArray, multiset);
            }
            newArray.set(j3, nodeProperties.longValue(j3));
            multiset.add(Long.valueOf(nodeProperties.longValue(j3)));
            j2 = j3 + 1;
        }
    }

    private static List<Metric> createMetrics(NodeClassificationTrainConfig nodeClassificationTrainConfig, Multiset<Long> multiset) {
        return (List) nodeClassificationTrainConfig.metrics().stream().flatMap(metricSpecification -> {
            return metricSpecification.createMetrics(multiset.keys());
        }).collect(Collectors.toList());
    }

    private NodeClassificationTrain(Graph graph, NodeClassificationTrainConfig nodeClassificationTrainConfig, HugeLongArray hugeLongArray, Multiset<Long> multiset, List<Metric> list, HugeLongArray hugeLongArray2, StatsMap statsMap, StatsMap statsMap2, AllocationTracker allocationTracker, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.config = nodeClassificationTrainConfig;
        this.targets = hugeLongArray;
        this.classCounts = multiset;
        this.metrics = list;
        this.nodeIds = hugeLongArray2;
        this.trainStats = statsMap;
        this.validationStats = statsMap2;
        this.allocationTracker = allocationTracker;
    }

    /* renamed from: me, reason: merged with bridge method [inline-methods] */
    public NodeClassificationTrain m75me() {
        return this;
    }

    public void release() {
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public Model<NodeLogisticRegressionData, NodeClassificationTrainConfig, NodeClassificationModelInfo> m76compute() {
        this.progressTracker.beginSubTask();
        this.progressTracker.beginSubTask();
        ShuffleUtil.shuffleHugeLongArray(this.nodeIds, ShuffleUtil.createRandomDataGenerator(this.config.randomSeed()));
        TrainingExamplesSplit split = new FractionSplitter(this.allocationTracker).split(this.nodeIds, 1.0d - this.config.holdoutFraction());
        List<TrainingExamplesSplit> splits = new StratifiedKFoldSplitter(this.config.validationFolds(), ReadOnlyHugeLongArray.of(split.trainSet()), ReadOnlyHugeLongArray.of(this.targets), this.config.randomSeed()).splits();
        this.progressTracker.endSubTask();
        ModelSelectResult selectBestModel = selectBestModel(splits);
        NodeLogisticRegressionTrainConfig bestParameters = selectBestModel.bestParameters();
        Map<Metric, MetricData<NodeLogisticRegressionTrainConfig>> evaluateBestModel = evaluateBestModel(split, selectBestModel, bestParameters);
        NodeLogisticRegressionData retrainBestModel = retrainBestModel(bestParameters);
        this.progressTracker.endSubTask();
        return createModel(bestParameters, evaluateBestModel, retrainBestModel);
    }

    private ModelSelectResult selectBestModel(List<TrainingExamplesSplit> list) {
        this.progressTracker.beginSubTask();
        for (NodeLogisticRegressionTrainConfig nodeLogisticRegressionTrainConfig : this.config.paramsConfig()) {
            this.progressTracker.beginSubTask();
            ModelStatsBuilder modelStatsBuilder = new ModelStatsBuilder(nodeLogisticRegressionTrainConfig, list.size());
            ModelStatsBuilder modelStatsBuilder2 = new ModelStatsBuilder(nodeLogisticRegressionTrainConfig, list.size());
            for (TrainingExamplesSplit trainingExamplesSplit : list) {
                this.progressTracker.beginSubTask();
                HugeLongArray trainSet = trainingExamplesSplit.trainSet();
                HugeLongArray testSet = trainingExamplesSplit.testSet();
                this.progressTracker.beginSubTask("Training");
                NodeLogisticRegressionData trainModel = trainModel(trainSet, nodeLogisticRegressionTrainConfig);
                this.progressTracker.endSubTask("Training");
                this.progressTracker.beginSubTask(testSet.size() + trainSet.size());
                Map<Metric, Double> computeMetrics = computeMetrics(this.classCounts, testSet, trainModel, this.metrics);
                Objects.requireNonNull(modelStatsBuilder);
                computeMetrics.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
                Map<Metric, Double> computeMetrics2 = computeMetrics(this.classCounts, trainSet, trainModel, this.metrics);
                Objects.requireNonNull(modelStatsBuilder2);
                computeMetrics2.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
                this.progressTracker.endSubTask();
                this.progressTracker.endSubTask();
            }
            this.progressTracker.endSubTask();
            this.metrics.forEach(metric -> {
                this.validationStats.add(metric, modelStatsBuilder.build(metric));
                this.trainStats.add(metric, modelStatsBuilder2.build(metric));
            });
        }
        this.progressTracker.endSubTask();
        return ModelSelectResult.of(this.validationStats.pickBestModelStats(this.metrics.get(0)).params(), this.trainStats, this.validationStats);
    }

    private Map<Metric, MetricData<NodeLogisticRegressionTrainConfig>> evaluateBestModel(TrainingExamplesSplit trainingExamplesSplit, ModelSelectResult modelSelectResult, NodeLogisticRegressionTrainConfig nodeLogisticRegressionTrainConfig) {
        this.progressTracker.beginSubTask("TrainSelectedOnRemainder");
        NodeLogisticRegressionData trainModel = trainModel(trainingExamplesSplit.trainSet(), nodeLogisticRegressionTrainConfig);
        this.progressTracker.endSubTask("TrainSelectedOnRemainder");
        this.progressTracker.beginSubTask(trainingExamplesSplit.testSet().size() + trainingExamplesSplit.trainSet().size());
        Map<Metric, Double> computeMetrics = computeMetrics(this.classCounts, trainingExamplesSplit.testSet(), trainModel, this.metrics);
        Map<Metric, Double> computeMetrics2 = computeMetrics(this.classCounts, trainingExamplesSplit.trainSet(), trainModel, this.metrics);
        this.progressTracker.endSubTask();
        return mergeMetricResults(modelSelectResult, computeMetrics2, computeMetrics);
    }

    private NodeLogisticRegressionData retrainBestModel(NodeLogisticRegressionTrainConfig nodeLogisticRegressionTrainConfig) {
        this.progressTracker.beginSubTask("RetrainSelectedModel");
        NodeLogisticRegressionData trainModel = trainModel(this.nodeIds, nodeLogisticRegressionTrainConfig);
        this.progressTracker.endSubTask("RetrainSelectedModel");
        return trainModel;
    }

    private Model<NodeLogisticRegressionData, NodeClassificationTrainConfig, NodeClassificationModelInfo> createModel(NodeLogisticRegressionTrainConfig nodeLogisticRegressionTrainConfig, Map<Metric, MetricData<NodeLogisticRegressionTrainConfig>> map, NodeLogisticRegressionData nodeLogisticRegressionData) {
        return Model.of(this.config.username(), this.config.modelName(), MODEL_TYPE, this.graph.schema(), nodeLogisticRegressionData, this.config, NodeClassificationModelInfo.of(nodeLogisticRegressionData.classIdMap().originalIdsList(), nodeLogisticRegressionTrainConfig, map));
    }

    private Map<Metric, MetricData<NodeLogisticRegressionTrainConfig>> mergeMetricResults(ModelSelectResult modelSelectResult, Map<Metric, Double> map, Map<Metric, Double> map2) {
        return (Map) modelSelectResult.validationStats().keySet().stream().collect(Collectors.toMap(Function.identity(), metric -> {
            return MetricData.of(modelSelectResult.trainStats().get(metric), modelSelectResult.validationStats().get(metric), ((Double) map.get(metric)).doubleValue(), ((Double) map2.get(metric)).doubleValue());
        }));
    }

    private NodeLogisticRegressionData trainModel(HugeLongArray hugeLongArray, NodeLogisticRegressionTrainConfig nodeLogisticRegressionTrainConfig) {
        return new NodeLogisticRegressionTrain(this.graph, hugeLongArray, nodeLogisticRegressionTrainConfig, this.progressTracker, this.terminationFlag, this.config.concurrency()).compute();
    }

    private Map<Metric, Double> computeMetrics(Multiset<Long> multiset, HugeLongArray hugeLongArray, NodeLogisticRegressionData nodeLogisticRegressionData, Collection<Metric> collection) {
        NodeLogisticRegressionPredictor nodeLogisticRegressionPredictor = new NodeLogisticRegressionPredictor(nodeLogisticRegressionData, this.config.featureProperties());
        HugeLongArray newArray = HugeLongArray.newArray(hugeLongArray.size(), this.allocationTracker);
        Graph graph = this.graph;
        Objects.requireNonNull(hugeLongArray);
        new BatchQueue(hugeLongArray.size()).parallelConsume(new NodeClassificationPredictConsumer(graph, hugeLongArray::get, nodeLogisticRegressionPredictor, null, newArray, this.config.featureProperties(), this.progressTracker), this.config.concurrency(), this.terminationFlag);
        HugeLongArray makeLocalTargets = makeLocalTargets(hugeLongArray);
        return (Map) collection.stream().collect(Collectors.toMap(Function.identity(), metric -> {
            return Double.valueOf(metric.compute(makeLocalTargets, newArray, multiset));
        }));
    }

    private HugeLongArray makeLocalTargets(HugeLongArray hugeLongArray) {
        HugeLongArray newArray = HugeLongArray.newArray(hugeLongArray.size(), this.allocationTracker);
        NodeProperties nodeProperties = this.graph.nodeProperties(this.config.targetProperty());
        newArray.setAll(j -> {
            return nodeProperties.longValue(hugeLongArray.get(j));
        });
        return newArray;
    }
}
