package org.neo4j.gds.ml.nodemodels.logisticregression;

import java.util.ArrayList;
import java.util.List;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.Predictor;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.features.BiasFeature;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.features.FeatureExtractor;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.Softmax;
import org.neo4j.gds.ml.core.tensor.Matrix;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/logisticregression/NodeLogisticRegressionPredictor.class */
public class NodeLogisticRegressionPredictor implements Predictor<Matrix, NodeLogisticRegressionData> {
    private final NodeLogisticRegressionData modelData;
    private final List<String> featureProperties;

    public static long sizeOfPredictionsVariableInBytes(int i, int i2, int i3) {
        int[] matrix = Dimensions.matrix(i, i2);
        int[] matrix2 = Dimensions.matrix(i3, i2);
        return sizeOfFeatureExtractorsInBytes(i2) + Constant.sizeInBytes(matrix) + MatrixMultiplyWithTransposedSecondOperand.sizeInBytes(matrix, matrix2) + Softmax.sizeInBytes(matrix[0], matrix2[0]);
    }

    private static long sizeOfFeatureExtractorsInBytes(int i) {
        return FeatureExtraction.memoryUsageInBytes(i) + MemoryUsage.sizeOfInstance(BiasFeature.class);
    }

    public NodeLogisticRegressionPredictor(NodeLogisticRegressionData nodeLogisticRegressionData, List<String> list) {
        this.modelData = nodeLogisticRegressionData;
        this.featureProperties = list;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.Predictor
    public NodeLogisticRegressionData modelData() {
        return this.modelData;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.Predictor
    public Matrix predict(Graph graph, Batch batch) {
        return new ComputationContext().forward(predictionsVariable(graph, batch));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Variable<Matrix> predictionsVariable(Graph graph, Batch batch) {
        return new Softmax(MatrixMultiplyWithTransposedSecondOperand.of(features(graph, batch), this.modelData.weights()));
    }

    private Constant<Matrix> features(Graph graph, Batch batch) {
        return FeatureExtraction.extract(batch, featureExtractors(graph));
    }

    private List<FeatureExtractor> featureExtractors(Graph graph) {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(FeatureExtraction.propertyExtractors(graph, this.featureProperties));
        arrayList.add(new BiasFeature());
        return arrayList;
    }
}
