package org.neo4j.gds.embeddings.graphsage;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainParameters;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.ConstantScale;
import org.neo4j.gds.ml.core.functions.ElementSum;
import org.neo4j.gds.ml.core.functions.L2NormSquared;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.optimizer.AdamOptimizer;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.TensorFunctions;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.class */
public class GraphSageModelTrainer {
    private final long randomSeed;
    private final GraphSageTrainParameters parameters;
    private final FeatureFunction featureFunction;
    private final Collection<Weights<Matrix>> labelProjectionWeights;
    private final ExecutorService executor;
    private final ProgressTracker progressTracker;
    private final Layer[] layers;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer$BatchTask.class */
    public static class BatchTask implements Runnable {
        private final Variable<Scalar> lossFunction;
        private final List<Weights<? extends Tensor<?>>> weightVariables;
        private List<? extends Tensor<?>> weightGradients;
        private final ProgressTracker progressTracker;
        private double loss;

        BatchTask(Variable<Scalar> variable, List<Weights<? extends Tensor<?>>> list, ProgressTracker progressTracker) {
            this.lossFunction = variable;
            this.weightVariables = list;
            this.progressTracker = progressTracker;
        }

        @Override // java.lang.Runnable
        public void run() {
            ComputationContext computationContext = new ComputationContext();
            this.loss = computationContext.forward(this.lossFunction).value();
            computationContext.backward(this.lossFunction);
            Stream<Weights<? extends Tensor<?>>> stream = this.weightVariables.stream();
            Objects.requireNonNull(computationContext);
            this.weightGradients = (List) stream.map((v1) -> {
                return r2.gradient(v1);
            }).collect(Collectors.toList());
            this.progressTracker.logProgress();
        }

        public double loss() {
            return this.loss;
        }

        List<? extends Tensor<?>> weightGradients() {
            return this.weightGradients;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer$EpochResult.class */
    public interface EpochResult {
        boolean converged();

        List<Double> losses();
    }

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer$GraphSageTrainMetrics.class */
    public interface GraphSageTrainMetrics extends Model.CustomInfo {
        static GraphSageTrainMetrics empty() {
            return ImmutableGraphSageTrainMetrics.of((List<List<Double>>) List.of(), false);
        }

        @Value.Derived
        default List<Double> epochLosses() {
            return (List) iterationLossPerEpoch().stream().map(list -> {
                return (Double) list.get(list.size() - 1);
            }).collect(Collectors.toList());
        }

        List<List<Double>> iterationLossPerEpoch();

        boolean didConverge();

        @Value.Derived
        default int ranEpochs() {
            if (iterationLossPerEpoch().isEmpty()) {
                return 0;
            }
            return iterationLossPerEpoch().size();
        }

        @Value.Derived
        default List<Integer> ranIterationsPerEpoch() {
            return (List) iterationLossPerEpoch().stream().map((v0) -> {
                return v0.size();
            }).collect(Collectors.toList());
        }

        @Value.Auxiliary
        @Value.Derived
        default Map<String, Object> toMap() {
            return Map.of("metrics", Map.of("epochLosses", epochLosses(), "iterationLossesPerEpoch", iterationLossPerEpoch(), "didConverge", Boolean.valueOf(didConverge()), "ranEpochs", Integer.valueOf(ranEpochs()), "ranIterationsPerEpoch", ranIterationsPerEpoch()));
        }
    }

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer$ModelTrainResult.class */
    public interface ModelTrainResult {
        GraphSageTrainMetrics metrics();

        Layer[] layers();

        static ModelTrainResult of(List<List<Double>> list, boolean z, Layer[] layerArr) {
            return ImmutableModelTrainResult.builder().layers(layerArr).metrics(ImmutableGraphSageTrainMetrics.of(list, z)).build();
        }
    }

    public GraphSageModelTrainer(GraphSageTrainParameters graphSageTrainParameters, int i, ExecutorService executorService, ProgressTracker progressTracker) {
        this(graphSageTrainParameters, executorService, progressTracker, new SingleLabelFeatureFunction(), Collections.emptyList(), i);
    }

    public GraphSageModelTrainer(GraphSageTrainParameters graphSageTrainParameters, ExecutorService executorService, ProgressTracker progressTracker, FeatureFunction featureFunction, Collection<Weights<Matrix>> collection, int i) {
        this.parameters = graphSageTrainParameters;
        this.featureFunction = featureFunction;
        this.labelProjectionWeights = collection;
        this.executor = executorService;
        this.progressTracker = progressTracker;
        this.randomSeed = graphSageTrainParameters.randomSeed().orElseGet(() -> {
            return Long.valueOf(ThreadLocalRandom.current().nextLong());
        }).longValue();
        this.layers = (Layer[]) graphSageTrainParameters.layerConfigs(i).stream().map(LayerFactory::createLayer).toArray(i2 -> {
            return new Layer[i2];
        });
    }

    public static List<Task> progressTasks(long j, int i, int i2, int i3) {
        return List.of(Tasks.leaf("Prepare batches", j), Tasks.iterativeDynamic("Train model", () -> {
            return List.of(Tasks.iterativeDynamic("Epoch", () -> {
                return List.of(Tasks.leaf("Iteration", i));
            }, i2));
        }, i3));
    }

    public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> hugeObjectArray) {
        Supplier<List<BatchTask>> supplier;
        ArrayList arrayList = new ArrayList(this.labelProjectionWeights);
        for (Layer layer : this.layers) {
            arrayList.addAll(layer.weights());
        }
        this.progressTracker.beginSubTask("Prepare batches");
        List<long[]> extendedBatches = new BatchSampler(graph, this.progressTracker).extendedBatches(this.parameters.batchSize(), this.parameters.searchDepth(), this.randomSeed);
        SplittableRandom splittableRandom = new SplittableRandom(this.randomSeed);
        this.progressTracker.endSubTask("Prepare batches");
        this.progressTracker.beginSubTask("Train model");
        boolean z = false;
        ArrayList arrayList2 = new ArrayList();
        double d = Double.NaN;
        int epochs = this.parameters.epochs();
        boolean z2 = this.parameters.batchesPerIteration(graph.nodeCount()) * this.parameters.maxIterations() > extendedBatches.size();
        for (int i = 1; i <= epochs && !z; i++) {
            this.progressTracker.beginSubTask("Epoch");
            long j = i + this.randomSeed;
            if (z2) {
                List list = (List) extendedBatches.stream().map(jArr -> {
                    return createBatchTask(jArr, graph, hugeObjectArray, this.layers, arrayList, j);
                }).collect(Collectors.toList());
                supplier = () -> {
                    return (List) IntStream.range(0, this.parameters.batchesPerIteration(graph.nodeCount())).mapToObj(i2 -> {
                        return (BatchTask) list.get(splittableRandom.nextInt(list.size()));
                    }).collect(Collectors.toList());
                };
            } else {
                supplier = () -> {
                    return (List) IntStream.range(0, this.parameters.batchesPerIteration(graph.nodeCount())).mapToObj(i2 -> {
                        return createBatchTask((long[]) extendedBatches.get(splittableRandom.nextInt(extendedBatches.size())), graph, hugeObjectArray, this.layers, arrayList, j);
                    }).collect(Collectors.toList());
                };
            }
            EpochResult trainEpoch = trainEpoch(supplier, arrayList, d);
            List<Double> losses = trainEpoch.losses();
            arrayList2.add(losses);
            d = losses.get(losses.size() - 1).doubleValue();
            z = trainEpoch.converged();
            this.progressTracker.endSubTask("Epoch");
        }
        this.progressTracker.endSubTask("Train model");
        return ModelTrainResult.of(arrayList2, z, this.layers);
    }

    private BatchTask createBatchTask(long[] jArr, Graph graph, HugeObjectArray<double[]> hugeObjectArray, Layer[] layerArr, ArrayList<Weights<? extends Tensor<?>>> arrayList, long j) {
        Graph concurrentCopy = graph.concurrentCopy();
        List<SubGraph> subGraphsPerLayer = GraphSageHelper.subGraphsPerLayer(concurrentCopy, jArr, layerArr, j);
        ElementSum graphSageLoss = new GraphSageLoss(SubGraph.relationshipWeightFunction(concurrentCopy), GraphSageHelper.embeddingsComputationGraph(subGraphsPerLayer, layerArr, this.featureFunction.apply(concurrentCopy, subGraphsPerLayer.get(subGraphsPerLayer.size() - 1).originalNodeIds(), hugeObjectArray)), jArr, this.parameters.negativeSampleWeight());
        return new BatchTask(this.parameters.penaltyL2() > 0.0d ? new ElementSum(List.of(graphSageLoss, new ConstantScale(new ElementSum((List) Arrays.stream(layerArr).map(layer -> {
            return layer.aggregator().weightsWithoutBias();
        }).flatMap(list -> {
            return list.stream().map((v1) -> {
                return new L2NormSquared(v1);
            });
        }).collect(Collectors.toList())), (this.parameters.penaltyL2() * (jArr.length / 3)) / graph.nodeCount()))) : graphSageLoss, arrayList, this.progressTracker);
    }

    private EpochResult trainEpoch(Supplier<List<BatchTask>> supplier, List<Weights<? extends Tensor<?>>> list, double d) {
        AdamOptimizer adamOptimizer = new AdamOptimizer(list, this.parameters.learningRate());
        int i = 1;
        ArrayList arrayList = new ArrayList();
        double d2 = d;
        boolean z = false;
        int maxIterations = this.parameters.maxIterations();
        while (true) {
            if (i > maxIterations) {
                break;
            }
            this.progressTracker.beginSubTask("Iteration");
            List<BatchTask> list2 = supplier.get();
            RunWithConcurrency.builder().concurrency(this.parameters.concurrency()).tasks(list2).executor(this.executor).run();
            double sum = list2.stream().mapToDouble((v0) -> {
                return v0.loss();
            }).sum() / list2.size();
            arrayList.add(Double.valueOf(sum));
            this.progressTracker.logInfo(StringFormatting.formatWithLocale("Average loss per node: %.10f", new Object[]{Double.valueOf(sum)}));
            if (Math.abs(d2 - sum) < this.parameters.tolerance()) {
                z = true;
                this.progressTracker.endSubTask("Iteration");
                break;
            }
            d2 = sum;
            adamOptimizer.update(TensorFunctions.averageTensors((List) list2.stream().map((v0) -> {
                return v0.weightGradients();
            }).collect(Collectors.toList())));
            this.progressTracker.endSubTask("Iteration");
            i++;
        }
        return ImmutableEpochResult.of(z, (List<Double>) arrayList);
    }
}
