package apoc.ml.watson;

import apoc.ApocConfig;
import apoc.Extended;
import apoc.ExtendedApocConfig;
import apoc.ml.MLUtil;
import apoc.ml.aws.AwsSignatureV4Generator;
import apoc.ml.watson.WatsonHandler;
import apoc.result.MapResult;
import apoc.util.JsonUtil;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.jsoup.helper.HttpConnection;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
/* loaded from: input_file:apoc/ml/watson/Watson.class */
public class Watson {
    static final String PROJECT_ID_KEY = "project_id";
    static final String SPACE_ID_KEY = "space_id";
    static final String MODEL_ID_KEY = "model_id";
    static final String WML_INSTANCE_CRN_KEY = "wml_instance_crn";
    static final String DEFAULT_COMPLETION_MODEL_ID = "ibm/granite-13b-chat-v2";
    static final String DEFAULT_EMBEDDING_MODEL_ID = "ibm/slate-30m-english-rtrvr";
    static final String DEFAULT_VERSION_DATE = "2023-05-29";
    static final String DEFAULT_REGION = "eu-de";

    @Context
    public ApocConfig apocConfig;

    @Context
    public URLAccessChecker urlAccessChecker;

    /* loaded from: input_file:apoc/ml/watson/Watson$EmbeddingResult.class */
    public static final class EmbeddingResult extends Record {
        private final long index;
        private final String text;
        private final List<Double> embedding;

        public EmbeddingResult(long j, String str, List<Double> list) {
            this.index = j;
            this.text = str;
            this.embedding = list;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, EmbeddingResult.class), EmbeddingResult.class, "index;text;embedding", "FIELD:Lapoc/ml/watson/Watson$EmbeddingResult;->index:J", "FIELD:Lapoc/ml/watson/Watson$EmbeddingResult;->text:Ljava/lang/String;", "FIELD:Lapoc/ml/watson/Watson$EmbeddingResult;->embedding:Ljava/util/List;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, EmbeddingResult.class), EmbeddingResult.class, "index;text;embedding", "FIELD:Lapoc/ml/watson/Watson$EmbeddingResult;->index:J", "FIELD:Lapoc/ml/watson/Watson$EmbeddingResult;->text:Ljava/lang/String;", "FIELD:Lapoc/ml/watson/Watson$EmbeddingResult;->embedding:Ljava/util/List;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, EmbeddingResult.class, Object.class), EmbeddingResult.class, "index;text;embedding", "FIELD:Lapoc/ml/watson/Watson$EmbeddingResult;->index:J", "FIELD:Lapoc/ml/watson/Watson$EmbeddingResult;->text:Ljava/lang/String;", "FIELD:Lapoc/ml/watson/Watson$EmbeddingResult;->embedding:Ljava/util/List;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public long index() {
            return this.index;
        }

        public String text() {
            return this.text;
        }

        public List<Double> embedding() {
            return this.embedding;
        }
    }

    @Procedure("apoc.ml.watson.embedding")
    @Description("apoc.ml.watson.embedding([texts], $configuration) - returns the embeddings for a given text")
    public Stream<EmbeddingResult> embedding(@Name("texts") List<String> list, @Name("accessToken") String str, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) {
        if (list == null) {
            throw new RuntimeException(MLUtil.ERROR_NULL_INPUT);
        }
        AtomicInteger atomicInteger = new AtomicInteger();
        return executeRequest(list, str, map, WatsonHandler.Type.EMBEDDING.get()).flatMap(map2 -> {
            return ((List) map2.get("results")).stream();
        }).map(map3 -> {
            int andIncrement = atomicInteger.getAndIncrement();
            return new EmbeddingResult(andIncrement, (String) list.get(andIncrement), (List) map3.get("embedding"));
        });
    }

    @Procedure("apoc.ml.watson.chat")
    @Description("apoc.ml.watson.chat(messages, accessToken, $configuration) - prompts the completion API")
    public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Object>> list, @Name("accessToken") String str, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) throws Exception {
        if (list == null) {
            throw new RuntimeException(MLUtil.ERROR_NULL_INPUT);
        }
        return completion((String) list.stream().map(map2 -> {
            Object obj = map2.get("role");
            Object obj2 = map2.get("content");
            if (obj == null || obj2 == null) {
                throw new RuntimeException("The `messages` items must have the keys: `role` and `content`");
            }
            return obj + ": " + obj2;
        }).collect(Collectors.joining("\n\n")), str, map);
    }

    @Procedure("apoc.ml.watson.completion")
    @Description("apoc.ml.watson.completion(prompt, accessToken, $configuration) - prompts the completion API")
    public Stream<MapResult> completion(@Name("prompt") String str, @Name("accessToken") String str2, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) throws Exception {
        if (str == null) {
            throw new RuntimeException(MLUtil.ERROR_NULL_INPUT);
        }
        return executeRequest(str, str2, map, WatsonHandler.Type.COMPLETION.get()).map(MapResult::new);
    }

    private Stream<Map> executeRequest(Object obj, String str, Map<String, Object> map, WatsonHandler watsonHandler) {
        try {
            if (!map.containsKey(PROJECT_ID_KEY) && !map.containsKey(SPACE_ID_KEY) && !map.containsKey(WML_INSTANCE_CRN_KEY)) {
                String string = this.apocConfig.getString(ExtendedApocConfig.APOC_ML_WATSON_PROJECT_ID, null);
                if (string == null) {
                    throw new RuntimeException("The body request has none of %s, %s, and %s and the APOC config `%s` is not present.%nPlease, define one of these".formatted(PROJECT_ID_KEY, SPACE_ID_KEY, WML_INSTANCE_CRN_KEY, ExtendedApocConfig.APOC_ML_WATSON_PROJECT_ID));
                }
                map.put(PROJECT_ID_KEY, string);
            }
            return JsonUtil.loadJson(watsonHandler.getEndpoint(map), Map.of(HttpConnection.CONTENT_TYPE, "application/json", "accept", "application/json", AwsSignatureV4Generator.AUTHORIZATION_KEY, "Bearer " + str), JsonUtil.OBJECT_MAPPER.writeValueAsString(watsonHandler.getPayload(map, obj)), "$", true, List.of(), this.urlAccessChecker).map(obj2 -> {
                return (Map) obj2;
            });
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
