package org.neo4j.gds.ml.models.logisticregression;

import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.TrainingMethod;
import org.neo4j.gds.ml.models.logisticregression.ImmutableLogisticRegressionData;

@ValueClass
/* loaded from: input_file:org/neo4j/gds/ml/models/logisticregression/LogisticRegressionData.class */
public interface LogisticRegressionData extends Classifier.ClassifierData {
    Weights<Matrix> weights();

    Weights<Vector> bias();

    @Override // org.neo4j.gds.ml.models.Classifier.ClassifierData
    @Value.Derived
    default TrainingMethod trainerMethod() {
        return TrainingMethod.LogisticRegression;
    }

    @Override // org.neo4j.gds.ml.models.Classifier.ClassifierData
    @Value.Derived
    default int featureDimension() {
        return weights().dimension(1);
    }

    static LogisticRegressionData standard(int i, LocalIdMap localIdMap) {
        return create(localIdMap.size(), i, localIdMap);
    }

    static LogisticRegressionData withReducedClassCount(int i, LocalIdMap localIdMap) {
        return create(localIdMap.size() - 1, i, localIdMap);
    }

    private static LogisticRegressionData create(int i, int i2, LocalIdMap localIdMap) {
        Weights<Matrix> ofMatrix = Weights.ofMatrix(i, i2);
        return ImmutableLogisticRegressionData.builder().weights(ofMatrix).classIdMap(localIdMap).bias(new Weights<>(new Vector(i))).build();
    }

    static MemoryEstimation memoryEstimation(boolean z, int i, MemoryRange memoryRange) {
        int i2 = z ? i - 1 : i;
        return MemoryEstimations.builder("Logistic regression model data", LogisticRegressionData.class).add("classIdMap", LocalIdMap.memoryEstimation(i)).fixed("weights", memoryRange.apply(j -> {
            return Weights.sizeInBytes(i2, Math.toIntExact(j));
        })).fixed("bias", Weights.sizeInBytes(i2, 1)).build();
    }

    static ImmutableLogisticRegressionData.Builder builder() {
        return ImmutableLogisticRegressionData.builder();
    }
}
