package org.neo4j.gds.ml.linkmodels.logisticregression;

import java.util.List;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.features.FeatureExtractor;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.Sigmoid;

/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/logisticregression/LinkLogisticRegressionPredictor.class */
public class LinkLogisticRegressionPredictor extends LinkLogisticRegressionBase {
    public LinkLogisticRegressionPredictor(LinkLogisticRegressionData linkLogisticRegressionData, List<String> list, List<FeatureExtractor> list2) {
        super(linkLogisticRegressionData, list, list2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static long sizeOfBatchInBytes(int i, int i2) {
        int[] matrix = Dimensions.matrix(i, i2);
        return sizeOfFeatureExtractorsInBytes(i2) + Constant.sizeInBytes(matrix) + MatrixMultiplyWithTransposedSecondOperand.sizeInBytes(matrix, Dimensions.matrix(1, i2)) + Sigmoid.sizeInBytes(i, 1);
    }

    private static long sizeOfFeatureExtractorsInBytes(int i) {
        return FeatureExtraction.memoryUsageInBytes(i);
    }

    public LinkLogisticRegressionData modelData() {
        return this.modelData;
    }

    public double predictedProbability(long j, long j2) {
        double[] data = this.modelData.weights().data().data();
        double[] features = features(j, j2);
        double d = 0.0d;
        int length = data.length - 1;
        for (int i = 0; i < length; i++) {
            d += data[i] * features[i];
        }
        return Sigmoid.sigmoid(d + data[length]);
    }
}
