package org.neo4j.gds.embeddings.graphsage.ddl4j.functions;

import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/functions/Softmax.class */
public class Softmax extends SingleParentVariable<Matrix> {
    private final int rows;
    private final int cols;

    public Softmax(Variable<?> variable) {
        super(variable, variable.dimensions());
        this.rows = dimension(0);
        this.cols = dimension(1);
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Matrix apply(ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(parent());
        Matrix zeros = matrix.zeros();
        boolean z = false;
        for (int i = 0; i < this.rows; i++) {
            double d = 1.0E-15d;
            for (int i2 = 0; i2 < this.cols; i2++) {
                int i3 = (i * this.cols) + i2;
                double exp = Math.exp(matrix.dataAt(i3));
                if (Double.isInfinite(exp)) {
                    z = true;
                    exp = Double.MAX_VALUE;
                }
                zeros.setDataAt(i3, exp);
                d += exp;
                if (Double.isInfinite(d)) {
                    z = true;
                    d = Double.MAX_VALUE;
                }
            }
            for (int i4 = 0; i4 < this.cols; i4++) {
                int i5 = (i * this.cols) + i4;
                zeros.setDataAt(i5, zeros.dataAt(i5) / d);
            }
        }
        if (z) {
            rescale(zeros);
        }
        return zeros;
    }

    private void rescale(Matrix matrix) {
        for (int i = 0; i < this.rows; i++) {
            double d = 1.0E-15d;
            for (int i2 = 0; i2 < this.cols; i2++) {
                d += matrix.dataAt((i * this.cols) + i2);
            }
            for (int i3 = 0; i3 < this.cols; i3++) {
                int i4 = (i * this.cols) + i3;
                matrix.setDataAt(i4, matrix.dataAt(i4) / d);
            }
        }
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        Matrix matrix = (Matrix) computationContext.data(this);
        Matrix matrix2 = (Matrix) computationContext.gradient(this);
        Matrix fill = Matrix.fill(0.0d, this.rows, this.cols);
        for (int i = 0; i < this.rows; i++) {
            int i2 = 0;
            while (i2 < this.cols) {
                int i3 = (i * this.cols) + i2;
                double dataAt = matrix.dataAt(i3);
                int i4 = 0;
                while (i4 < this.cols) {
                    fill.addDataAt(i3, matrix.dataAt((i * this.cols) + i4) * ((i2 == i4 ? 1 : 0) - dataAt) * matrix2.dataAt((i * this.cols) + i4));
                    i4++;
                }
                i2++;
            }
        }
        return fill;
    }
}
