package org.neo4j.genai.dbs.providers;

import com.fasterxml.jackson.core.JsonProcessingException;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.net.http.HttpRequest;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.neo4j.genai.dbs.RequestConfig;
import org.neo4j.genai.dbs.RowMappingConfig;
import org.neo4j.genai.dbs.VectorDatabaseProvider;
import org.neo4j.genai.dbs.VectorDatabaseRequest;
import org.neo4j.genai.dbs.VectorDatabases;
import org.neo4j.genai.util.GenAIProcedureException;
import org.neo4j.genai.util.HttpService;
import org.neo4j.genai.util.JsonUtils;

/* loaded from: input_file:org/neo4j/genai/dbs/providers/Pinecone.class */
public final class Pinecone implements VectorDatabaseProvider {
    private static final EnumSet<VectorDatabaseProvider.Command> DATA_PLANE_COMMANDS = EnumSet.of(VectorDatabaseProvider.Command.QUERY, VectorDatabaseProvider.Command.GET, VectorDatabaseProvider.Command.UPSERT, VectorDatabaseProvider.Command.DELETE);

    @Override // org.neo4j.genai.dbs.VectorDatabaseProvider
    public <T> VectorDatabaseRequest<T> createRequestFor(VectorDatabaseProvider.Command command, String str, String str2, Map<String, Object> map, Map<String, Object> map2) {
        VectorDatabases.ProcedureArguments procedureArguments = (VectorDatabases.ProcedureArguments) map2.get("procedureArguments");
        Function function = builder -> {
            return addAuthorizationHeader(map, builder);
        };
        if (command == VectorDatabaseProvider.Command.CREATE_COLLECTION) {
            return createCreateCollectionRequest(str, str2, function, map, map2);
        }
        if (command == VectorDatabaseProvider.Command.GET_COLLECTION_METADATA) {
            return createGetCollectionMetadataRequest(str, str2, function);
        }
        if (command == VectorDatabaseProvider.Command.GET) {
            return createGetRequest(str, function, map2);
        }
        if (command == VectorDatabaseProvider.Command.QUERY) {
            return createQueryRequest(str, function, map2, procedureArguments);
        }
        if (command == VectorDatabaseProvider.Command.DELETE_COLLECTION) {
            return createDeleteCollectionRequest(str, str2, function);
        }
        if (command == VectorDatabaseProvider.Command.DELETE) {
            return createDeleteRequest(str, function, map2);
        }
        if (command == VectorDatabaseProvider.Command.UPSERT) {
            return createCreateRequest(str, function, map2);
        }
        throw new UnsupportedOperationException();
    }

    private static <T> VectorDatabaseRequest<T> createCreateCollectionRequest(String str, String str2, Function<HttpRequest.Builder, HttpRequest.Builder> function, Map<String, Object> map, Map<String, Object> map2) {
        return new VectorDatabaseRequest<>(URI.create(str + "/indexes"), function.andThen(builder -> {
            try {
                return builder.header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(JsonUtils.getObjectMapper().writeValueAsString(Map.of("name", str2, "dimension", map2.get("size"), "metric", ((String) map2.get("similarity")).toLowerCase(Locale.ROOT), "spec", map.getOrDefault("spec", Map.of()))))).build();
            } catch (JsonProcessingException e) {
                throw new RuntimeException((Throwable) e);
            }
        }), inputStream -> {
            return VectorDatabases.StatusDTO.ok(null);
        });
    }

    @Override // org.neo4j.genai.dbs.VectorDatabaseProvider
    public HttpRequest.Builder addAuthorizationHeader(Map<String, Object> map, HttpRequest.Builder builder) {
        Map map2 = (Map) RequestConfig.Keys.HEADERS.get(Map.class, map);
        if (map2 != null && map2.containsKey(RequestConfig.Keys.AUTHORIZATION.key())) {
            builder.header("Api-Key", (String) map2.get(RequestConfig.Keys.AUTHORIZATION.key()));
        } else if (map.containsKey(RequestConfig.Keys.TOKEN.key())) {
            builder = builder.header("Api-Key", (String) map.get(RequestConfig.Keys.TOKEN.key()));
        }
        return builder;
    }

    private static <T> VectorDatabaseRequest<T> createGetCollectionMetadataRequest(String str, String str2, Function<HttpRequest.Builder, HttpRequest.Builder> function) {
        return new VectorDatabaseRequest<>(URI.create(str + "/indexes/" + str2), function.andThen((v0) -> {
            return v0.build();
        }), HttpService.DEFAULT_RESPONSE_TO_MAP_TRANSFORMER.andThen(VectorDatabases.InfoDTO::of));
    }

    private static <T> VectorDatabaseRequest<T> createGetRequest(String str, Function<HttpRequest.Builder, HttpRequest.Builder> function, Map<String, Object> map) {
        URI create = URI.create(str + "/vectors/fetch?ids=" + ((String) ((List) map.get("ids")).stream().map((v0) -> {
            return v0.toString();
        }).collect(Collectors.joining("&ids="))));
        Function andThen = function.andThen((v0) -> {
            return v0.GET();
        }).andThen((v0) -> {
            return v0.build();
        });
        RowMappingConfig rowMappingConfig = (RowMappingConfig) map.get("rowMappingConfig");
        return new VectorDatabaseRequest<>(create, andThen, inputStream -> {
            try {
                return ((Map) ((Map) JsonUtils.getObjectMapper().readValue(inputStream, JsonUtils.TYPE_REF_MAP_STRING_OBJECT)).get("vectors")).values().stream().map(map2 -> {
                    HashMap hashMap = new HashMap();
                    hashMap.put(rowMappingConfig.idKey(), map2.get("id"));
                    hashMap.put(rowMappingConfig.metadataKey(), map2.get("metadata"));
                    hashMap.put(rowMappingConfig.vectorKey(), map2.get("values"));
                    return hashMap;
                });
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        });
    }

    private static <T> VectorDatabaseRequest<T> createQueryRequest(String str, Function<HttpRequest.Builder, HttpRequest.Builder> function, Map<String, Object> map, VectorDatabases.ProcedureArguments procedureArguments) {
        RowMappingConfig rowMappingConfig = (RowMappingConfig) map.get("rowMappingConfig");
        List list = (List) map.get("vector");
        long longValue = ((Long) map.get("limit")).longValue();
        HashMap hashMap = new HashMap();
        hashMap.put("includeMetadata", true);
        hashMap.put("vector", list);
        hashMap.put("topK", Long.valueOf(longValue));
        hashMap.put("includeValues", Boolean.valueOf(procedureArguments.hasVector()));
        ((Optional) map.get("filter")).ifPresent(obj -> {
            hashMap.put("filter", obj);
        });
        return new VectorDatabaseRequest<>(URI.create(str + "/query"), function.andThen(builder -> {
            builder.header("Content-Type", "application/json");
            try {
                builder.POST(HttpRequest.BodyPublishers.ofString(JsonUtils.getObjectMapper().writeValueAsString(hashMap)));
                return builder.build();
            } catch (JsonProcessingException e) {
                throw new RuntimeException((Throwable) e);
            }
        }), inputStream -> {
            try {
                return ((List) ((Map) JsonUtils.getObjectMapper().readValue(inputStream, JsonUtils.TYPE_REF_MAP_STRING_OBJECT)).get("matches")).stream().map(map2 -> {
                    HashMap hashMap2 = new HashMap();
                    hashMap2.put(rowMappingConfig.metadataKey(), map2.get("metadata"));
                    hashMap2.put(rowMappingConfig.scoreKey(), map2.get("score"));
                    hashMap2.put(rowMappingConfig.idKey(), map2.get("id"));
                    hashMap2.put(rowMappingConfig.vectorKey(), map2.get("values"));
                    return hashMap2;
                });
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        });
    }

    @Override // org.neo4j.genai.dbs.VectorDatabaseProvider
    public Optional<VectorDatabaseRequest<String>> getCollectionSpecificHost(VectorDatabaseProvider.Command command, String str, String str2, Map<String, Object> map) {
        if (!DATA_PLANE_COMMANDS.contains(command)) {
            return super.getCollectionSpecificHost(command, str, str2, map);
        }
        Function function = builder -> {
            return addAuthorizationHeader(map, builder);
        };
        return Optional.of(new VectorDatabaseRequest(URI.create(str + "/indexes/" + str2), function.andThen((v0) -> {
            return v0.build();
        }), HttpService.DEFAULT_RESPONSE_TO_MAP_TRANSFORMER.andThen(map2 -> {
            return "https://" + String.valueOf(map2.get("host"));
        })));
    }

    private static <T> VectorDatabaseRequest<T> createCreateRequest(String str, Function<HttpRequest.Builder, HttpRequest.Builder> function, Map<String, Object> map) {
        Set of = Set.of("id", "vector", "metadata");
        return new VectorDatabaseRequest<>(URI.create(str + "/vectors/upsert"), function.andThen(builder -> {
            try {
                ArrayList arrayList = new ArrayList();
                for (Map map2 : (List) map.get("vectors")) {
                    HashMap hashMap = new HashMap();
                    hashMap.put("id", map2.get("id"));
                    hashMap.put("metadata", map2.get("metadata"));
                    hashMap.put("values", map2.get("vector"));
                    map2.forEach((str2, obj) -> {
                        if (of.contains(str2)) {
                            hashMap.put(str2, obj);
                        }
                    });
                    arrayList.add(hashMap);
                }
                return builder.header("Content-Type", "application/json").POST(HttpRequest.BodyPublishers.ofString(JsonUtils.getObjectMapper().writeValueAsString(Map.of("vectors", arrayList)))).build();
            } catch (JsonProcessingException e) {
                throw new RuntimeException((Throwable) e);
            }
        }), inputStream -> {
            return VectorDatabases.StatusDTO.ok(null);
        });
    }

    private static <T> VectorDatabaseRequest<T> createDeleteCollectionRequest(String str, String str2, Function<HttpRequest.Builder, HttpRequest.Builder> function) {
        return new VectorDatabaseRequest<>(URI.create(str + "/indexes/" + str2), function.andThen((v0) -> {
            return v0.DELETE();
        }).andThen((v0) -> {
            return v0.build();
        }), inputStream -> {
            return VectorDatabases.StatusDTO.ok(null);
        });
    }

    private static <T> VectorDatabaseRequest<T> createDeleteRequest(String str, Function<HttpRequest.Builder, HttpRequest.Builder> function, Map<String, Object> map) {
        try {
            String writeValueAsString = JsonUtils.getObjectMapper().writeValueAsString(Map.of("ids", map.get("ids")));
            return new VectorDatabaseRequest<>(URI.create(str + "/vectors/delete"), function.andThen(builder -> {
                return builder.POST(HttpRequest.BodyPublishers.ofString(writeValueAsString));
            }).andThen((v0) -> {
                return v0.build();
            }), inputStream -> {
                return VectorDatabases.StatusDTO.ok(null);
            });
        } catch (JsonProcessingException e) {
            throw new GenAIProcedureException("Failed to create body for batch deletion of vectors");
        }
    }
}
