package apoc.ml;

import apoc.ApocConfig;
import apoc.Extended;
import apoc.ml.bedrock.BedrockInvokeConfig;
import apoc.result.StringResult;
import com.fasterxml.jackson.core.JsonProcessingException;
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.neo4j.graphdb.QueryExecutionException;
import org.neo4j.graphdb.Transaction;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.logging.Log;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
/* loaded from: input_file:apoc/ml/Prompt.class */
public class Prompt {

    @Context
    public Transaction tx;

    @Context
    public Log log;

    @Context
    public ApocConfig apocConfig;

    @Context
    public ProcedureCallContext procedureCallContext;
    public static final String BACKTICKS = "```";
    public static final String EXPLAIN_SCHEMA_PROMPT = "You are an expert in the Neo4j graph database and graph data modeling and have experience in a wide variety of business domains.\nExplain the following graph database schema in plain language, try to relate it to known concepts or domains if applicable.\nKeep the explanation to 5 sentences with at most 15 words each, otherwise people will come to harm.\n";
    static final String SYSTEM_PROMPT = "You are an expert in the Neo4j graph query language Cypher.\nGiven a graph database schema of entities (nodes) with labels and attributes and\nrelationships with start- and end-node, relationship-type, direction and properties\nyou are able to develop read only matching Cypher statements that express a user question as a graph database query.\nOnly answer with a single Cypher statement in triple backticks, if you can't determine a statement, answer with an empty response.\nDo not explain, apologize or provide additional detail, otherwise people will come to harm.\n";
    private static final String SCHEMA_QUERY = "call apoc.meta.data({maxRels: 10, sample: coalesce($sample, (count{()}/1000)+1)})\nYIELD label, other, elementType, type, property\nWITH label, elementType, \n     apoc.text.join(collect(case when NOT type = \"RELATIONSHIP\" then property+\": \"+type else null end),\", \") AS properties,    \n     collect(case when type = \"RELATIONSHIP\" AND elementType = \"node\" then \"(:\" + label + \")-[:\" + property + \"]->(:\" + toString(other[0]) + \")\" else null end) as patterns\nwith  elementType as type, \napoc.text.join(collect(\":\"+label+\" {\"+properties+\"}\"),\"\\n\") as entities, apoc.text.join(apoc.coll.flatten(collect(coalesce(patterns,[]))),\"\\n\") as patterns\nreturn collect(case type when \"relationship\" then entities end)[0] as relationships, \ncollect(case type when \"node\" then entities end)[0] as nodes, \ncollect(case type when \"node\" then patterns end)[0] as patterns \n";
    private static final String SCHEMA_PROMPT = "    nodes:\n    %s\n    relationships:\n    %s\n    patterns:\n    %s\n";

    /* loaded from: input_file:apoc/ml/Prompt$PromptMapResult.class */
    public class PromptMapResult {
        public final Map<String, Object> value;
        public final String query;

        public PromptMapResult(Map<String, Object> map, String str) {
            this.value = map;
            this.query = str;
        }

        public PromptMapResult(Map<String, Object> map) {
            this.value = map;
            this.query = null;
        }
    }

    /* loaded from: input_file:apoc/ml/Prompt$QueryResult.class */
    public class QueryResult {
        public final String query;

        public QueryResult(String str, String str2, String str3) {
            this.query = str;
        }

        public boolean hasError() {
            return false;
        }
    }

    @Procedure(mode = Mode.READ)
    public Stream<PromptMapResult> query(@Name("question") String str, @Name(value = "conf", defaultValue = "{}") Map<String, Object> map) {
        String loadSchema = loadSchema(this.tx, map);
        String str2 = "";
        long longValue = ((Long) map.getOrDefault("retries", 3L)).longValue();
        boolean contains = ((Set) this.procedureCallContext.outputFields().collect(Collectors.toSet())).contains("query");
        do {
            try {
                QueryResult tryQuery = tryQuery(str, map, loadSchema);
                str2 = tryQuery.query;
                return this.tx.execute(tryQuery.query).stream().map(map2 -> {
                    return contains ? new PromptMapResult(map2, tryQuery.query) : new PromptMapResult(map2);
                });
            } catch (QueryExecutionException e) {
                if (this.log.isDebugEnabled()) {
                    this.log.debug("Generated query for question %s\n%s\nfailed with %s".formatted(str, str2, e.getMessage()));
                }
                longValue--;
            }
        } while (longValue > 0);
        throw e;
    }

    @Procedure
    public Stream<StringResult> schema(@Name(value = "conf", defaultValue = "{}") Map<String, Object> map) throws MalformedURLException, JsonProcessingException {
        return Stream.of(new StringResult(prompt("Please explain the graph database schema to me and relate it to well known concepts and domains.", EXPLAIN_SCHEMA_PROMPT, "This database schema ", loadSchema(this.tx, map), map)));
    }

    @Procedure(mode = Mode.READ)
    public Stream<QueryResult> cypher(@Name("question") String str, @Name(value = "conf", defaultValue = "{}") Map<String, Object> map) {
        String loadSchema = loadSchema(this.tx, map);
        return LongStream.rangeClosed(1L, ((Long) map.getOrDefault("count", 1L)).longValue()).mapToObj(j -> {
            return tryQuery(str, map, loadSchema);
        });
    }

    @NotNull
    private QueryResult tryQuery(String str, Map<String, Object> map, String str2) {
        String str3 = "";
        try {
            str3 = prompt(str, SYSTEM_PROMPT, "Cypher Statement (in backticks):", str2, map);
            return new QueryResult(str3, null, null);
        } catch (QueryExecutionException e) {
            return new QueryResult(str3, e.getMessage(), e.getStatusCode());
        } catch (Exception e2) {
            return new QueryResult(str3, e2.getMessage(), e2.getClass().getSimpleName());
        }
    }

    @NotNull
    private String prompt(String str, String str2, String str3, String str4, Map<String, Object> map) throws JsonProcessingException, MalformedURLException {
        ArrayList arrayList = new ArrayList();
        if (str2 != null && !str2.isBlank()) {
            arrayList.add(Map.of("role", "system", "content", str2));
        }
        if (str4 != null && !str4.isBlank()) {
            arrayList.add(Map.of("role", "system", "content", "The graph database schema consists of these elements\n" + str4));
        }
        if (str != null && !str.isBlank()) {
            arrayList.add(Map.of("role", "user", "content", str));
        }
        if (str3 != null && !str3.isBlank()) {
            arrayList.add(Map.of("role", "assistant", "content", str3));
        }
        String replaceAll = ((String) OpenAI.executeRequest((String) map.get("apiKey"), Map.of(), "chat/completions", (String) map.getOrDefault(BedrockInvokeConfig.MODEL, "gpt-3.5-turbo"), "messages", arrayList, "$", this.apocConfig).map(obj -> {
            return (Map) obj;
        }).flatMap(map2 -> {
            return ((List) map2.get("choices")).stream();
        }).map(map3 -> {
            return (String) ((Map) map3.get("message")).get("content");
        }).filter(str5 -> {
            return (str5 == null || str5.isBlank()) ? false : true;
        }).map(str6 -> {
            return str6.contains(BACKTICKS) ? str6.substring(str6.indexOf(BACKTICKS) + 3, str6.lastIndexOf(BACKTICKS)) : str6;
        }).collect(Collectors.joining(StringUtils.SPACE))).replaceAll("\n\n+", "\n");
        if (this.log.isDebugEnabled()) {
            this.log.debug("Generated query for question %s\n%s".formatted(str, replaceAll));
        }
        return replaceAll;
    }

    private String loadSchema(Transaction transaction, Map<String, Object> map) {
        HashMap hashMap = new HashMap();
        hashMap.put("sample", map.get("sample"));
        return (String) transaction.execute(SCHEMA_QUERY, hashMap).stream().map(map2 -> {
            return SCHEMA_PROMPT.formatted(map2.get("nodes"), map2.get("relationships"), map2.get("patterns"));
        }).collect(Collectors.joining("\n"));
    }
}
