package org.neo4j.gds.ml.models.linearregression;

import java.util.List;
import org.apache.commons.lang3.mutable.MutableInt;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.ConstantScale;
import org.neo4j.gds.ml.core.functions.ElementSum;
import org.neo4j.gds.ml.core.functions.L2NormSquared;
import org.neo4j.gds.ml.core.functions.MeanSquareError;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.ml.gradientdescent.Objective;
import org.neo4j.gds.ml.models.Features;

/* loaded from: input_file:org/neo4j/gds/ml/models/linearregression/LinearRegressionObjective.class */
public class LinearRegressionObjective implements Objective<LinearRegressionData> {
    private final Features features;
    private final HugeDoubleArray targets;
    private final LinearRegressionData modelData;
    private final double penalty;

    @Override // org.neo4j.gds.ml.gradientdescent.Objective
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(this.modelData.weights(), this.modelData.bias());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public LinearRegressionObjective(Features features, HugeDoubleArray hugeDoubleArray, double d) {
        this.features = features;
        this.targets = hugeDoubleArray;
        this.modelData = LinearRegressionData.of(features.featureDimension());
        this.penalty = d;
    }

    @Override // org.neo4j.gds.ml.gradientdescent.Objective
    public Variable<Scalar> loss(Batch batch, long j) {
        return new ElementSum(List.of(new MeanSquareError(new LinearRegressor(this.modelData).predictionsVariable(Objective.batchFeatureMatrix(batch, this.features)), batchTargets(batch)), penaltyForBatch(batch, j)));
    }

    private Variable<Scalar> penaltyForBatch(Batch batch, long j) {
        return new ConstantScale(new L2NormSquared(modelData().weights()), (batch.size() * this.penalty) / j);
    }

    private Constant<Vector> batchTargets(Batch batch) {
        Vector vector = new Vector(batch.size());
        MutableInt mutableInt = new MutableInt();
        batch.nodeIds().forEach(l -> {
            vector.setDataAt(mutableInt.getAndIncrement(), this.targets.get(l.longValue()));
        });
        return new Constant<>(vector);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.gradientdescent.Objective
    public LinearRegressionData modelData() {
        return this.modelData;
    }
}
