package org.neo4j.gds.ml.linkmodels.logisticregression;

import org.apache.commons.lang3.mutable.MutableInt;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.MatrixConstant;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.Sigmoid;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.ml.Batch;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.NodeProperties;
import org.neo4j.graphalgo.api.nodeproperties.ValueType;

/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/logisticregression/LinkLogisticRegressionBase.class */
public class LinkLogisticRegressionBase {
    protected final LinkLogisticRegressionData modelData;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionBase$1, reason: invalid class name */
    /* loaded from: input_file:org/neo4j/gds/ml/linkmodels/logisticregression/LinkLogisticRegressionBase$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$neo4j$graphalgo$api$nodeproperties$ValueType = new int[ValueType.values().length];

        static {
            try {
                $SwitchMap$org$neo4j$graphalgo$api$nodeproperties$ValueType[ValueType.DOUBLE_ARRAY.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$neo4j$graphalgo$api$nodeproperties$ValueType[ValueType.DOUBLE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public LinkLogisticRegressionBase(LinkLogisticRegressionData linkLogisticRegressionData) {
        this.modelData = linkLogisticRegressionData;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Variable<Matrix> predictions(MatrixConstant matrixConstant) {
        return new Sigmoid(MatrixMultiplyWithTransposedSecondOperand.of(matrixConstant, this.modelData.weights()));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MatrixConstant features(Graph graph, Batch batch) {
        Graph concurrentCopy = graph.concurrentCopy();
        MutableInt mutableInt = new MutableInt();
        batch.nodeIds().forEach(l -> {
            mutableInt.add(graph.degree(l.longValue()));
        });
        int intValue = mutableInt.intValue();
        int numberOfFeatures = this.modelData.numberOfFeatures();
        double[] dArr = new double[intValue * numberOfFeatures];
        MutableInt mutableInt2 = new MutableInt();
        batch.nodeIds().forEach(l2 -> {
            concurrentCopy.forEachRelationship(l2.longValue(), (j, j2) -> {
                setLinkFeatures(this.modelData.linkFeatureCombiner().combine(nodeFeatures(graph, j), nodeFeatures(graph, j2)), dArr, mutableInt2.getValue().intValue());
                mutableInt2.increment();
                return true;
            });
        });
        for (int i = 0; i < intValue; i++) {
            dArr[((i * numberOfFeatures) + numberOfFeatures) - 1] = 1.0d;
        }
        return new MatrixConstant(dArr, intValue, numberOfFeatures);
    }

    private double[] nodeFeatures(Graph graph, long j) {
        double[] dArr = new double[this.modelData.numberOfFeatures()];
        MutableInt mutableInt = new MutableInt();
        this.modelData.nodePropertyKeys().forEach(str -> {
            NodeProperties nodeProperties = graph.nodeProperties(str);
            switch (AnonymousClass1.$SwitchMap$org$neo4j$graphalgo$api$nodeproperties$ValueType[nodeProperties.valueType().ordinal()]) {
                case 1:
                    double[] doubleArrayValue = nodeProperties.doubleArrayValue(j);
                    System.arraycopy(doubleArrayValue, 0, dArr, mutableInt.getValue().intValue(), doubleArrayValue.length);
                    mutableInt.add(doubleArrayValue.length);
                    return;
                case 2:
                    dArr[mutableInt.getValue().intValue()] = nodeProperties.doubleValue(j);
                    mutableInt.increment();
                    return;
                default:
                    throw new IllegalStateException("Link Logistic Regression requires double or double array node properties, not " + nodeProperties.valueType());
            }
        });
        return dArr;
    }

    private void setLinkFeatures(double[] dArr, double[] dArr2, int i) {
        int length = dArr.length;
        System.arraycopy(dArr, 0, dArr2, i * length, length);
    }
}
