package apoc.ml;

import apoc.ApocConfig;
import apoc.Extended;
import apoc.ml.VertexAIHandler;
import apoc.ml.aws.AwsSignatureV4Generator;
import apoc.result.MapResult;
import apoc.result.ObjectResult;
import apoc.util.JsonUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.novell.ldap.events.edir.EdirEventConstant;
import java.net.MalformedURLException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import org.jsoup.helper.HttpConnection;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
/* loaded from: input_file:apoc/ml/VertexAI.class */
public class VertexAI {

    @Context
    public URLAccessChecker urlAccessChecker;

    @Context
    public ApocConfig apocConfig;

    /* loaded from: input_file:apoc/ml/VertexAI$EmbeddingResult.class */
    public static class EmbeddingResult {
        public final long index;
        public final String text;
        public final List<Double> embedding;

        public EmbeddingResult(long j, String str, List<Double> list) {
            this.index = j;
            this.text = str;
            this.embedding = list;
        }
    }

    private Stream<Object> executeRequest(String str, String str2, Map<String, Object> map, String str3, Object obj, Collection<String> collection, URLAccessChecker uRLAccessChecker) throws JsonProcessingException, MalformedURLException {
        return executeRequest(str, str2, map, str3, obj, collection, uRLAccessChecker, VertexAIHandler.Type.PREDICT);
    }

    private Stream<Object> executeRequest(String str, String str2, Map<String, Object> map, String str3, Object obj, Collection<String> collection, URLAccessChecker uRLAccessChecker, VertexAIHandler.Type type) throws JsonProcessingException {
        if (str == null || str.isBlank()) {
            throw new IllegalArgumentException("Access Token must not be empty");
        }
        Map map2 = (Map) map.getOrDefault("headers", new HashMap());
        map2.putIfAbsent(HttpConnection.CONTENT_TYPE, "application/json");
        map2.putIfAbsent("Accept", "application/json");
        map2.putIfAbsent(AwsSignatureV4Generator.AUTHORIZATION_KEY, "Bearer " + str);
        VertexAIHandler vertexAIHandler = type.get();
        return JsonUtil.loadJson(vertexAIHandler.getFullUrl(map, this.apocConfig, str3, str2), map2, new ObjectMapper().writeValueAsString(vertexAIHandler.getBody(obj, map, collection)), vertexAIHandler.getJsonPath(), true, List.of(), uRLAccessChecker);
    }

    @Procedure("apoc.ml.vertexai.embedding")
    @Description("apoc.vertexai.embedding([texts], accessToken, project, configuration) - returns the embeddings for a given text")
    public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> list, @Name("accessToken") String str, @Name("project") String str2, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) throws Exception {
        Stream<Object> executeRequest = executeRequest(str, str2, map, "textembedding-gecko", list.stream().map(str3 -> {
            return Map.of("content", str3);
        }).toList(), List.of(), this.urlAccessChecker);
        AtomicInteger atomicInteger = new AtomicInteger();
        return executeRequest.flatMap(obj -> {
            return ((List) obj).stream();
        }).map(map2 -> {
            Map map2 = (Map) map2.get("embeddings");
            int andIncrement = atomicInteger.getAndIncrement();
            return new EmbeddingResult(andIncrement, (String) list.get(andIncrement), (List) map2.get("values"));
        });
    }

    @Procedure("apoc.ml.vertexai.completion")
    @Description("apoc.ml.vertexai.completion(prompt, accessToken, project, configuration) - prompts the completion API")
    public Stream<MapResult> completion(@Name("prompt") String str, @Name("accessToken") String str2, @Name("project") String str3, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) throws Exception {
        return executeRequest(str2, str3, map, "text-bison", List.of(Map.of("prompt", str)), List.of("temperature", "topK", "topP", "maxOutputTokens"), this.urlAccessChecker).flatMap(obj -> {
            return ((List) obj).stream();
        }).map(map2 -> {
            return map2;
        }).map(MapResult::new);
    }

    public static Map<String, Object> getParameters(Map<String, Object> map, Collection<String> collection) {
        HashMap hashMap = new HashMap(Map.of("temperature", map.getOrDefault("temperature", Double.valueOf(0.3d)), "maxOutputTokens", map.getOrDefault("maxOutputTokens", 256), "maxDecodeSteps", map.getOrDefault("maxDecodeSteps", Integer.valueOf(EdirEventConstant.EVT_MOVE_ENTRY_DEST)), "topP", map.getOrDefault("topP", Double.valueOf(0.8d)), "topK", map.getOrDefault("topK", 40)));
        hashMap.keySet().retainAll(collection);
        return hashMap;
    }

    @Procedure("apoc.ml.vertexai.chat")
    @Description("apoc.ml.vertexai.chat(messages, accessToken, project, configuration]) - prompts the completion API")
    public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, String>> list, @Name("accessToken") String str, @Name("project") String str2, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map, @Name(value = "context", defaultValue = "") String str3, @Name(value = "examples", defaultValue = "[]") List<Map<String, Map<String, String>>> list2) throws Exception {
        return executeRequest(str, str2, map, "chat-bison", List.of(Map.of("context", str3, "examples", list2, "messages", list)), List.of("temperature", "topK", "topP", "maxOutputTokens"), this.urlAccessChecker).flatMap(obj -> {
            return ((List) obj).stream();
        }).map(map2 -> {
            return map2;
        }).map(MapResult::new);
    }

    @Procedure("apoc.ml.vertexai.stream")
    @Description("apoc.ml.vertexai.stream(contents, accessToken, project, configuration) - prompts the streaming API")
    public Stream<MapResult> stream(@Name("messages") List<Map<String, String>> list, @Name("accessToken") String str, @Name("project") String str2, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map) throws Exception {
        return executeRequest(str, str2, map, "gemini-pro", list, List.of("temperature", "topK", "topP", "maxOutputTokens"), this.urlAccessChecker, VertexAIHandler.Type.STREAM).flatMap(obj -> {
            return ((List) obj).stream();
        }).map(MapResult::new);
    }

    @Procedure("apoc.ml.vertexai.custom")
    @Description("apoc.ml.vertexai.custom(contents, accessToken, project, configuration) - prompts a customizable API")
    public Stream<ObjectResult> custom(@Name("body") Map<String, Object> map, @Name("accessToken") String str, @Name("project") String str2, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> map2) throws Exception {
        return executeRequest(str, str2, map2, "gemini-pro", map, Collections.emptyList(), this.urlAccessChecker, VertexAIHandler.Type.CUSTOM).map(ObjectResult::new);
    }
}
