package org.neo4j.gds.ml.nodemodels.logisticregression;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.Sigmoid;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.ml.Predictor;
import org.neo4j.gds.ml.batch.Batch;
import org.neo4j.gds.ml.features.BiasFeature;
import org.neo4j.gds.ml.features.FeatureExtraction;
import org.neo4j.gds.ml.features.FeatureExtractor;
import org.neo4j.graphalgo.api.Graph;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/logisticregression/NodeLogisticRegressionPredictor.class */
public class NodeLogisticRegressionPredictor implements Predictor<List<Double>, NodeLogisticRegressionData> {
    private final NodeLogisticRegressionData modelData;

    /* JADX INFO: Access modifiers changed from: package-private */
    public NodeLogisticRegressionPredictor(NodeLogisticRegressionData nodeLogisticRegressionData) {
        this.modelData = nodeLogisticRegressionData;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.Predictor
    public NodeLogisticRegressionData modelData() {
        return this.modelData;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.Predictor
    public List<Double> predict(Graph graph, Batch batch) {
        return (List) Arrays.stream(new ComputationContext().forward(predictionsVariable(graph, batch)).data()).boxed().collect(Collectors.toList());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Variable<Matrix> predictionsVariable(Graph graph, Batch batch) {
        return new Sigmoid(MatrixMultiplyWithTransposedSecondOperand.of(features(graph, batch, this.modelData.nodePropertyKeys()), this.modelData.weights()));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Variable<Matrix> features(Graph graph, Batch batch, List<String> list) {
        return FeatureExtraction.extract(batch, featureExtractors(graph, list));
    }

    private List<FeatureExtractor> featureExtractors(Graph graph, List<String> list) {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(FeatureExtraction.propertyExtractors(graph, list));
        arrayList.add(new BiasFeature());
        return arrayList;
    }
}
