package org.neo4j.gds.embeddings.graphsage.ddl4j.functions;

import java.util.List;
import org.neo4j.gds.embeddings.graphsage.ddl4j.AbstractVariable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Dimensions;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Scalar;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/functions/CrossEntropyLoss.class */
public class CrossEntropyLoss extends AbstractVariable<Scalar> {
    private final Variable<Matrix> predictions;
    private final Variable<Matrix> targets;

    public CrossEntropyLoss(Variable<Matrix> variable, Variable<Matrix> variable2) {
        super(List.of(variable, variable2), Dimensions.scalar());
        this.predictions = variable;
        this.targets = variable2;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Scalar apply(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this.predictions);
        Matrix matrix2 = (Matrix) computationContext.data(this.targets);
        double d = 0.0d;
        for (int i = 0; i < matrix2.totalSize(); i++) {
            if (matrix.dataAt(i) > 0.0d) {
                d += matrix2.dataAt(i) * Math.log(matrix.dataAt(i));
            }
            if (matrix.dataAt(i) < 1.0d) {
                d += (1.0d - matrix2.dataAt(i)) * Math.log(1.0d - matrix.dataAt(i));
            }
        }
        return new Scalar((-d) / matrix2.totalSize());
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        if (variable != this.predictions) {
            return computationContext.data(variable).zeros();
        }
        Matrix matrix = (Matrix) computationContext.data(this.predictions);
        Matrix matrix2 = (Matrix) computationContext.data(this.targets);
        Matrix zeros = matrix.zeros();
        double d = 1.0d / zeros.totalSize();
        for (int i = 0; i < zeros.totalSize(); i++) {
            double dataAt = matrix.dataAt(i) > 0.0d ? 0.0d - (matrix2.dataAt(i) / matrix.dataAt(i)) : 0.0d;
            if (matrix.dataAt(i) < 1.0d) {
                dataAt += (1.0d - matrix2.dataAt(i)) / (1.0d - matrix.dataAt(i));
            }
            zeros.setDataAt(i, d * dataAt);
        }
        return zeros;
    }
}
