package org.neo4j.gds.ml;

import java.util.concurrent.atomic.DoubleAdder;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Scalar;
import org.neo4j.graphalgo.utils.StringFormatting;
import org.neo4j.logging.Log;

/* loaded from: input_file:org/neo4j/gds/ml/Training.class */
public class Training {
    private final TrainingSettings settings;
    private final Log log;
    private final long trainSize;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/ml/Training$LossEvalConsumer.class */
    public static class LossEvalConsumer implements Consumer<Batch> {
        private final Objective objective;
        private final DoubleAdder totalLoss;
        private final long trainSize;

        LossEvalConsumer(Objective objective, DoubleAdder doubleAdder, long j) {
            this.objective = objective;
            this.totalLoss = doubleAdder;
            this.trainSize = j;
        }

        @Override // java.util.function.Consumer
        public void accept(Batch batch) {
            Variable<Scalar> loss = this.objective.loss(batch, this.trainSize);
            this.totalLoss.add(new ComputationContext().forward(loss).value());
        }
    }

    /* loaded from: input_file:org/neo4j/gds/ml/Training$ObjectiveUpdateConsumer.class */
    static class ObjectiveUpdateConsumer implements Consumer<Batch> {
        private final Objective objective;
        private final Updater updater;
        private final long trainSize;

        ObjectiveUpdateConsumer(Objective objective, Updater updater, long j) {
            this.objective = objective;
            this.updater = updater;
            this.trainSize = j;
        }

        @Override // java.util.function.Consumer
        public void accept(Batch batch) {
            Variable<Scalar> loss = this.objective.loss(batch, this.trainSize);
            ComputationContext computationContext = new ComputationContext();
            computationContext.forward(loss);
            computationContext.backward(loss);
            this.updater.update(computationContext);
        }
    }

    public Training(TrainingSettings trainingSettings, Log log, long j) {
        this.settings = trainingSettings;
        this.log = log;
        this.trainSize = j;
    }

    public void train(Objective objective, Supplier<BatchQueue> supplier, int i) {
        Updater updater = this.settings.sharedUpdater() ? this.settings.updater(objective.weights()) : null;
        Updater[] updaterArr = null;
        if (!this.settings.sharedUpdater()) {
            updaterArr = new Updater[i];
            for (int i2 = 0; i2 < i; i2++) {
                updaterArr[i2] = this.settings.updater(objective.weights());
            }
        }
        int i3 = 0;
        TrainingStopper stopper = this.settings.stopper();
        double evaluateLoss = evaluateLoss(objective, supplier.get(), i);
        double d = evaluateLoss;
        while (!stopper.terminated()) {
            trainEpoch(this.settings, objective, supplier.get(), i, updater, updaterArr);
            d = evaluateLoss(objective, supplier.get(), i);
            stopper.registerLoss(d);
            i3++;
            this.log.debug(StringFormatting.formatWithLocale("Loss: %s, After Epoch: %d", new Object[]{Double.valueOf(d), Integer.valueOf(i3)}));
        }
        Log log = this.log;
        Object[] objArr = new Object[5];
        objArr[0] = stopper.converged() ? "converged" : "terminated";
        objArr[1] = Integer.valueOf(i3);
        objArr[2] = Double.valueOf(evaluateLoss);
        objArr[3] = Double.valueOf(d);
        objArr[4] = stopper.converged() ? "" : "Did not converge";
        log.debug(StringFormatting.formatWithLocale("Training %s after %d epochs. Initial loss: %s, Last loss: %s.%s", objArr));
    }

    private double evaluateLoss(Objective objective, BatchQueue batchQueue, int i) {
        DoubleAdder doubleAdder = new DoubleAdder();
        batchQueue.parallelConsume(new LossEvalConsumer(objective, doubleAdder, this.trainSize), i);
        return doubleAdder.doubleValue();
    }

    private void trainEpoch(TrainingSettings trainingSettings, Objective objective, BatchQueue batchQueue, int i, Updater updater, Updater[] updaterArr) {
        batchQueue.parallelConsume(i, i2 -> {
            return new ObjectiveUpdateConsumer(objective, trainingSettings.sharedUpdater() ? updater : updaterArr[i2], this.trainSize);
        });
    }
}
