package org.neo4j.gds.ml.linkmodels;

import java.util.function.Consumer;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionPredictor;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/SignedProbabilitiesCollector.class */
public class SignedProbabilitiesCollector implements Consumer<Batch> {
    private static final int POSITIVE_LINK = 1;
    private static final int NEGATIVE_LINK = 0;
    private final Graph graph;
    private final LinkLogisticRegressionPredictor predictor;
    private final SignedProbabilities signedProbabilities;
    private final ProgressTracker progressTracker;

    /* JADX INFO: Access modifiers changed from: package-private */
    public SignedProbabilitiesCollector(Graph graph, LinkLogisticRegressionPredictor linkLogisticRegressionPredictor, SignedProbabilities signedProbabilities, ProgressTracker progressTracker) {
        this.graph = graph;
        this.predictor = linkLogisticRegressionPredictor;
        this.signedProbabilities = signedProbabilities;
        this.progressTracker = progressTracker;
    }

    @Override // java.util.function.Consumer
    public void accept(Batch batch) {
        batch.nodeIds().forEach(l -> {
            this.graph.forEachRelationship(l.longValue(), -1.0d, (j, j2, d) -> {
                this.signedProbabilities.add(sign(d, j, j2) * this.predictor.predictedProbability(j, j2));
                return true;
            });
        });
        this.progressTracker.logProgress(batch.size());
    }

    private int sign(double d, long j, long j2) {
        switch ((int) d) {
            case 0:
                return -1;
            case POSITIVE_LINK /* 1 */:
                return POSITIVE_LINK;
            default:
                throw new IllegalArgumentException(StringFormatting.formatWithLocale("Invalid property value %.4f on relationship (%d,%d). Valid values are 0 and 1 which represent target classes for links.", new Object[]{Double.valueOf(d), Long.valueOf(j), Long.valueOf(j2)}));
        }
    }
}
