package org.neo4j.gds.ml.gradientdescent;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.optimizer.AdamOptimizer;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.TensorFunctions;
import org.neo4j.gds.ml.splitting.EdgeSplitter;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/gradientdescent/Training.class */
public class Training {
    private final GradientDescentConfig config;
    private final ProgressTracker progressTracker;
    private final long trainSize;
    private final TerminationFlag terminationFlag;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/ml/gradientdescent/Training$ObjectiveUpdateConsumer.class */
    public static class ObjectiveUpdateConsumer implements Consumer<Batch> {
        private final Objective<?> objective;
        private final long trainSize;
        private final List<? extends Tensor<?>> summedWeightGradients;
        private int consumedBatches = 0;
        private double lossSum = EdgeSplitter.NEGATIVE;

        ObjectiveUpdateConsumer(Objective<?> objective, long j) {
            this.objective = objective;
            this.trainSize = j;
            this.summedWeightGradients = (List) objective.weights().stream().map(weights -> {
                return weights.data().createWithSameDimensions();
            }).collect(Collectors.toList());
        }

        @Override // java.util.function.Consumer
        public void accept(Batch batch) {
            Variable<Scalar> loss = this.objective.loss(batch, this.trainSize);
            ComputationContext computationContext = new ComputationContext();
            this.lossSum += computationContext.forward(loss).value();
            computationContext.backward(loss);
            Stream<Weights<? extends Tensor<?>>> stream = this.objective.weights().stream();
            Objects.requireNonNull(computationContext);
            List list = (List) stream.map((v1) -> {
                return r1.gradient(v1);
            }).collect(Collectors.toList());
            for (int i = 0; i < this.summedWeightGradients.size(); i++) {
                this.summedWeightGradients.get(i).addInPlace((Tensor) list.get(i));
            }
            this.consumedBatches++;
        }

        List<? extends Tensor<?>> summedWeightGradients() {
            return this.summedWeightGradients;
        }

        int consumedBatches() {
            return this.consumedBatches;
        }

        double lossSum() {
            return this.lossSum;
        }
    }

    public Training(GradientDescentConfig gradientDescentConfig, ProgressTracker progressTracker, long j, TerminationFlag terminationFlag) {
        this.config = gradientDescentConfig;
        this.progressTracker = progressTracker;
        this.trainSize = j;
        this.terminationFlag = terminationFlag;
    }

    public static MemoryEstimation memoryEstimation(int i, int i2) {
        return memoryEstimation(MemoryRange.of(i), i2);
    }

    public static MemoryEstimation memoryEstimation(MemoryRange memoryRange, int i) {
        return MemoryEstimations.builder(Training.class.getSimpleName()).add(MemoryEstimations.of("updater", memoryRange.apply(j -> {
            return AdamOptimizer.sizeInBytes(i, Math.toIntExact(j));
        }))).perThread("weight gradients", memoryRange.apply(j2 -> {
            return Weights.sizeInBytes(i, Math.toIntExact(j2));
        })).build();
    }

    public void train(Objective<?> objective, Supplier<BatchQueue> supplier, int i) {
        AdamOptimizer adamOptimizer = new AdamOptimizer(objective.weights(), this.config.learningRate());
        TrainingStopper defaultStopper = TrainingStopper.defaultStopper(this.config);
        ArrayList arrayList = new ArrayList();
        List<ObjectiveUpdateConsumer> executeBatches = executeBatches(i, objective, supplier.get());
        List<? extends Tensor<? extends Tensor<?>>> avgWeightGradients = avgWeightGradients(executeBatches);
        double d = totalLoss(executeBatches);
        this.progressTracker.logMessage(StringFormatting.formatWithLocale("Initial loss %s", new Object[]{Double.valueOf(d)}));
        while (!defaultStopper.terminated()) {
            this.terminationFlag.assertRunning();
            adamOptimizer.update(avgWeightGradients);
            List<ObjectiveUpdateConsumer> executeBatches2 = executeBatches(i, objective, supplier.get());
            avgWeightGradients = avgWeightGradients(executeBatches2);
            double d2 = totalLoss(executeBatches2);
            arrayList.add(Double.valueOf(d2));
            defaultStopper.registerLoss(d2);
            this.progressTracker.logMessage(StringFormatting.formatWithLocale("Epoch %d with loss %s", new Object[]{Integer.valueOf(arrayList.size()), Double.valueOf(d2)}));
        }
        ProgressTracker progressTracker = this.progressTracker;
        Object[] objArr = new Object[6];
        objArr[0] = defaultStopper.converged() ? "converged" : "terminated";
        objArr[1] = Integer.valueOf(arrayList.size());
        objArr[2] = Integer.valueOf(this.config.maxEpochs());
        objArr[3] = Double.valueOf(d);
        objArr[4] = arrayList.get(arrayList.size() - 1);
        objArr[5] = defaultStopper.converged() ? "" : " Did not converge";
        progressTracker.logMessage(StringFormatting.formatWithLocale("%s after %d out of %d epochs. Initial loss: %s, Last loss: %s.%s", objArr));
    }

    private List<ObjectiveUpdateConsumer> executeBatches(int i, Objective<?> objective, BatchQueue batchQueue) {
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new ObjectiveUpdateConsumer(objective, this.trainSize));
        }
        batchQueue.parallelConsume(i, arrayList, this.terminationFlag);
        return arrayList;
    }

    private List<? extends Tensor<? extends Tensor<?>>> avgWeightGradients(List<ObjectiveUpdateConsumer> list) {
        return TensorFunctions.averageTensors((List) list.stream().map((v0) -> {
            return v0.summedWeightGradients();
        }).collect(Collectors.toList()), list.stream().mapToInt((v0) -> {
            return v0.consumedBatches();
        }).sum());
    }

    private double totalLoss(List<ObjectiveUpdateConsumer> list) {
        return list.stream().mapToDouble((v0) -> {
            return v0.lossSum();
        }).sum();
    }
}
