package org.neo4j.gds.ml.pipeline.node.classification.predict;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.WriteProc;
import org.neo4j.gds.api.properties.nodes.DoubleArrayNodePropertyValues;
import org.neo4j.gds.api.properties.nodes.LongNodePropertyValues;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.write.NodeProperty;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.executor.NewConfigFunction;
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
import org.neo4j.gds.ml.pipeline.node.classification.NodeClassificationPipelineCompanion;
import org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPredictPipelineExecutor;
import org.neo4j.gds.result.AbstractResultBuilder;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.gds.results.StandardWriteResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@GdsCallable(name = "gds.beta.pipeline.nodeClassification.predict.write", description = NodeClassificationPipelineCompanion.PREDICT_DESCRIPTION, executionMode = ExecutionMode.WRITE_NODE_PROPERTY)
/* loaded from: input_file:org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteProc.class */
public class NodeClassificationPipelineWriteProc extends WriteProc<NodeClassificationPredictPipelineExecutor, NodeClassificationPredictPipelineExecutor.NodeClassificationPipelineResult, WriteResult, NodeClassificationPredictPipelineWriteConfig> {

    /* loaded from: input_file:org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteProc$WriteResult.class */
    public static final class WriteResult extends StandardWriteResult {
        public final long nodePropertiesWritten;

        /* loaded from: input_file:org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPipelineWriteProc$WriteResult$Builder.class */
        static class Builder extends AbstractResultBuilder<WriteResult> {
            Builder() {
            }

            /* renamed from: build, reason: merged with bridge method [inline-methods] */
            public WriteResult m24build() {
                return new WriteResult(this.preProcessingMillis, this.computeMillis, this.writeMillis, this.nodePropertiesWritten, this.config.toMap());
            }
        }

        WriteResult(long j, long j2, long j3, long j4, Map<String, Object> map) {
            super(j, j2, 0L, j3, map);
            this.nodePropertiesWritten = j4;
        }
    }

    @Procedure(name = "gds.beta.pipeline.nodeClassification.predict.write", mode = Mode.WRITE)
    @Description(NodeClassificationPipelineCompanion.PREDICT_DESCRIPTION)
    public Stream<WriteResult> write(@Name("graphName") String str, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) {
        PipelineCompanion.preparePipelineConfig(str, map);
        return write(compute(str, map));
    }

    @Procedure(name = "gds.beta.pipeline.nodeClassification.predict.write.estimate", mode = Mode.READ)
    @Description(NodeClassificationPipelineCompanion.ESTIMATE_PREDICT_DESCRIPTION)
    public Stream<MemoryEstimateResult> estimate(@Name("graphNameOrConfiguration") Object obj, @Name("algoConfiguration") Map<String, Object> map) {
        PipelineCompanion.preparePipelineConfig(obj, map);
        return computeEstimate(obj, map);
    }

    public AlgorithmSpec<NodeClassificationPredictPipelineExecutor, NodeClassificationPredictPipelineExecutor.NodeClassificationPipelineResult, NodeClassificationPredictPipelineWriteConfig, Stream<WriteResult>, AlgorithmFactory<?, NodeClassificationPredictPipelineExecutor, NodeClassificationPredictPipelineWriteConfig>> withModelCatalog(ModelCatalog modelCatalog) {
        setModelCatalog(modelCatalog);
        return this;
    }

    protected List<NodeProperty> nodePropertyList(ComputationResult<NodeClassificationPredictPipelineExecutor, NodeClassificationPredictPipelineExecutor.NodeClassificationPipelineResult, NodeClassificationPredictPipelineWriteConfig> computationResult) {
        NodeClassificationPredictPipelineWriteConfig nodeClassificationPredictPipelineWriteConfig = (NodeClassificationPredictPipelineWriteConfig) computationResult.config();
        String writeProperty = nodeClassificationPredictPipelineWriteConfig.writeProperty();
        NodeClassificationPredictPipelineExecutor.NodeClassificationPipelineResult nodeClassificationPipelineResult = (NodeClassificationPredictPipelineExecutor.NodeClassificationPipelineResult) computationResult.result();
        LongNodePropertyValues asNodeProperties = nodeClassificationPipelineResult.predictedClasses().asNodeProperties();
        ArrayList arrayList = new ArrayList();
        arrayList.add(NodeProperty.of(writeProperty, asNodeProperties));
        nodeClassificationPipelineResult.predictedProbabilities().ifPresent(hugeObjectArray -> {
            arrayList.add(NodeProperty.of(nodeClassificationPredictPipelineWriteConfig.predictedProbabilityProperty().orElseThrow(), new DoubleArrayNodePropertyValues() { // from class: org.neo4j.gds.ml.pipeline.node.classification.predict.NodeClassificationPipelineWriteProc.1
                public long size() {
                    return computationResult.graph().nodeCount();
                }

                public double[] doubleArrayValue(long j) {
                    return (double[]) hugeObjectArray.get(j);
                }
            }));
        });
        return arrayList;
    }

    protected AbstractResultBuilder<WriteResult> resultBuilder(ComputationResult<NodeClassificationPredictPipelineExecutor, NodeClassificationPredictPipelineExecutor.NodeClassificationPipelineResult, NodeClassificationPredictPipelineWriteConfig> computationResult, ExecutionContext executionContext) {
        return new WriteResult.Builder();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: newConfig, reason: merged with bridge method [inline-methods] */
    public NodeClassificationPredictPipelineWriteConfig m22newConfig(String str, CypherMapWrapper cypherMapWrapper) {
        return (NodeClassificationPredictPipelineWriteConfig) newConfigFunction().apply(str, cypherMapWrapper);
    }

    public NewConfigFunction<NodeClassificationPredictPipelineWriteConfig> newConfigFunction() {
        return new NodeClassificationPredictNewWriteConfigFn(modelCatalog());
    }

    /* renamed from: algorithmFactory, reason: merged with bridge method [inline-methods] */
    public GraphStoreAlgorithmFactory<NodeClassificationPredictPipelineExecutor, NodeClassificationPredictPipelineWriteConfig> m23algorithmFactory() {
        return new NodeClassificationPredictPipelineAlgorithmFactory(executionContext(), modelCatalog());
    }
}
