package org.neo4j.gds.embeddings.graphsage.ddl4j.functions;

import java.util.ArrayList;
import java.util.Map;
import java.util.stream.IntStream;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.mult.MatrixMatrixMult_DDRM;
import org.neo4j.gds.embeddings.graphsage.ddl4j.AbstractVariable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;
import org.neo4j.graphalgo.NodeLabel;
import org.neo4j.graphalgo.core.utils.paged.HugeObjectArray;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/functions/LabelwiseFeatureProjection.class */
public class LabelwiseFeatureProjection extends AbstractVariable<Matrix> {
    private final long[] nodeIds;
    private final HugeObjectArray<double[]> features;
    private final Map<NodeLabel, Weights<? extends Tensor<?>>> weightsByLabel;
    private final int projectedFeatureSize;
    private final NodeLabel[] labels;

    public LabelwiseFeatureProjection(long[] jArr, HugeObjectArray<double[]> hugeObjectArray, Map<NodeLabel, Weights<? extends Tensor<?>>> map, int i, NodeLabel[] nodeLabelArr) {
        super(new ArrayList(map.values()), new int[]{jArr.length, i});
        this.nodeIds = jArr;
        this.features = hugeObjectArray;
        this.weightsByLabel = map;
        this.projectedFeatureSize = i;
        this.labels = nodeLabelArr;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Matrix apply(ComputationContext computationContext) {
        double[] dArr = new double[this.nodeIds.length * this.projectedFeatureSize];
        IntStream.range(0, this.nodeIds.length).forEach(i -> {
            long j = this.nodeIds[i];
            Weights<? extends Tensor<?>> weights = this.weightsByLabel.get(this.labels[i]);
            double[] dArr2 = (double[]) this.features.get(j);
            DMatrixRMaj wrap = DMatrixRMaj.wrap(weights.dimension(0), weights.dimension(1), weights.data().data());
            DMatrixRMaj wrap2 = DMatrixRMaj.wrap(1, dArr2.length, dArr2);
            DMatrixRMaj dMatrixRMaj = new DMatrixRMaj(weights.dimension(0), 1);
            MatrixMatrixMult_DDRM.multTransB(wrap, wrap2, dMatrixRMaj);
            System.arraycopy(dMatrixRMaj.getData(), 0, dArr, i * this.projectedFeatureSize, this.projectedFeatureSize);
        });
        return new Matrix(dArr, this.nodeIds.length, this.projectedFeatureSize);
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        double[] data = computationContext.gradient(this).data();
        int dimension = variable.dimension(0);
        int dimension2 = variable.dimension(1);
        double[] dArr = new double[dimension * dimension2];
        IntStream.range(0, this.nodeIds.length).forEach(i -> {
            long j = this.nodeIds[i];
            if (this.weightsByLabel.get(this.labels[i]) == variable) {
                double[] dArr2 = (double[]) this.features.get(j);
                for (int i = 0; i < dimension; i++) {
                    for (int i2 = 0; i2 < dimension2; i2++) {
                        int i3 = (i * dimension2) + i2;
                        dArr[i3] = dArr[i3] + (dArr2[i2] * data[(i * dimension(1)) + i]);
                    }
                }
            }
        });
        return new Matrix(dArr, dimension, dimension2);
    }
}
