package org.neo4j.gds.ml.nodemodels;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.neo4j.gds.ml.batch.BatchQueue;
import org.neo4j.gds.ml.nodemodels.metrics.Metric;
import org.neo4j.gds.ml.nodemodels.multiclasslogisticregression.MultiClassNLRData;
import org.neo4j.gds.ml.nodemodels.multiclasslogisticregression.MultiClassNLRPredictor;
import org.neo4j.gds.ml.nodemodels.multiclasslogisticregression.MultiClassNLRTrain;
import org.neo4j.gds.ml.nodemodels.multiclasslogisticregression.MultiClassNLRTrainConfig;
import org.neo4j.gds.ml.splitting.FractionSplitter;
import org.neo4j.gds.ml.splitting.NodeSplit;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.util.ShuffleUtil;
import org.neo4j.graphalgo.Algorithm;
import org.neo4j.graphalgo.annotation.ValueClass;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.NodeProperties;
import org.neo4j.graphalgo.core.model.Model;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.mem.AllocationTracker;
import org.neo4j.graphalgo.core.utils.paged.HugeLongArray;
import org.neo4j.graphalgo.utils.StringFormatting;
import org.openjdk.jol.util.Multiset;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/NodeClassificationTrain.class */
public class NodeClassificationTrain extends Algorithm<NodeClassificationTrain, Model<MultiClassNLRData, NodeClassificationTrainConfig>> {
    public static final String MODEL_TYPE = "multiClassNodeLogisticRegression";
    private final Graph graph;
    private final NodeClassificationTrainConfig config;
    private final AllocationTracker allocationTracker;

    /* JADX INFO: Access modifiers changed from: package-private */
    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/nodemodels/NodeClassificationTrain$ModelSelectResult.class */
    public interface ModelSelectResult {
        Map<String, Object> bestParameters();

        Map<Metric, List<ModelStats>> trainStats();

        Map<Metric, List<ModelStats>> validationStats();

        static ModelSelectResult of(Map<String, Object> map, Map<Metric, List<ModelStats>> map2, Map<Metric, List<ModelStats>> map3) {
            return ImmutableModelSelectResult.of(map, map2, map3);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/ml/nodemodels/NodeClassificationTrain$ModelStatsBuilder.class */
    public class ModelStatsBuilder {
        private final Map<Metric, Double> min = new HashMap();
        private final Map<Metric, Double> max = new HashMap();
        private final Map<Metric, Double> sum = new HashMap();
        private final Map<String, Object> modelParams;
        private final int numberOfSplits;

        ModelStatsBuilder(Map<String, Object> map, int i) {
            this.modelParams = map;
            this.numberOfSplits = i;
        }

        void update(Metric metric, double d) {
            this.min.merge(metric, Double.valueOf(d), (v0, v1) -> {
                return Math.min(v0, v1);
            });
            this.max.merge(metric, Double.valueOf(d), (v0, v1) -> {
                return Math.max(v0, v1);
            });
            this.sum.merge(metric, Double.valueOf(d), (v0, v1) -> {
                return Double.sum(v0, v1);
            });
        }

        ModelStats modelStats(Metric metric) {
            return ImmutableModelStats.of(this.modelParams, this.sum.get(metric).doubleValue() / this.numberOfSplits, this.min.get(metric).doubleValue(), this.max.get(metric).doubleValue());
        }
    }

    public NodeClassificationTrain(Graph graph, NodeClassificationTrainConfig nodeClassificationTrainConfig, AllocationTracker allocationTracker, ProgressLogger progressLogger) {
        this.graph = graph;
        this.config = nodeClassificationTrainConfig;
        this.allocationTracker = allocationTracker;
        this.progressLogger = progressLogger;
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public Model<MultiClassNLRData, NodeClassificationTrainConfig> m15compute() {
        HugeLongArray makeGlobalTargets = makeGlobalTargets();
        Multiset<Long> countClassesGlobally = countClassesGlobally();
        List<Metric> createMetrics = createMetrics(countClassesGlobally);
        this.progressLogger.logStart(":: Shuffle and Split");
        HugeLongArray newArray = HugeLongArray.newArray(this.graph.nodeCount(), this.allocationTracker);
        newArray.setAll(j -> {
            return j;
        });
        ShuffleUtil.shuffleHugeLongArray(newArray, getRandomDataGenerator());
        NodeSplit split = new FractionSplitter().split(newArray, 1.0d - this.config.holdoutFraction());
        List<NodeSplit> splits = new StratifiedKFoldSplitter(this.config.validationFolds(), split.trainSet(), makeGlobalTargets, this.config.randomSeed()).splits();
        this.progressLogger.logFinish(":: Shuffle and Split");
        Map<Metric, List<ModelStats>> initStatsMap = initStatsMap(createMetrics);
        Map<Metric, List<ModelStats>> initStatsMap2 = initStatsMap(createMetrics);
        for (int i = 0; i < this.config.params().size(); i++) {
            String formatWithLocale = StringFormatting.formatWithLocale(":: Model Candidate %s of %s", new Object[]{Integer.valueOf(i + 1), Integer.valueOf(this.config.params().size())});
            Map<String, Object> map = this.config.params().get(i);
            ModelStatsBuilder modelStatsBuilder = new ModelStatsBuilder(map, splits.size());
            ModelStatsBuilder modelStatsBuilder2 = new ModelStatsBuilder(map, splits.size());
            for (int i2 = 0; i2 < splits.size(); i2++) {
                NodeSplit nodeSplit = splits.get(i2);
                String formatWithLocale2 = StringFormatting.formatWithLocale(formatWithLocale + " :: Split %s of %s", new Object[]{Integer.valueOf(i2 + 1), Integer.valueOf(splits.size())});
                HugeLongArray trainSet = nodeSplit.trainSet();
                HugeLongArray testSet = nodeSplit.testSet();
                this.progressLogger.logStart(formatWithLocale2 + " :: Train");
                int intValue = ((Number) map.getOrDefault("maxEpochs", 100)).intValue();
                this.progressLogger.logMessage(StringFormatting.formatWithLocale(formatWithLocale2 + " :: Train :: Max iterations: %s", new Object[]{Integer.valueOf(intValue)}));
                this.progressLogger.reset(intValue);
                MultiClassNLRData trainModel = trainModel(trainSet, map);
                this.progressLogger.logFinish(formatWithLocale2 + " :: Train");
                this.progressLogger.logStart(formatWithLocale2 + " :: Evaluate");
                this.progressLogger.reset(testSet.size() + trainSet.size());
                Map<Metric, Double> computeMetrics = computeMetrics(countClassesGlobally, testSet, trainModel, createMetrics);
                Objects.requireNonNull(modelStatsBuilder);
                computeMetrics.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
                Map<Metric, Double> computeMetrics2 = computeMetrics(countClassesGlobally, trainSet, trainModel, createMetrics);
                Objects.requireNonNull(modelStatsBuilder2);
                computeMetrics2.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
                this.progressLogger.logFinish(formatWithLocale2 + " :: Evaluate");
            }
            createMetrics.forEach(metric -> {
                ((List) initStatsMap2.get(metric)).add(modelStatsBuilder.modelStats(metric));
                ((List) initStatsMap.get(metric)).add(modelStatsBuilder2.modelStats(metric));
            });
        }
        this.progressLogger.logStart(":: Select Model");
        ModelStats modelStats = (ModelStats) Collections.max(initStatsMap2.get(createMetrics.get(0)), ModelStats.COMPARE_AVERAGE);
        this.progressLogger.logFinish(":: Select Model");
        Map<String, Object> params = modelStats.params();
        int intValue2 = ((Number) params.getOrDefault("maxEpochs", 100)).intValue();
        ModelSelectResult of = ModelSelectResult.of(params, initStatsMap, initStatsMap2);
        Map<String, Object> bestParameters = of.bestParameters();
        this.progressLogger.logStart(":: Train Selected on Remainder");
        this.progressLogger.reset(intValue2);
        MultiClassNLRData trainModel2 = trainModel(split.trainSet(), bestParameters);
        this.progressLogger.logFinish(":: Train Selected on Remainder");
        this.progressLogger.logStart(":: Evaluate Selected Model");
        this.progressLogger.reset(split.testSet().size() + split.trainSet().size());
        Map<Metric, Double> computeMetrics3 = computeMetrics(countClassesGlobally, split.testSet(), trainModel2, createMetrics);
        Map<Metric, Double> computeMetrics4 = computeMetrics(countClassesGlobally, split.trainSet(), trainModel2, createMetrics);
        this.progressLogger.logFinish(":: Evaluate Selected Model");
        Map<Metric, MetricData> mergeMetricResults = mergeMetricResults(of, computeMetrics4, computeMetrics3);
        this.progressLogger.logStart(":: Retrain Selected Model");
        this.progressLogger.reset(intValue2);
        MultiClassNLRData trainModel3 = trainModel(newArray, bestParameters);
        this.progressLogger.logFinish(":: Retrain Selected Model");
        return Model.of(this.config.username(), this.config.modelName(), MODEL_TYPE, this.graph.schema(), trainModel3, this.config, NodeClassificationModelInfo.of(trainModel3.classIdMap().originalIdsList(), of.bestParameters(), mergeMetricResults));
    }

    private List<Metric> createMetrics(Multiset<Long> multiset) {
        return (List) this.config.metrics().stream().flatMap(metricSpecification -> {
            return metricSpecification.createMetrics(multiset.keys());
        }).collect(Collectors.toList());
    }

    private RandomDataGenerator getRandomDataGenerator() {
        RandomDataGenerator randomDataGenerator = new RandomDataGenerator();
        Optional<Long> randomSeed = this.config.randomSeed();
        Objects.requireNonNull(randomDataGenerator);
        randomSeed.ifPresent((v1) -> {
            r1.reSeed(v1);
        });
        return randomDataGenerator;
    }

    private Map<Metric, MetricData> mergeMetricResults(ModelSelectResult modelSelectResult, Map<Metric, Double> map, Map<Metric, Double> map2) {
        return (Map) modelSelectResult.validationStats().keySet().stream().collect(Collectors.toMap(metric -> {
            return metric;
        }, metric2 -> {
            return MetricData.of(modelSelectResult.trainStats().get(metric2), modelSelectResult.validationStats().get(metric2), ((Double) map.get(metric2)).doubleValue(), ((Double) map2.get(metric2)).doubleValue());
        }));
    }

    private Map<Metric, List<ModelStats>> initStatsMap(List<Metric> list) {
        HashMap hashMap = new HashMap();
        list.forEach(metric -> {
            hashMap.put(metric, new ArrayList());
        });
        return hashMap;
    }

    private Map<Metric, Double> computeMetrics(Multiset<Long> multiset, HugeLongArray hugeLongArray, MultiClassNLRData multiClassNLRData, List<Metric> list) {
        HugeLongArray makeLocalTargets = makeLocalTargets(hugeLongArray);
        HugeLongArray newArray = HugeLongArray.newArray(hugeLongArray.size(), this.allocationTracker);
        Graph graph = this.graph;
        Objects.requireNonNull(hugeLongArray);
        new BatchQueue(hugeLongArray.size()).parallelConsume(new NodeClassificationPredictConsumer(graph, hugeLongArray::get, predictor(multiClassNLRData), null, newArray, this.config.featureProperties(), this.progressLogger), this.config.concurrency());
        return (Map) list.stream().collect(Collectors.toMap(metric -> {
            return metric;
        }, metric2 -> {
            return Double.valueOf(metric2.compute(makeLocalTargets, newArray, multiset));
        }));
    }

    private MultiClassNLRPredictor predictor(MultiClassNLRData multiClassNLRData) {
        return new MultiClassNLRPredictor(multiClassNLRData, this.config.featureProperties());
    }

    private MultiClassNLRData trainModel(HugeLongArray hugeLongArray, Map<String, Object> map) {
        return new MultiClassNLRTrain(this.graph, hugeLongArray, MultiClassNLRTrainConfig.of(this.config.featureProperties(), this.config.targetProperty(), this.config.concurrency(), map), this.progressLogger).compute();
    }

    private Multiset<Long> countClassesGlobally() {
        Multiset<Long> multiset = new Multiset<>();
        NodeProperties nodeProperties = this.graph.nodeProperties(this.config.targetProperty());
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= this.graph.nodeCount()) {
                return multiset;
            }
            multiset.add(Long.valueOf(nodeProperties.longValue(j2)));
            j = j2 + 1;
        }
    }

    private HugeLongArray makeGlobalTargets() {
        HugeLongArray newArray = HugeLongArray.newArray(this.graph.nodeCount(), this.allocationTracker);
        NodeProperties nodeProperties = this.graph.nodeProperties(this.config.targetProperty());
        Objects.requireNonNull(nodeProperties);
        newArray.setAll(nodeProperties::longValue);
        return newArray;
    }

    private HugeLongArray makeLocalTargets(HugeLongArray hugeLongArray) {
        HugeLongArray newArray = HugeLongArray.newArray(hugeLongArray.size(), this.allocationTracker);
        NodeProperties nodeProperties = this.graph.nodeProperties(this.config.targetProperty());
        newArray.setAll(j -> {
            return nodeProperties.longValue(hugeLongArray.get(j));
        });
        return newArray;
    }

    /* renamed from: me, reason: merged with bridge method [inline-methods] */
    public NodeClassificationTrain m14me() {
        return this;
    }

    public void release() {
    }
}
