package io.quarkiverse.langchain4j.redis;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Json;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.quarkiverse.langchain4j.redis.runtime.RedisSchema;
import io.quarkus.redis.datasource.ReactiveRedisDataSource;
import io.quarkus.redis.datasource.json.ReactiveJsonCommands;
import io.quarkus.redis.datasource.keys.KeyScanArgs;
import io.smallrye.mutiny.Uni;
import io.vertx.mutiny.redis.client.Command;
import io.vertx.mutiny.redis.client.Request;
import io.vertx.mutiny.redis.client.Response;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.jboss.logging.Logger;

/* loaded from: input_file:io/quarkiverse/langchain4j/redis/RedisEmbeddingStore.class */
public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {
    public static final String EXTRA_ATTRIBUTES = "extra_attributes";
    public static final String ID = "id";
    private final ReactiveRedisDataSource ds;
    private final RedisSchema schema;
    private final Logger LOG = Logger.getLogger(RedisEmbeddingStore.class);
    private static final String SCORE_FIELD_NAME = "vector_score";

    /* loaded from: input_file:io/quarkiverse/langchain4j/redis/RedisEmbeddingStore$Builder.class */
    public static class Builder {
        private ReactiveRedisDataSource redisClient;
        private RedisSchema schema;

        public Builder dataSource(ReactiveRedisDataSource reactiveRedisDataSource) {
            this.redisClient = reactiveRedisDataSource;
            return this;
        }

        public Builder schema(RedisSchema redisSchema) {
            this.schema = redisSchema;
            return this;
        }

        public RedisEmbeddingStore build() {
            return new RedisEmbeddingStore(this.redisClient, this.schema);
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public RedisEmbeddingStore(ReactiveRedisDataSource reactiveRedisDataSource, RedisSchema redisSchema) {
        this.ds = reactiveRedisDataSource;
        this.schema = redisSchema;
        createIndexIfDoesNotExist();
    }

    private void createIndexIfDoesNotExist() {
        if (((List) this.ds.search().ft_list().onFailure().invoke(th -> {
            if (th.getMessage().contains("unknown command")) {
                this.LOG.error("The Redis server does not seem to support RediSearch. Please install the RediSearch module. If using containers, we suggest to use the redis/redis-stack images.");
            }
        }).await().indefinitely()).contains(this.schema.getIndexName())) {
            this.LOG.debug("Index in Redis already exists: " + this.schema.getIndexName());
            return;
        }
        Request arg = Request.cmd(Command.FT_CREATE).arg(this.schema.getIndexName()).arg("ON").arg("JSON").arg("PREFIX").arg("1").arg(this.schema.getPrefix()).arg("SCHEMA");
        this.schema.defineFields(arg);
        this.LOG.debug("Creating index with command: " + arg.toString().replaceAll("\r\n", " "));
        this.ds.getRedis().send(arg).await().indefinitely();
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        add(randomUUID, embedding);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list2, list, null);
        return list2;
    }

    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        List<String> list3 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list3, list, list2);
        return list3;
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAllInternal(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    private void addAllInternal(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (list.isEmpty() || list.size() != list2.size() || (list3 != null && list3.size() != list2.size())) {
            throw new IllegalArgumentException("ids, embeddings and embedded must be non-empty and of the same size");
        }
        ReactiveJsonCommands json = this.ds.json();
        int size = list.size();
        Uni[] uniArr = new Uni[size];
        for (int i = 0; i < size; i++) {
            String str = list.get(i);
            Embedding embedding = list2.get(i);
            TextSegment textSegment = list3 == null ? null : list3.get(i);
            HashMap hashMap = new HashMap();
            hashMap.put(this.schema.getVectorFieldName(), embedding.vector());
            if (textSegment != null) {
                hashMap.put(this.schema.getScalarFieldName(), textSegment.text());
                hashMap.putAll(textSegment.metadata().asMap());
            }
            uniArr[i] = json.jsonSet(this.schema.getPrefix() + str, "$", hashMap);
        }
        Uni.join().all(uniArr).andFailFast().await().indefinitely();
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int i, double d) {
        return (List) StreamSupport.stream(((Response) this.ds.getRedis().send(Request.cmd(Command.FT_SEARCH).arg(this.schema.getIndexName()).arg(String.format("*=>[ KNN %d @%s $BLOB AS %s ]", Integer.valueOf(i), this.schema.getVectorFieldName(), SCORE_FIELD_NAME)).arg("PARAMS").arg("2").arg("BLOB").arg(toByteArray(embedding.vector())).arg("DIALECT").arg("2")).await().indefinitely()).get("results").spliterator(), false).map(this::toEmbeddingMatch).filter(embeddingMatch -> {
            return embeddingMatch.score().doubleValue() >= d;
        }).collect(Collectors.toList());
    }

    public void deleteAll() {
        Set set = (Set) this.ds.key().scan(new KeyScanArgs().match(this.schema.getPrefix() + "*")).toMulti().collect().asSet().await().indefinitely();
        if (set.isEmpty()) {
            return;
        }
        Request cmd = Request.cmd(Command.DEL);
        Objects.requireNonNull(cmd);
        set.forEach(cmd::arg);
        this.ds.getRedis().send(cmd).await().indefinitely();
        this.LOG.debug("Deleted " + set.size() + " keys");
    }

    public static byte[] toByteArray(float[] fArr) {
        byte[] bArr = new byte[4 * fArr.length];
        ByteBuffer.wrap(bArr).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().put(fArr);
        return bArr;
    }

    private EmbeddingMatch<TextSegment> toEmbeddingMatch(Response response) {
        try {
            JsonNode readTree = QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER.readTree(response.get(EXTRA_ATTRIBUTES).get("$").toString());
            JsonNode jsonNode = readTree.get(this.schema.getScalarFieldName());
            Embedding embedding = new Embedding((float[]) Json.fromJson(readTree.get(this.schema.getVectorFieldName()).toString(), float[].class));
            double doubleValue = (2.0d - response.get(EXTRA_ATTRIBUTES).get(SCORE_FIELD_NAME).toDouble().doubleValue()) / 2.0d;
            String substring = response.get(ID).toString().substring(this.schema.getPrefix().length());
            Stream<String> stream = this.schema.getMetadataFields().stream();
            Objects.requireNonNull(readTree);
            return new EmbeddingMatch<>(Double.valueOf(doubleValue), substring, embedding, jsonNode != null ? new TextSegment(jsonNode.asText(), Metadata.from((Map) stream.filter(readTree::has).collect(Collectors.toMap(str -> {
                return str;
            }, str2 -> {
                return readTree.get(str2).asText();
            })))) : null);
        } catch (JsonProcessingException e) {
            throw new RuntimeException((Throwable) e);
        }
    }
}
