package org.neo4j.gds.ml.gradientdescent;

import java.util.List;
import java.util.PrimitiveIterator;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.models.Features;

/* loaded from: input_file:org/neo4j/gds/ml/gradientdescent/Objective.class */
public interface Objective<DATA> {
    List<Weights<? extends Tensor<?>>> weights();

    Variable<Scalar> loss(Batch batch, long j);

    DATA modelData();

    static Constant<Matrix> batchFeatureMatrix(Batch batch, Features features) {
        Matrix matrix = new Matrix(batch.size(), features.featureDimension());
        int i = 0;
        PrimitiveIterator.OfLong elementIds = batch.elementIds();
        while (elementIds.hasNext()) {
            int i2 = i;
            i++;
            matrix.setRow(i2, features.get(elementIds.nextLong()));
        }
        return new Constant<>(matrix);
    }
}
