package org.neo4j.genai.dbs.providers;

import com.fasterxml.jackson.core.JsonProcessingException;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import java.util.stream.IntStream;
import org.neo4j.genai.dbs.CollectionNotFoundException;
import org.neo4j.genai.dbs.RowMappingConfig;
import org.neo4j.genai.dbs.VectorDatabaseProvider;
import org.neo4j.genai.dbs.VectorDatabaseRequest;
import org.neo4j.genai.dbs.VectorDatabases;
import org.neo4j.genai.util.GenAIProcedureException;
import org.neo4j.genai.util.HttpService;
import org.neo4j.genai.util.JsonUtils;

/* loaded from: input_file:org/neo4j/genai/dbs/providers/ChromaDb.class */
public class ChromaDb implements VectorDatabaseProvider {
    private static final UnaryOperator<String> CREATE_BASE_URI;
    private static final BinaryOperator<String> CREATE_COLLECTION_BASE_URI = (str, str2) -> {
        return ((String) CREATE_BASE_URI.apply(str)) + "/" + str2;
    };
    private static final BiFunction<String, String, String> CREATE_GET_POINTS_BASE_URI = CREATE_COLLECTION_BASE_URI.andThen(str -> {
        return str + "/get";
    });
    private static final BiFunction<String, String, String> CREATE_UPSERT_POINTS_BASE_URI = CREATE_COLLECTION_BASE_URI.andThen(str -> {
        return str + "/upsert";
    });
    private static final BiFunction<String, String, String> CREATE_QUERY_BASE_URI = CREATE_COLLECTION_BASE_URI.andThen(str -> {
        return str + "/query";
    });
    private static final String IDS_KEY = "ids";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/genai/dbs/providers/ChromaDb$Result.class */
    public static final class Result extends Record {
        private final URI target;
        private final Function<HttpRequest.Builder, HttpRequest> requestCustomizer;

        private Result(URI uri, Function<HttpRequest.Builder, HttpRequest> function) {
            this.target = uri;
            this.requestCustomizer = function;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, Result.class), Result.class, "target;requestCustomizer", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$Result;->target:Ljava/net/URI;", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$Result;->requestCustomizer:Ljava/util/function/Function;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, Result.class), Result.class, "target;requestCustomizer", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$Result;->target:Ljava/net/URI;", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$Result;->requestCustomizer:Ljava/util/function/Function;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, Result.class, Object.class), Result.class, "target;requestCustomizer", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$Result;->target:Ljava/net/URI;", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$Result;->requestCustomizer:Ljava/util/function/Function;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public URI target() {
            return this.target;
        }

        public Function<HttpRequest.Builder, HttpRequest> requestCustomizer() {
            return this.requestCustomizer;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/genai/dbs/providers/ChromaDb$UpsertPayload.class */
    public static final class UpsertPayload extends Record {
        private final List<List<Double>> embeddings;
        private final List<Map<String, Object>> metadatas;
        private final List<Object> ids;

        private UpsertPayload(List<List<Double>> list, List<Map<String, Object>> list2, List<Object> list3) {
            this.embeddings = list;
            this.metadatas = list2;
            this.ids = list3;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, UpsertPayload.class), UpsertPayload.class, "embeddings;metadatas;ids", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$UpsertPayload;->embeddings:Ljava/util/List;", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$UpsertPayload;->metadatas:Ljava/util/List;", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$UpsertPayload;->ids:Ljava/util/List;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, UpsertPayload.class), UpsertPayload.class, "embeddings;metadatas;ids", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$UpsertPayload;->embeddings:Ljava/util/List;", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$UpsertPayload;->metadatas:Ljava/util/List;", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$UpsertPayload;->ids:Ljava/util/List;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, UpsertPayload.class, Object.class), UpsertPayload.class, "embeddings;metadatas;ids", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$UpsertPayload;->embeddings:Ljava/util/List;", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$UpsertPayload;->metadatas:Ljava/util/List;", "FIELD:Lorg/neo4j/genai/dbs/providers/ChromaDb$UpsertPayload;->ids:Ljava/util/List;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public List<List<Double>> embeddings() {
            return this.embeddings;
        }

        public List<Map<String, Object>> metadatas() {
            return this.metadatas;
        }

        public List<Object> ids() {
            return this.ids;
        }
    }

    @Override // org.neo4j.genai.dbs.VectorDatabaseProvider
    public <T> VectorDatabaseRequest<T> createRequestFor(VectorDatabaseProvider.Command command, String str, String str2, Map<String, Object> map, Map<String, Object> map2) {
        RowMappingConfig rowMappingConfig = (RowMappingConfig) map2.get("rowMappingConfig");
        VectorDatabases.ProcedureArguments procedureArguments = (VectorDatabases.ProcedureArguments) map2.get("procedureArguments");
        Function function = builder -> {
            return addHttpVersion(addAuthorizationHeader(map, builder));
        };
        if (command == VectorDatabaseProvider.Command.GET_COLLECTION_METADATA) {
            return createGetCollectionMetadataRequest(str, str2, function);
        }
        if (command == VectorDatabaseProvider.Command.GET) {
            return createGetRequest(str, str2, map2, procedureArguments, rowMappingConfig, function);
        }
        if (command == VectorDatabaseProvider.Command.QUERY) {
            return createQueryRequest(str, str2, map2, procedureArguments, rowMappingConfig, function);
        }
        if (command == VectorDatabaseProvider.Command.UPSERT) {
            return createUpsertRequest(str, str2, map, map2, function);
        }
        if (command == VectorDatabaseProvider.Command.DELETE_COLLECTION) {
            return createDeleteCollectionRequest(str, str2, function);
        }
        if (command == VectorDatabaseProvider.Command.CREATE_COLLECTION) {
            return createCreateCollectionRequest(str, str2, map2, function);
        }
        if (command == VectorDatabaseProvider.Command.DELETE) {
            return createDeleteRequest(str, str2, map, map2, function);
        }
        throw new UnsupportedOperationException();
    }

    private static <T> VectorDatabaseRequest<T> createDeleteRequest(String str, String str2, Map<String, Object> map, Map<String, Object> map2, Function<HttpRequest.Builder, HttpRequest.Builder> function) {
        return new VectorDatabaseRequest<>(URI.create(((String) CREATE_COLLECTION_BASE_URI.apply(str, str2)) + "/delete"), function.andThen(builder -> {
            try {
                return builder.POST(HttpRequest.BodyPublishers.ofString(JsonUtils.getObjectMapper().writeValueAsString(Map.of(IDS_KEY, map2.get(IDS_KEY))))).build();
            } catch (JsonProcessingException e) {
                throw new RuntimeException((Throwable) e);
            }
        }), inputStream -> {
            return VectorDatabases.StatusDTO.ok(null);
        });
    }

    private static <T> VectorDatabaseRequest<T> createCreateCollectionRequest(String str, String str2, Map<String, Object> map, Function<HttpRequest.Builder, HttpRequest.Builder> function) {
        return new VectorDatabaseRequest<>(URI.create((String) CREATE_BASE_URI.apply(str)), function.andThen(builder -> {
            try {
                return builder.POST(HttpRequest.BodyPublishers.ofString(JsonUtils.getObjectMapper().writeValueAsString(Map.of("name", str2, "metadata", Map.of("size", map.get("size"), "hnsw:space", map.get("similarity").toString().toLowerCase()))))).build();
            } catch (JsonProcessingException e) {
                throw new RuntimeException((Throwable) e);
            }
        }), inputStream -> {
            return VectorDatabases.StatusDTO.ok(null);
        });
    }

    private static <T> VectorDatabaseRequest<T> createUpsertRequest(String str, String str2, Map<String, Object> map, Map<String, Object> map2, Function<HttpRequest.Builder, HttpRequest.Builder> function) {
        return new VectorDatabaseRequest<>(URI.create(CREATE_UPSERT_POINTS_BASE_URI.apply(str, str2)), function.andThen(builder -> {
            try {
                ArrayList arrayList = new ArrayList();
                ArrayList arrayList2 = new ArrayList();
                ArrayList arrayList3 = new ArrayList();
                for (Map map3 : (List) map2.get("vectors")) {
                    arrayList.add((List) map3.get("vector"));
                    arrayList3.add(map3.get("id"));
                    arrayList2.add((Map) map3.get("metadata"));
                }
                return builder.POST(HttpRequest.BodyPublishers.ofString(JsonUtils.getObjectMapper().writeValueAsString(new UpsertPayload(arrayList, arrayList2, arrayList3)))).build();
            } catch (JsonProcessingException e) {
                throw new RuntimeException((Throwable) e);
            }
        }), inputStream -> {
            return VectorDatabases.StatusDTO.ok(null);
        });
    }

    @Override // org.neo4j.genai.dbs.VectorDatabaseProvider
    public BiFunction<Integer, String, Optional<GenAIProcedureException>> getProviderSpecificStatusHandler(String str) {
        return (num, str2) -> {
            boolean contains = str2.contains("Collection " + str + " does not exist.");
            return (num.intValue() == 400 && contains) ? Optional.of(new CollectionNotFoundException(str)) : (num.intValue() == 500 && contains) ? Optional.of(new GenAIProcedureException("API request forbidden (HTTP response code: 403); check your credentials.", (Integer) 500)) : Optional.empty();
        };
    }

    private static <T> VectorDatabaseRequest<T> createQueryRequest(String str, String str2, Map<String, Object> map, VectorDatabases.ProcedureArguments procedureArguments, RowMappingConfig rowMappingConfig, Function<HttpRequest.Builder, HttpRequest.Builder> function) {
        URI create = URI.create(CREATE_QUERY_BASE_URI.apply(str, str2));
        Function<? super HttpRequest.Builder, ? extends V> function2 = builder -> {
            try {
                HashMap hashMap = new HashMap(Map.of("query_embeddings", List.of(map.get("vector")), "n_results", map.get("limit"), "include", List.of("metadatas", "embeddings", "distances")));
                hashMap.put("where", ((Optional) map.get("filter")).orElseGet(Map::of));
                return builder.POST(HttpRequest.BodyPublishers.ofString(JsonUtils.getObjectMapper().writeValueAsString(hashMap))).build();
            } catch (JsonProcessingException e) {
                throw new RuntimeException((Throwable) e);
            }
        };
        return new VectorDatabaseRequest<>(create, function.andThen(function2), inputStream -> {
            try {
                Map map2 = (Map) JsonUtils.getObjectMapper().readValue(inputStream, JsonUtils.TYPE_REF_MAP_STRING_OBJECT);
                List list = (List) ((List) map2.get(IDS_KEY)).get(0);
                List list2 = (List) ((List) map2.get("embeddings")).get(0);
                List list3 = (List) ((List) map2.get("metadatas")).get(0);
                List list4 = (List) ((List) map2.get("distances")).get(0);
                return IntStream.rangeClosed(0, list.size() - 1).mapToObj(i -> {
                    HashMap hashMap = new HashMap();
                    hashMap.put(rowMappingConfig.idKey(), list.get(i));
                    hashMap.put(rowMappingConfig.metadataKey(), list3.get(i));
                    hashMap.put(rowMappingConfig.scoreKey(), list4.get(i));
                    if (procedureArguments.allResults() && list2 != null && list2.get(i) != null) {
                        hashMap.put(rowMappingConfig.vectorKey(), list2.get(i));
                    }
                    return hashMap;
                });
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        });
    }

    private static <T> VectorDatabaseRequest<T> createGetRequest(String str, String str2, Map<String, Object> map, VectorDatabases.ProcedureArguments procedureArguments, RowMappingConfig rowMappingConfig, Function<HttpRequest.Builder, HttpRequest.Builder> function) {
        Result result = new Result(URI.create(CREATE_GET_POINTS_BASE_URI.apply(str, str2)), function.andThen(builder -> {
            try {
                return builder.POST(HttpRequest.BodyPublishers.ofString(JsonUtils.getObjectMapper().writeValueAsString(Map.of(IDS_KEY, map.get(IDS_KEY), "include", List.of("metadatas", "embeddings"))))).build();
            } catch (JsonProcessingException e) {
                throw new RuntimeException((Throwable) e);
            }
        }));
        return new VectorDatabaseRequest<>(result.target(), result.requestCustomizer(), inputStream -> {
            try {
                Map map2 = (Map) JsonUtils.getObjectMapper().readValue(inputStream, JsonUtils.TYPE_REF_MAP_STRING_OBJECT);
                List list = (List) map2.get(IDS_KEY);
                List list2 = (List) map2.get("embeddings");
                List list3 = (List) map2.get("metadatas");
                return IntStream.rangeClosed(0, list.size() - 1).mapToObj(i -> {
                    HashMap hashMap = new HashMap();
                    hashMap.put(rowMappingConfig.idKey(), list.get(i));
                    hashMap.put(rowMappingConfig.metadataKey(), list3.get(i));
                    if (procedureArguments.allResults() && list2 != null && list2.get(i) != null) {
                        hashMap.put(rowMappingConfig.vectorKey(), list2.get(i));
                    }
                    return hashMap;
                });
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        });
    }

    private static <T> VectorDatabaseRequest<T> createGetCollectionMetadataRequest(String str, String str2, Function<HttpRequest.Builder, HttpRequest.Builder> function) {
        return new VectorDatabaseRequest<>(URI.create((String) CREATE_COLLECTION_BASE_URI.apply(str, str2)), function.andThen((v0) -> {
            return v0.build();
        }), HttpService.DEFAULT_RESPONSE_TO_MAP_TRANSFORMER.andThen(VectorDatabases.InfoDTO::of));
    }

    private static <T> VectorDatabaseRequest<T> createDeleteCollectionRequest(String str, String str2, Function<HttpRequest.Builder, HttpRequest.Builder> function) {
        URI create = URI.create((String) CREATE_COLLECTION_BASE_URI.apply(str, str2));
        Function<? super HttpRequest.Builder, ? extends V> function2 = builder -> {
            return builder.DELETE().build();
        };
        return new VectorDatabaseRequest<>(create, function.andThen(function2), HttpService.DEFAULT_RESPONSE_TO_MAP_TRANSFORMER.andThen(map -> {
            return new VectorDatabases.StatusDTO("ok", Map.of());
        }));
    }

    private static HttpRequest.Builder addHttpVersion(HttpRequest.Builder builder) {
        return builder.version(HttpClient.Version.HTTP_1_1);
    }

    static {
        String str = "%s/api/v1/collections";
        CREATE_BASE_URI = obj -> {
            return "%s/api/v1/collections".formatted(obj);
        };
    }
}
