package org.neo4j.gds.ml.models.logisticregression;

import java.io.Serializable;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.logisticregression.ImmutableLogisticRegressionData;

@ValueClass
/* loaded from: input_file:org/neo4j/gds/ml/models/logisticregression/LogisticRegressionData.class */
public interface LogisticRegressionData extends Classifier.ClassifierData, Serializable {
    Weights<Matrix> weights();

    Weights<Vector> bias();

    @Override // org.neo4j.gds.ml.models.BaseModelData
    @Value.Derived
    default TrainingMethod trainerMethod() {
        return TrainingMethod.LogisticRegression;
    }

    @Override // org.neo4j.gds.ml.models.BaseModelData
    @Value.Derived
    default int featureDimension() {
        return weights().dimension(1);
    }

    static LogisticRegressionData standard(int i, int i2) {
        return create(i2, i, false);
    }

    static LogisticRegressionData withReducedClassCount(int i, int i2) {
        return create(i2, i, true);
    }

    private static LogisticRegressionData create(int i, int i2, boolean z) {
        int i3 = z ? i - 1 : i;
        return ImmutableLogisticRegressionData.builder().weights(Weights.ofMatrix(i3, i2)).numberOfClasses(i).bias(new Weights<>(new Vector(i3))).build();
    }

    static MemoryEstimation memoryEstimation(boolean z, int i, MemoryRange memoryRange) {
        int i2 = z ? i - 1 : i;
        return MemoryEstimations.builder("Logistic regression model data").fixed("weights", memoryRange.apply(j -> {
            return Weights.sizeInBytes(i2, Math.toIntExact(j));
        })).fixed("bias", Weights.sizeInBytes(i2, 1)).build();
    }

    static ImmutableLogisticRegressionData.Builder builder() {
        return ImmutableLogisticRegressionData.builder();
    }
}
