package org.neo4j.gds.ml.linkmodels.logisticregression;

import java.util.List;
import org.immutables.value.Value;
import org.neo4j.gds.ml.LinkFeatureCombiner;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.linkmodels.logisticregression.ImmutableLinkLogisticRegressionData;
import org.neo4j.graphalgo.annotation.ValueClass;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.utils.mem.MemoryEstimation;
import org.neo4j.graphalgo.core.utils.mem.MemoryEstimations;
import org.neo4j.graphalgo.core.utils.mem.MemoryUsage;

@ValueClass
/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/logisticregression/LinkLogisticRegressionData.class */
public interface LinkLogisticRegressionData {
    static MemoryEstimation memoryEstimation(int i) {
        return MemoryEstimations.builder("model data").fixed("instance", MemoryUsage.sizeOfInstance(ImmutableLinkLogisticRegressionData.class)).fixed("weights", Weights.sizeInBytes(1, i)).build();
    }

    Weights<Matrix> weights();

    LinkFeatureCombiner linkFeatureCombiner();

    @Value.Derived
    default int linkFeatureDimension() {
        return linkFeatureCombiner().linkFeatureDimension(nodeFeatureDimension());
    }

    int nodeFeatureDimension();

    static LinkLogisticRegressionData from(Graph graph, List<String> list, LinkFeatureCombiner linkFeatureCombiner) {
        int featureCount = FeatureExtraction.featureCount(FeatureExtraction.propertyExtractors(graph, list));
        return builder().weights(Weights.ofMatrix(1, linkFeatureCombiner.linkFeatureDimension(featureCount))).linkFeatureCombiner(linkFeatureCombiner).nodeFeatureDimension(featureCount).build();
    }

    static ImmutableLinkLogisticRegressionData.Builder builder() {
        return ImmutableLinkLogisticRegressionData.builder();
    }
}
