package dev.langchain4j.store.embedding.azure.search;

import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.credential.TokenCredential;
import com.azure.core.util.Context;
import com.azure.search.documents.SearchClient;
import com.azure.search.documents.SearchClientBuilder;
import com.azure.search.documents.SearchDocument;
import com.azure.search.documents.indexes.SearchIndexClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import com.azure.search.documents.indexes.models.HnswAlgorithmConfiguration;
import com.azure.search.documents.indexes.models.HnswParameters;
import com.azure.search.documents.indexes.models.SearchField;
import com.azure.search.documents.indexes.models.SearchFieldDataType;
import com.azure.search.documents.indexes.models.SearchIndex;
import com.azure.search.documents.indexes.models.SemanticConfiguration;
import com.azure.search.documents.indexes.models.SemanticField;
import com.azure.search.documents.indexes.models.SemanticPrioritizedFields;
import com.azure.search.documents.indexes.models.SemanticSearch;
import com.azure.search.documents.indexes.models.VectorSearch;
import com.azure.search.documents.indexes.models.VectorSearchAlgorithmMetric;
import com.azure.search.documents.indexes.models.VectorSearchProfile;
import com.azure.search.documents.models.IndexingResult;
import com.azure.search.documents.models.SearchOptions;
import com.azure.search.documents.models.SearchResult;
import com.azure.search.documents.models.VectorQuery;
import com.azure.search.documents.models.VectorSearchOptions;
import com.azure.search.documents.models.VectorizedQuery;
import com.azure.search.documents.util.SearchPagedIterable;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.rag.content.retriever.azure.search.AzureAiSearchFilterMapper;
import dev.langchain4j.rag.content.retriever.azure.search.DefaultAzureAiSearchFilterMapper;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.azure.search.Document;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/store/embedding/azure/search/AbstractAzureAiSearchEmbeddingStore.class */
public abstract class AbstractAzureAiSearchEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(AbstractAzureAiSearchEmbeddingStore.class);
    public static final String DEFAULT_INDEX_NAME = "vectorsearch";
    static final String DEFAULT_FIELD_ID = "id";
    protected static final String DEFAULT_FIELD_CONTENT = "content";
    protected final String DEFAULT_FIELD_CONTENT_VECTOR = "content_vector";
    protected static final String DEFAULT_FIELD_METADATA = "metadata";
    protected static final String DEFAULT_FIELD_METADATA_SOURCE = "source";
    protected static final String DEFAULT_FIELD_METADATA_ATTRS = "attributes";
    protected static final String SEMANTIC_SEARCH_CONFIG_NAME = "semantic-search-config";
    protected static final String VECTOR_ALGORITHM_NAME = "vector-search-algorithm";
    protected static final String VECTOR_SEARCH_PROFILE_NAME = "vector-search-profile";
    private boolean createOrUpdateIndex;
    private SearchIndexClient searchIndexClient;
    protected SearchClient searchClient;
    private String indexName;
    protected AzureAiSearchFilterMapper filterMapper;

    /* JADX INFO: Access modifiers changed from: protected */
    public void initialize(String str, AzureKeyCredential azureKeyCredential, TokenCredential tokenCredential, boolean z, int i, SearchIndex searchIndex, String str2, AzureAiSearchFilterMapper azureAiSearchFilterMapper) {
        ValidationUtils.ensureNotNull(str, "endpoint");
        if (azureAiSearchFilterMapper == null) {
            this.filterMapper = new DefaultAzureAiSearchFilterMapper();
        } else {
            this.filterMapper = azureAiSearchFilterMapper;
        }
        if (searchIndex != null && Utils.isNotNullOrBlank(str2)) {
            throw new IllegalArgumentException("index and indexName cannot be both defined");
        }
        if (!z || searchIndex == null) {
            this.indexName = (String) Utils.getOrDefault(str2, DEFAULT_INDEX_NAME);
        } else {
            this.indexName = searchIndex.getName();
        }
        this.createOrUpdateIndex = z;
        if (azureKeyCredential != null) {
            if (z) {
                this.searchIndexClient = new SearchIndexClientBuilder().endpoint(str).credential(azureKeyCredential).buildClient();
            }
            this.searchClient = new SearchClientBuilder().endpoint(str).credential(azureKeyCredential).indexName(this.indexName).buildClient();
        } else {
            if (z) {
                this.searchIndexClient = new SearchIndexClientBuilder().endpoint(str).credential(tokenCredential).buildClient();
            }
            this.searchClient = new SearchClientBuilder().endpoint(str).credential(tokenCredential).indexName(this.indexName).buildClient();
        }
        if (z) {
            if (searchIndex == null) {
                createOrUpdateIndex(i);
            } else {
                createOrUpdateIndex(searchIndex);
            }
        }
    }

    public void createOrUpdateIndex(int i) {
        SearchIndex fields;
        if (!this.createOrUpdateIndex) {
            throw new IllegalArgumentException("createOrUpdateIndex is false, so the index cannot be created or updated");
        }
        if (i == 0) {
            log.info("Dimensions is 0, so the index will only be created for full text search");
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(new SearchField(DEFAULT_FIELD_ID, SearchFieldDataType.STRING).setKey(true).setFilterable(true));
        arrayList.add(new SearchField(DEFAULT_FIELD_CONTENT, SearchFieldDataType.STRING).setSearchable(true).setFilterable(true));
        if (i > 0) {
            arrayList.add(new SearchField("content_vector", SearchFieldDataType.collection(SearchFieldDataType.SINGLE)).setSearchable(true).setVectorSearchDimensions(Integer.valueOf(i)).setVectorSearchProfileName(VECTOR_SEARCH_PROFILE_NAME));
        }
        arrayList.add(new SearchField(DEFAULT_FIELD_METADATA, SearchFieldDataType.COMPLEX).setFields(Arrays.asList(new SearchField(DEFAULT_FIELD_METADATA_SOURCE, SearchFieldDataType.STRING).setFilterable(true), new SearchField(DEFAULT_FIELD_METADATA_ATTRS, SearchFieldDataType.collection(SearchFieldDataType.COMPLEX)).setFields(Arrays.asList(new SearchField("key", SearchFieldDataType.STRING).setFilterable(true), new SearchField("value", SearchFieldDataType.STRING).setFilterable(true))))));
        if (i > 0) {
            fields = new SearchIndex(this.indexName).setFields(arrayList).setVectorSearch(new VectorSearch().setAlgorithms(Collections.singletonList(new HnswAlgorithmConfiguration(VECTOR_ALGORITHM_NAME).setParameters(new HnswParameters().setMetric(VectorSearchAlgorithmMetric.COSINE).setM(4).setEfSearch(500).setEfConstruction(400)))).setProfiles(Collections.singletonList(new VectorSearchProfile(VECTOR_SEARCH_PROFILE_NAME, VECTOR_ALGORITHM_NAME)))).setSemanticSearch(new SemanticSearch().setDefaultConfigurationName(SEMANTIC_SEARCH_CONFIG_NAME).setConfigurations(Collections.singletonList(new SemanticConfiguration(SEMANTIC_SEARCH_CONFIG_NAME, new SemanticPrioritizedFields().setContentFields(new SemanticField[]{new SemanticField(DEFAULT_FIELD_CONTENT)}).setKeywordsFields(new SemanticField[]{new SemanticField(DEFAULT_FIELD_CONTENT)})))));
        } else {
            fields = new SearchIndex(this.indexName).setFields(arrayList);
        }
        this.searchIndexClient.createOrUpdateIndex(fields);
    }

    void createOrUpdateIndex(SearchIndex searchIndex) {
        if (!this.createOrUpdateIndex) {
            throw new IllegalArgumentException("createOrUpdateIndex is false, so the index cannot be created or updated");
        }
        this.searchIndexClient.createOrUpdateIndex(searchIndex);
    }

    public void deleteIndex() {
        if (!this.createOrUpdateIndex) {
            throw new IllegalArgumentException("createOrUpdateIndex is false, so the index cannot be deleted");
        }
        this.searchIndexClient.deleteIndex(this.indexName);
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, null);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list2, list, null);
        return list2;
    }

    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        List<String> list3 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list3, list, list2);
        return list3;
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        EmbeddingMatch embeddingMatch;
        SearchPagedIterable search = this.searchClient.search((String) null, new SearchOptions().setFilter(this.filterMapper.map(embeddingSearchRequest.filter())).setVectorSearchOptions(new VectorSearchOptions().setQueries(new VectorQuery[]{new VectorizedQuery(embeddingSearchRequest.queryEmbedding().vectorAsList()).setFields(new String[]{"content_vector"}).setKNearestNeighborsCount(Integer.valueOf(embeddingSearchRequest.maxResults()))})), Context.NONE);
        ArrayList arrayList = new ArrayList();
        Iterator it = search.iterator();
        while (it.hasNext()) {
            SearchResult searchResult = (SearchResult) it.next();
            Double valueOf = Double.valueOf(fromAzureScoreToRelevanceScore(searchResult.getScore()));
            if (valueOf.doubleValue() >= embeddingSearchRequest.minScore()) {
                SearchDocument searchDocument = (SearchDocument) searchResult.getDocument(SearchDocument.class);
                String str = (String) searchDocument.get(DEFAULT_FIELD_ID);
                List<Double> list = (List) searchDocument.get("content_vector");
                Embedding from = list != null ? Embedding.from(doublesListToFloatArray(list)) : null;
                String str2 = (String) searchDocument.get(DEFAULT_FIELD_CONTENT);
                if (Utils.isNotNullOrBlank(str2)) {
                    List<LinkedHashMap> list2 = (List) ((LinkedHashMap) searchDocument.get(DEFAULT_FIELD_METADATA)).get(DEFAULT_FIELD_METADATA_ATTRS);
                    HashMap hashMap = new HashMap();
                    for (LinkedHashMap linkedHashMap : list2) {
                        hashMap.put((String) linkedHashMap.get("key"), (String) linkedHashMap.get("value"));
                    }
                    embeddingMatch = new EmbeddingMatch(valueOf, str, from, TextSegment.textSegment(str2, Metadata.from(hashMap)));
                } else {
                    embeddingMatch = new EmbeddingMatch(valueOf, str, from, (Object) null);
                }
                arrayList.add(embeddingMatch);
            }
        }
        return new EmbeddingSearchResult<>(arrayList);
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAllInternal(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    private void addAllInternal(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            log.info("Empty embeddings - no ops");
            return;
        }
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue(list3 == null || list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            Document document = new Document();
            document.setId(list.get(i));
            document.setContentVector(list2.get(i).vectorAsList());
            if (list3 != null) {
                document.setContent(list3.get(i).text());
                Document.Metadata metadata = new Document.Metadata();
                ArrayList arrayList2 = new ArrayList();
                for (Map.Entry entry : list3.get(i).metadata().toMap().entrySet()) {
                    Document.Metadata.Attribute attribute = new Document.Metadata.Attribute();
                    attribute.setKey((String) entry.getKey());
                    attribute.setValue(String.valueOf(entry.getValue()));
                    arrayList2.add(attribute);
                }
                metadata.setAttributes(arrayList2);
                document.setMetadata(metadata);
            }
            arrayList.add(document);
        }
        for (IndexingResult indexingResult : this.searchClient.uploadDocuments(arrayList).getResults()) {
            if (!indexingResult.isSucceeded()) {
                throw new AzureAiSearchRuntimeException("Failed to add embedding: " + indexingResult.getErrorMessage());
            }
            log.debug("Added embedding: {}", indexingResult.getKey());
        }
    }

    float[] doublesListToFloatArray(List<Double> list) {
        float[] fArr = new float[list.size()];
        for (int i = 0; i < list.size(); i++) {
            fArr[i] = list.get(i).floatValue();
        }
        return fArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double fromAzureScoreToRelevanceScore(double d) {
        return RelevanceScore.fromCosineSimilarity((-((1.0d - d) / d)) + 1.0d);
    }
}
