package org.neo4j.gds.embeddings.graphsage.ddl4j;

import java.util.HashMap;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.PassthroughVariable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.TensorFactory;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/ComputationContext.class */
public class ComputationContext {
    private final Map<Variable<?>, Tensor<?>> data = new ConcurrentHashMap();
    private final Map<Variable<?>, Tensor<?>> gradients = new ConcurrentHashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/ComputationContext$BackPropTask.class */
    public static class BackPropTask {
        Variable<?> variable;
        Variable<?> child;

        BackPropTask(Variable<?> variable, Variable<?> variable2) {
            this.variable = variable;
            this.child = variable2;
        }
    }

    public Tensor<?> forward(Variable<?> variable) {
        for (Variable<?> variable2 : variable.parents()) {
            if (!this.data.containsKey(variable2)) {
                this.data.put(variable2, forward(variable2));
            }
        }
        return this.data.computeIfAbsent(variable, variable3 -> {
            return variable.apply(this);
        });
    }

    public Tensor<?> data(Variable<?> variable) {
        return this.data.get(variable);
    }

    public Tensor<?> gradient(Variable<?> variable) {
        return this.gradients.get(variable);
    }

    public void backward(Variable<?> variable) {
        if (variable.dimensions().length != 1 || data(variable).totalSize() != 1) {
            throw new IllegalArgumentException("Backward requires a variable with rank 1 and single dimension of size 1.");
        }
        this.gradients.clear();
        LinkedBlockingQueue linkedBlockingQueue = new LinkedBlockingQueue();
        PassthroughVariable passthroughVariable = new PassthroughVariable(variable);
        linkedBlockingQueue.add(new BackPropTask(variable, passthroughVariable));
        HashMap hashMap = new HashMap();
        initUpstream(passthroughVariable, hashMap);
        backward(linkedBlockingQueue, hashMap);
    }

    private void backward(Queue<BackPropTask> queue, Map<Variable<?>, AtomicInteger> map) {
        while (!queue.isEmpty()) {
            BackPropTask poll = queue.poll();
            Variable<?> variable = poll.variable;
            updateGradient(variable, poll.child.gradient(variable, this));
            map.get(variable).decrementAndGet();
            if (map.get(variable).get() == 0) {
                for (Variable<?> variable2 : variable.parents()) {
                    if (variable2.requireGradient()) {
                        queue.offer(new BackPropTask(variable2, variable));
                    }
                }
            }
        }
    }

    private void initUpstream(Variable<?> variable, Map<Variable<?>, AtomicInteger> map) {
        for (Variable<?> variable2 : variable.parents()) {
            if (variable2.requireGradient()) {
                if (!map.containsKey(variable2)) {
                    initUpstream(variable2, map);
                    map.put(variable2, new AtomicInteger(0));
                }
                map.get(variable2).incrementAndGet();
            }
        }
    }

    private void updateGradient(Variable<?> variable, Tensor<?> tensor) {
        this.gradients.putIfAbsent(variable, TensorFactory.constant(0.0d, variable.dimensions()));
        this.gradients.get(variable).addInPlace(tensor);
    }
}
