package org.neo4j.gds.ml.models.randomforest;

import com.carrotsearch.hppc.BitSet;
import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeAtomicLongArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.decisiontree.DecisionTreeClassifierTrainer;
import org.neo4j.gds.ml.decisiontree.DecisionTreeLoss;
import org.neo4j.gds.ml.decisiontree.DecisionTreePredictor;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainerConfig;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainerConfigImpl;
import org.neo4j.gds.ml.decisiontree.FeatureBagger;
import org.neo4j.gds.ml.decisiontree.GiniIndex;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.splitting.EdgeSplitter;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/models/randomforest/RandomForestClassifierTrainer.class */
public class RandomForestClassifierTrainer implements ClassifierTrainer {
    private final LocalIdMap classIdMap;
    private final RandomForestTrainerConfig config;
    private final int concurrency;
    private final boolean computeOutOfBagError;
    private final SplittableRandom random;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;
    private Optional<Double> outOfBagError = Optional.empty();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/ml/models/randomforest/RandomForestClassifierTrainer$TrainDecisionTreeTask.class */
    public static class TrainDecisionTreeTask<LOSS extends DecisionTreeLoss> implements Runnable {
        private DecisionTreePredictor<Integer> trainedTree;
        private final Optional<HugeAtomicLongArray> maybePredictions;
        private final DecisionTreeTrainerConfig decisionTreeTrainConfig;
        private final RandomForestTrainerConfig randomForestTrainConfig;
        private final SplittableRandom random;
        private final Features allFeatureVectors;
        private final HugeLongArray allLabels;
        private final LocalIdMap classIdMap;
        private final LOSS lossFunction;
        private final ReadOnlyHugeLongArray trainSet;
        private final ProgressTracker progressTracker;
        private final AtomicInteger numberOfTreesTrained;

        /* JADX INFO: Access modifiers changed from: package-private */
        @ValueClass
        /* loaded from: input_file:org/neo4j/gds/ml/models/randomforest/RandomForestClassifierTrainer$TrainDecisionTreeTask$BootstrappedDataset.class */
        public interface BootstrappedDataset {
            BitSet trainSetIndices();

            ReadOnlyHugeLongArray allVectorsIndices();
        }

        TrainDecisionTreeTask(Optional<HugeAtomicLongArray> optional, DecisionTreeTrainerConfig decisionTreeTrainerConfig, RandomForestTrainerConfig randomForestTrainerConfig, SplittableRandom splittableRandom, Features features, HugeLongArray hugeLongArray, LocalIdMap localIdMap, LOSS loss, ReadOnlyHugeLongArray readOnlyHugeLongArray, ProgressTracker progressTracker, AtomicInteger atomicInteger) {
            this.maybePredictions = optional;
            this.decisionTreeTrainConfig = decisionTreeTrainerConfig;
            this.randomForestTrainConfig = randomForestTrainerConfig;
            this.random = splittableRandom;
            this.allFeatureVectors = features;
            this.allLabels = hugeLongArray;
            this.classIdMap = localIdMap;
            this.lossFunction = loss;
            this.trainSet = readOnlyHugeLongArray;
            this.progressTracker = progressTracker;
            this.numberOfTreesTrained = atomicInteger;
        }

        public static MemoryRange memoryEstimation(int i, int i2, long j, int i3, int i4, double d) {
            long ceil = (long) Math.ceil(d * j);
            return MemoryRange.of(MemoryUsage.sizeOfInstance(TrainDecisionTreeTask.class)).add(FeatureBagger.memoryEstimation(i4)).add(DecisionTreeClassifierTrainer.memoryEstimation(i, i2, ceil, i4, i3)).add(MemoryRange.of(HugeLongArray.memoryEstimation(ceil)).add(MemoryUsage.sizeOfBitset(ceil)));
        }

        public DecisionTreePredictor<Integer> trainedTree() {
            return this.trainedTree;
        }

        @Override // java.lang.Runnable
        public void run() {
            DecisionTreeClassifierTrainer decisionTreeClassifierTrainer = new DecisionTreeClassifierTrainer(this.lossFunction, this.allFeatureVectors, this.allLabels, this.classIdMap, this.decisionTreeTrainConfig, new FeatureBagger(this.random, this.allFeatureVectors.featureDimension(), this.randomForestTrainConfig.maxFeaturesRatio(this.allFeatureVectors.featureDimension())));
            BootstrappedDataset bootstrappedDataset = bootstrappedDataset();
            this.trainedTree = decisionTreeClassifierTrainer.train(bootstrappedDataset.allVectorsIndices());
            this.maybePredictions.ifPresent(hugeAtomicLongArray -> {
                OutOfBagErrorMetric.addPredictionsForTree(this.trainedTree, this.classIdMap, this.allFeatureVectors, this.trainSet, bootstrappedDataset.trainSetIndices(), hugeAtomicLongArray);
            });
            this.progressTracker.logMessage(StringFormatting.formatWithLocale("Trained decision tree %d out of %d", new Object[]{Integer.valueOf(this.numberOfTreesTrained.incrementAndGet()), Integer.valueOf(this.randomForestTrainConfig.numberOfDecisionTrees())}));
        }

        private BootstrappedDataset bootstrappedDataset() {
            ReadOnlyHugeLongArray bootstrap;
            BitSet bitSet = new BitSet(this.trainSet.size());
            if (Double.compare(this.randomForestTrainConfig.numberOfSamplesRatio(), EdgeSplitter.NEGATIVE) == 0) {
                bootstrap = this.trainSet;
                bitSet.set(1L, this.trainSet.size());
            } else {
                bootstrap = DatasetBootstrapper.bootstrap(this.random, this.randomForestTrainConfig.numberOfSamplesRatio(), this.trainSet, bitSet);
            }
            return ImmutableBootstrappedDataset.of(bitSet, bootstrap);
        }
    }

    public RandomForestClassifierTrainer(int i, LocalIdMap localIdMap, RandomForestTrainerConfig randomForestTrainerConfig, boolean z, Optional<Long> optional, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.classIdMap = localIdMap;
        this.config = randomForestTrainerConfig;
        this.concurrency = i;
        this.computeOutOfBagError = z;
        this.random = new SplittableRandom(optional.orElseGet(() -> {
            return Long.valueOf(new SplittableRandom().nextLong());
        }).longValue());
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
    }

    public static MemoryEstimation memoryEstimation(LongUnaryOperator longUnaryOperator, int i, MemoryRange memoryRange, RandomForestTrainerConfig randomForestTrainerConfig) {
        int ceil = (int) Math.ceil(randomForestTrainerConfig.maxFeaturesRatio((int) memoryRange.min) * memoryRange.min);
        int ceil2 = (int) Math.ceil(randomForestTrainerConfig.maxFeaturesRatio((int) memoryRange.max) * memoryRange.max);
        return MemoryEstimations.builder("Training", RandomForestClassifierTrainer.class).add(RandomForestClassifierData.memoryEstimation(longUnaryOperator, randomForestTrainerConfig)).rangePerNode("GiniIndex Loss", j -> {
            return GiniIndex.memoryEstimation(longUnaryOperator.applyAsLong(j));
        }).perGraphDimension("Decision tree training", (graphDimensions, num) -> {
            return TrainDecisionTreeTask.memoryEstimation(randomForestTrainerConfig.maxDepth(), randomForestTrainerConfig.minSplitSize(), longUnaryOperator.applyAsLong(graphDimensions.nodeCount()), i, ceil, randomForestTrainerConfig.numberOfSamplesRatio()).union(TrainDecisionTreeTask.memoryEstimation(randomForestTrainerConfig.maxDepth(), randomForestTrainerConfig.minSplitSize(), longUnaryOperator.applyAsLong(graphDimensions.nodeCount()), i, ceil2, randomForestTrainerConfig.numberOfSamplesRatio())).times(num.intValue());
        }).build();
    }

    @Override // org.neo4j.gds.ml.models.ClassifierTrainer
    public RandomForestClassifier train(Features features, HugeLongArray hugeLongArray, ReadOnlyHugeLongArray readOnlyHugeLongArray) {
        Optional of = this.computeOutOfBagError ? Optional.of(HugeAtomicLongArray.newArray(this.classIdMap.size() * readOnlyHugeLongArray.size())) : Optional.empty();
        DecisionTreeTrainerConfig build = DecisionTreeTrainerConfigImpl.builder().maxDepth(this.config.maxDepth()).minSplitSize(this.config.minSplitSize()).build();
        int numberOfDecisionTrees = this.config.numberOfDecisionTrees();
        GiniIndex fromOriginalLabels = GiniIndex.fromOriginalLabels(hugeLongArray, this.classIdMap);
        AtomicInteger atomicInteger = new AtomicInteger(0);
        List list = (List) IntStream.range(0, numberOfDecisionTrees).mapToObj(i -> {
            return new TrainDecisionTreeTask(of, build, this.config, this.random.split(), features, hugeLongArray, this.classIdMap, fromOriginalLabels, readOnlyHugeLongArray, this.progressTracker, atomicInteger);
        }).collect(Collectors.toList());
        ParallelUtil.runWithConcurrency(this.concurrency, list, this.terminationFlag, Pools.DEFAULT);
        this.outOfBagError = of.map(hugeAtomicLongArray -> {
            return Double.valueOf(OutOfBagErrorMetric.evaluate(readOnlyHugeLongArray, this.classIdMap, hugeLongArray, this.concurrency, hugeAtomicLongArray));
        });
        return new RandomForestClassifier((List) list.stream().map((v0) -> {
            return v0.trainedTree();
        }).collect(Collectors.toList()), this.classIdMap, features.featureDimension());
    }

    double outOfBagError() {
        return this.outOfBagError.orElseThrow(() -> {
            return new IllegalAccessError("Out of bag error has not been computed.");
        }).doubleValue();
    }
}
