package org.neo4j.gds.ml.nodemodels;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.neo4j.gds.ml.batch.BatchQueue;
import org.neo4j.gds.ml.nodemodels.logisticregression.MultiClassNLRTrainConfig;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeClassificationTrainConfig;
import org.neo4j.gds.ml.nodemodels.metrics.Metric;
import org.neo4j.gds.ml.nodemodels.multiclasslogisticregression.MultiClassNLRData;
import org.neo4j.gds.ml.nodemodels.multiclasslogisticregression.MultiClassNLRPredictor;
import org.neo4j.gds.ml.nodemodels.multiclasslogisticregression.MultiClassNLRTrain;
import org.neo4j.gds.ml.nodemodels.multiclasslogisticregression.NodeClassificationPredictConsumer;
import org.neo4j.gds.ml.splitting.FractionSplitter;
import org.neo4j.gds.ml.splitting.NodeSplit;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.util.ShuffleUtil;
import org.neo4j.graphalgo.Algorithm;
import org.neo4j.graphalgo.annotation.ValueClass;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.NodeProperties;
import org.neo4j.graphalgo.core.model.Model;
import org.neo4j.graphalgo.core.utils.mem.AllocationTracker;
import org.neo4j.graphalgo.core.utils.paged.HugeLongArray;
import org.neo4j.logging.Log;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/NodeClassificationTrain.class */
public class NodeClassificationTrain extends Algorithm<NodeClassificationTrain, Model<MultiClassNLRData, NodeClassificationTrainConfig>> {
    public static final String MODEL_TYPE = "multiClassNodeLogisticRegression";
    private final Graph graph;
    private final NodeClassificationTrainConfig config;
    private final Log log;
    private final AllocationTracker allocationTracker = AllocationTracker.empty();

    /* JADX INFO: Access modifiers changed from: package-private */
    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/nodemodels/NodeClassificationTrain$ModelSelectResult.class */
    public interface ModelSelectResult {
        Map<String, Object> bestParameters();

        Map<Metric, List<ModelStats>> trainStats();

        Map<Metric, List<ModelStats>> validationStats();

        static ModelSelectResult of(Map<String, Object> map, Map<Metric, List<ModelStats>> map2, Map<Metric, List<ModelStats>> map3) {
            return ImmutableModelSelectResult.of(map, map2, map3);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/ml/nodemodels/NodeClassificationTrain$ModelStatsBuilder.class */
    public class ModelStatsBuilder {
        private final Map<Metric, Double> min = new HashMap();
        private final Map<Metric, Double> max = new HashMap();
        private final Map<Metric, Double> sum = new HashMap();
        private final Map<String, Object> modelParams;
        private final int numberOfSplits;

        ModelStatsBuilder(Map<String, Object> map, int i) {
            this.modelParams = map;
            this.numberOfSplits = i;
        }

        void update(Metric metric, double d) {
            this.min.merge(metric, Double.valueOf(d), (v0, v1) -> {
                return Math.min(v0, v1);
            });
            this.max.merge(metric, Double.valueOf(d), (v0, v1) -> {
                return Math.max(v0, v1);
            });
            this.sum.merge(metric, Double.valueOf(d), (v0, v1) -> {
                return Double.sum(v0, v1);
            });
        }

        ModelStats modelStats(Metric metric) {
            return ImmutableModelStats.of(this.modelParams, this.sum.get(metric).doubleValue() / this.numberOfSplits, this.min.get(metric).doubleValue(), this.max.get(metric).doubleValue());
        }
    }

    public NodeClassificationTrain(Graph graph, NodeClassificationTrainConfig nodeClassificationTrainConfig, Log log) {
        this.graph = graph;
        this.config = nodeClassificationTrainConfig;
        this.log = log;
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public Model<MultiClassNLRData, NodeClassificationTrainConfig> m13compute() {
        HugeLongArray newArray = HugeLongArray.newArray(this.graph.nodeCount(), this.allocationTracker);
        newArray.setAll(j -> {
            return j;
        });
        ShuffleUtil.shuffleHugeLongArray(newArray, getRandomDataGenerator());
        NodeSplit split = new FractionSplitter().split(newArray, 1.0d - this.config.holdoutFraction());
        HugeLongArray makeGlobalTargets = makeGlobalTargets();
        ModelSelectResult modelSelect = modelSelect(split.trainSet(), makeGlobalTargets);
        Map<String, Object> bestParameters = modelSelect.bestParameters();
        MultiClassNLRData trainModel = trainModel(split.trainSet(), bestParameters);
        Map<Metric, MetricData> mergeMetrics = mergeMetrics(modelSelect, computeMetrics(makeGlobalTargets, split.trainSet(), trainModel), computeMetrics(makeGlobalTargets, split.testSet(), trainModel));
        MultiClassNLRData trainModel2 = trainModel(newArray, bestParameters);
        return Model.of(this.config.username(), this.config.modelName(), MODEL_TYPE, this.graph.schema(), trainModel2, this.config, NodeClassificationModelInfo.of(trainModel2.classIdMap().originalIdsList(), modelSelect.bestParameters(), mergeMetrics));
    }

    private RandomDataGenerator getRandomDataGenerator() {
        RandomDataGenerator randomDataGenerator = new RandomDataGenerator();
        Optional<Long> randomSeed = this.config.randomSeed();
        Objects.requireNonNull(randomDataGenerator);
        randomSeed.ifPresent((v1) -> {
            r1.reSeed(v1);
        });
        return randomDataGenerator;
    }

    private Map<Metric, MetricData> mergeMetrics(ModelSelectResult modelSelectResult, Map<Metric, Double> map, Map<Metric, Double> map2) {
        return (Map) modelSelectResult.validationStats().keySet().stream().collect(Collectors.toMap(metric -> {
            return metric;
        }, metric2 -> {
            return MetricData.of(modelSelectResult.trainStats().get(metric2), modelSelectResult.validationStats().get(metric2), ((Double) map.get(metric2)).doubleValue(), ((Double) map2.get(metric2)).doubleValue());
        }));
    }

    private ModelSelectResult modelSelect(HugeLongArray hugeLongArray, HugeLongArray hugeLongArray2) {
        List<NodeSplit> splits = new StratifiedKFoldSplitter(this.config.validationFolds(), hugeLongArray, hugeLongArray2).splits();
        Map<Metric, List<ModelStats>> initStatsMap = initStatsMap();
        Map<Metric, List<ModelStats>> initStatsMap2 = initStatsMap();
        this.config.params().forEach(map -> {
            ModelStatsBuilder modelStatsBuilder = new ModelStatsBuilder(map, splits.size());
            ModelStatsBuilder modelStatsBuilder2 = new ModelStatsBuilder(map, splits.size());
            Iterator it = splits.iterator();
            while (it.hasNext()) {
                NodeSplit nodeSplit = (NodeSplit) it.next();
                HugeLongArray trainSet = nodeSplit.trainSet();
                HugeLongArray testSet = nodeSplit.testSet();
                MultiClassNLRData trainModel = trainModel(trainSet, map);
                Map<Metric, Double> computeMetrics = computeMetrics(hugeLongArray2, testSet, trainModel);
                Objects.requireNonNull(modelStatsBuilder);
                computeMetrics.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
                Map<Metric, Double> computeMetrics2 = computeMetrics(hugeLongArray2, trainSet, trainModel);
                Objects.requireNonNull(modelStatsBuilder2);
                computeMetrics2.forEach((v1, v2) -> {
                    r1.update(v1, v2);
                });
            }
            this.config.metrics().forEach(metric -> {
                ((List) initStatsMap2.get(metric)).add(modelStatsBuilder.modelStats(metric));
                ((List) initStatsMap.get(metric)).add(modelStatsBuilder2.modelStats(metric));
            });
        });
        return ModelSelectResult.of(((ModelStats) Collections.max(initStatsMap2.get(this.config.metrics().get(0)), ModelStats.COMPARE_AVERAGE)).params(), initStatsMap, initStatsMap2);
    }

    private Map<Metric, List<ModelStats>> initStatsMap() {
        HashMap hashMap = new HashMap();
        this.config.metrics().forEach(metric -> {
            hashMap.put(metric, new ArrayList());
        });
        return hashMap;
    }

    private Map<Metric, Double> computeMetrics(HugeLongArray hugeLongArray, HugeLongArray hugeLongArray2, MultiClassNLRData multiClassNLRData) {
        HugeLongArray makeLocalTargets = makeLocalTargets(hugeLongArray2);
        HugeLongArray newArray = HugeLongArray.newArray(hugeLongArray2.size(), this.allocationTracker);
        Graph graph = this.graph;
        Objects.requireNonNull(hugeLongArray2);
        new BatchQueue(hugeLongArray2.size()).parallelConsume(new NodeClassificationPredictConsumer(graph, hugeLongArray2::get, predictor(multiClassNLRData), null, newArray, this.progressLogger), this.config.concurrency());
        return (Map) this.config.metrics().stream().collect(Collectors.toMap(metric -> {
            return metric;
        }, metric2 -> {
            return Double.valueOf(metric2.compute(makeLocalTargets, newArray, hugeLongArray));
        }));
    }

    private MultiClassNLRPredictor predictor(MultiClassNLRData multiClassNLRData) {
        return new MultiClassNLRPredictor(multiClassNLRData, this.config.featureProperties());
    }

    private MultiClassNLRData trainModel(HugeLongArray hugeLongArray, Map<String, Object> map) {
        return new MultiClassNLRTrain(this.graph, hugeLongArray, MultiClassNLRTrainConfig.of(this.config.featureProperties(), this.config.targetProperty(), this.config.concurrency(), map), this.log).compute();
    }

    private HugeLongArray makeGlobalTargets() {
        HugeLongArray newArray = HugeLongArray.newArray(this.graph.nodeCount(), this.allocationTracker);
        NodeProperties nodeProperties = this.graph.nodeProperties(this.config.targetProperty());
        Objects.requireNonNull(nodeProperties);
        newArray.setAll(nodeProperties::longValue);
        return newArray;
    }

    private HugeLongArray makeLocalTargets(HugeLongArray hugeLongArray) {
        HugeLongArray newArray = HugeLongArray.newArray(hugeLongArray.size(), this.allocationTracker);
        NodeProperties nodeProperties = this.graph.nodeProperties(this.config.targetProperty());
        newArray.setAll(j -> {
            return nodeProperties.longValue(hugeLongArray.get(j));
        });
        return newArray;
    }

    /* renamed from: me, reason: merged with bridge method [inline-methods] */
    public NodeClassificationTrain m12me() {
        return this;
    }

    public void release() {
    }
}
