package org.neo4j.gds.embeddings.graphsage;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Random;
import java.util.TreeMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.PassthroughVariable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.Weights;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Scalar;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;
import org.neo4j.graphalgo.annotation.ValueClass;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.concurrency.ParallelUtil;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.paged.HugeObjectArray;
import org.neo4j.graphalgo.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer.class */
public class GraphSageModelTrainer {
    private Layer[] layers;
    private final boolean useWeights;
    private final BatchProvider batchProvider;
    private final double learningRate;
    private final double tolerance;
    private final int negativeSampleWeight;
    private final int concurrency;
    private final int epochs;
    private final int maxIterations;
    private final int maxSearchDepth;
    private final List<LayerConfig> layerConfigs;
    private final FeatureFunction featureFunction;
    private final Collection<Weights<? extends Tensor<?>>> labelProjectionWeights;
    private final ProgressLogger progressLogger;
    private double degreeProbabilityNormalizer;
    private RelationshipWeights relationshipWeights;

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageModelTrainer$ModelTrainResult.class */
    public interface ModelTrainResult {
        double startLoss();

        Map<String, Double> epochLosses();

        Layer[] layers();

        static ModelTrainResult of(double d, Map<String, Double> map, Layer[] layerArr) {
            return ImmutableModelTrainResult.builder().startLoss(d).epochLosses(map).layers(layerArr).build();
        }
    }

    public GraphSageModelTrainer(GraphSageTrainConfig graphSageTrainConfig, ProgressLogger progressLogger) {
        this(graphSageTrainConfig, progressLogger, GraphSageHelper::features, Collections.emptyList());
    }

    public GraphSageModelTrainer(GraphSageTrainConfig graphSageTrainConfig, ProgressLogger progressLogger, FeatureFunction featureFunction, Collection<Weights<? extends Tensor<?>>> collection) {
        this.layerConfigs = graphSageTrainConfig.layerConfigs();
        this.batchProvider = new BatchProvider(graphSageTrainConfig.batchSize());
        this.learningRate = graphSageTrainConfig.learningRate();
        this.tolerance = graphSageTrainConfig.tolerance();
        this.negativeSampleWeight = graphSageTrainConfig.negativeSampleWeight();
        this.concurrency = graphSageTrainConfig.concurrency();
        this.epochs = graphSageTrainConfig.epochs();
        this.maxIterations = graphSageTrainConfig.maxIterations();
        this.maxSearchDepth = graphSageTrainConfig.searchDepth();
        this.featureFunction = featureFunction;
        this.labelProjectionWeights = collection;
        this.progressLogger = progressLogger;
        this.useWeights = graphSageTrainConfig.relationshipWeightProperty() != null;
    }

    public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> hugeObjectArray) {
        RelationshipWeights relationshipWeights;
        this.progressLogger.logStart();
        TreeMap treeMap = new TreeMap();
        if (this.useWeights) {
            Objects.requireNonNull(graph);
            relationshipWeights = graph::relationshipProperty;
        } else {
            relationshipWeights = RelationshipWeights.UNWEIGHTED;
        }
        this.relationshipWeights = relationshipWeights;
        Optional of = this.useWeights ? Optional.of(this.relationshipWeights) : Optional.empty();
        this.layers = (Layer[]) this.layerConfigs.stream().map(layerConfig -> {
            return LayerFactory.createLayer(layerConfig, of);
        }).toArray(i -> {
            return new Layer[i];
        });
        this.degreeProbabilityNormalizer = LongStream.range(0L, graph.nodeCount()).mapToDouble(j -> {
            return Math.pow(graph.degree(j), 0.75d);
        }).sum();
        double evaluateLoss = evaluateLoss(graph, hugeObjectArray, this.batchProvider, -1);
        double d = evaluateLoss;
        for (int i2 = 0; i2 < this.epochs; i2++) {
            String str = ":: Epoch " + (i2 + 1);
            this.progressLogger.logStart(str);
            trainEpoch(graph, hugeObjectArray, i2);
            double evaluateLoss2 = evaluateLoss(graph, hugeObjectArray, this.batchProvider, i2);
            treeMap.put(StringFormatting.formatWithLocale("Epoch: %d", new Object[]{Integer.valueOf(i2)}), Double.valueOf(evaluateLoss2));
            this.progressLogger.logFinish(str);
            if (Math.abs((evaluateLoss2 - d) / d) < this.tolerance) {
                break;
            }
            d = evaluateLoss2;
        }
        this.progressLogger.logFinish();
        return ModelTrainResult.of(evaluateLoss, treeMap, this.layers);
    }

    private void trainEpoch(Graph graph, HugeObjectArray<double[]> hugeObjectArray, int i) {
        AdamOptimizer adamOptimizer = new AdamOptimizer(getWeights(), this.learningRate);
        AtomicInteger atomicInteger = new AtomicInteger(0);
        ParallelUtil.parallelStreamConsume(this.batchProvider.stream(graph), this.concurrency, stream -> {
            stream.forEach(jArr -> {
                trainOnBatch(jArr, graph, hugeObjectArray, adamOptimizer, i, atomicInteger.incrementAndGet());
            });
        });
    }

    private void trainOnBatch(long[] jArr, Graph graph, HugeObjectArray<double[]> hugeObjectArray, AdamOptimizer adamOptimizer, int i, int i2) {
        for (Layer layer : this.layers) {
            layer.generateNewRandomState();
        }
        Variable<Scalar> lossFunction = lossFunction(jArr, graph, hugeObjectArray);
        double d = Double.MAX_VALUE;
        this.progressLogger.getLog().debug("Epoch %d\tBatch %d, Initial loss: %.10f", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Double.valueOf(Double.MAX_VALUE)});
        int i3 = 0;
        while (true) {
            if (i3 >= this.maxIterations) {
                break;
            }
            this.progressLogger.logStart(":: Iteration " + (i3 + 1));
            double d2 = d;
            ComputationContext computationContext = new ComputationContext();
            d = computationContext.forward(lossFunction).dataAt(0);
            if (Math.abs((d2 - d) / d2) < this.tolerance) {
                this.progressLogger.logFinish(":: Iteration " + (i3 + 1));
                break;
            }
            computationContext.backward(lossFunction);
            adamOptimizer.update(computationContext);
            this.progressLogger.logFinish(":: Iteration " + (i3 + 1));
            i3++;
        }
        this.progressLogger.getLog().debug("Epoch %d\tBatch %d LOSS: %.10f at iteration %d", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Double.valueOf(d), Integer.valueOf(i3)});
    }

    private double evaluateLoss(Graph graph, HugeObjectArray<double[]> hugeObjectArray, BatchProvider batchProvider, int i) {
        DoubleAdder doubleAdder = new DoubleAdder();
        ParallelUtil.parallelStreamConsume(batchProvider.stream(graph), this.concurrency, stream -> {
            stream.forEach(jArr -> {
                doubleAdder.add(new ComputationContext().forward(lossFunction(jArr, graph, hugeObjectArray)).dataAt(0));
            });
        });
        double doubleValue = doubleAdder.doubleValue();
        this.progressLogger.getLog().debug("Loss after epoch %s: %s", new Object[]{Integer.valueOf(i), Double.valueOf(doubleValue)});
        return doubleValue;
    }

    private Variable<Scalar> lossFunction(long[] jArr, Graph graph, HugeObjectArray<double[]> hugeObjectArray) {
        long[] array = LongStream.concat(Arrays.stream(jArr), LongStream.concat(neighborBatch(graph, jArr), negativeBatch(graph, jArr.length))).toArray();
        return new PassthroughVariable(new GraphSageLoss(this.relationshipWeights, GraphSageHelper.embeddings(graph, array, hugeObjectArray, this.layers, this.featureFunction), array, this.negativeSampleWeight));
    }

    private LongStream neighborBatch(Graph graph, long[] jArr) {
        return Arrays.stream(jArr).map(j -> {
            int nextInt = ThreadLocalRandom.current().nextInt(this.maxSearchDepth) + 1;
            AtomicLong atomicLong = new AtomicLong(j);
            while (nextInt > 0) {
                OptionalLong sampleOne = (this.useWeights ? new WeightedNeighborhoodSampler() : new UniformNeighborhoodSampler()).sampleOne(graph, j, 0L);
                if (sampleOne.isPresent()) {
                    atomicLong.set(sampleOne.getAsLong());
                } else {
                    nextInt = 0;
                }
                nextInt--;
            }
            return atomicLong.get();
        });
    }

    private LongStream negativeBatch(Graph graph, int i) {
        Random random = new Random(this.layers[0].randomState());
        return IntStream.range(0, i).mapToLong(i2 -> {
            double nextDouble = random.nextDouble();
            double d = 0.0d;
            long j = 0;
            while (true) {
                long j2 = j;
                if (j2 >= graph.nodeCount()) {
                    throw new RuntimeException("This happens when there are no relationships in the Graph. This condition is checked by the calling procedure.");
                }
                d += Math.pow(graph.degree(j2), 0.75d) / this.degreeProbabilityNormalizer;
                if (nextDouble < d) {
                    return j2;
                }
                j = j2 + 1;
            }
        });
    }

    private List<Weights<? extends Tensor<?>>> getWeights() {
        ArrayList arrayList = new ArrayList(this.labelProjectionWeights);
        arrayList.addAll((Collection) Arrays.stream(this.layers).flatMap(layer -> {
            return layer.weights().stream();
        }).collect(Collectors.toList()));
        return arrayList;
    }
}
