package com.huaweicloud.pangu.dev.sdk.vectorstore;

import com.alibaba.fastjson.JSON;
import com.huaweicloud.pangu.dev.sdk.api.memory.bo.Document;
import com.huaweicloud.pangu.dev.sdk.api.memory.config.VectorStoreConfig;
import com.huaweicloud.pangu.dev.sdk.exception.PanguDevSDKException;
import com.huaweicloud.pangu.dev.sdk.utils.CommonUtil;
import com.huaweicloud.pangu.dev.sdk.utils.SecurityUtil;
import java.io.IOException;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
import org.apache.http.nio.conn.ssl.SSLIOSessionStrategy;
import org.apache.http.ssl.SSLContextBuilder;
import org.apache.http.ssl.TrustStrategy;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.indices.CreateIndexRequest;
import org.elasticsearch.client.indices.GetIndexRequest;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.mapper.TextFieldMapper;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/huaweicloud/pangu/dev/sdk/vectorstore/CSSVectorStore.class */
public class CSSVectorStore extends VectorStore {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) CSSVectorStore.class);
    private static final String STORE_TYPE = "vector";
    protected VectorStoreConfig vectorStoreConfig;
    protected RestHighLevelClient client;

    public CSSVectorStore(VectorStoreConfig vectorStoreConfig) {
        super(vectorStoreConfig);
        this.vectorStoreConfig = vectorStoreConfig;
        RestClientBuilder builder = RestClient.builder(constructHttpHosts(Arrays.asList(this.vectorStoreConfig.getServerInfo().getUrl().split(","))));
        if (StringUtils.isNotEmpty(this.vectorStoreConfig.getServerInfo().getUser()) && StringUtils.isNotEmpty(this.vectorStoreConfig.getServerInfo().getPassword())) {
            builder.setHttpClientConfigCallback(getHttpClientConfigCallback(this.vectorStoreConfig.getServerInfo().getUser(), this.vectorStoreConfig.getServerInfo().getPassword()));
        }
        this.client = new RestHighLevelClient(builder);
    }

    @Override // com.huaweicloud.pangu.dev.sdk.vectorstore.VectorStore
    protected List<String> addTextsFull(List<String> list, List<Map<String, Object>> list2) {
        return bulkCreate(list, list2, this.vectorStoreConfig.getEmbedding().embedDocuments(list));
    }

    private List<String> bulkCreate(List<String> list, List<Map<String, Object>> list2, List<List<Float>> list3) {
        ArrayList arrayList = new ArrayList();
        try {
            createIndex(Integer.valueOf(list3.get(0).size()));
            BulkRequest bulkRequest = new BulkRequest();
            for (int i = 0; i < list3.size(); i++) {
                HashMap hashMap = new HashMap();
                hashMap.put(this.vectorStoreConfig.getTextKey(), list.get(i));
                hashMap.put(this.vectorStoreConfig.getVectorKey(), list3.get(i));
                if (list2 != null && i < list2.size()) {
                    hashMap.put(this.vectorStoreConfig.getMetadataKey(), list2.get(i));
                }
                String uuid = SecurityUtil.getUUID();
                bulkRequest.add(((IndexRequest) new IndexRequest().index(this.vectorStoreConfig.getIndexName())).id(uuid).source(hashMap));
                arrayList.add(uuid);
            }
            this.client.bulk(bulkRequest, RequestOptions.DEFAULT);
            this.client.indices().refresh(new RefreshRequest(this.vectorStoreConfig.getIndexName()), RequestOptions.DEFAULT);
        } catch (IOException e) {
            log.error("add text failed");
        }
        return arrayList;
    }

    @Override // com.huaweicloud.pangu.dev.sdk.vectorstore.VectorStore
    protected List<String> addQATextsFull(List<Map<String, String>> list, Map<String, Integer> map) {
        List<List<Float>> embedQADocuments = this.vectorStoreConfig.getEmbedding().embedQADocuments(list, map);
        ArrayList arrayList = new ArrayList();
        list.forEach(map2 -> {
            arrayList.add(JSON.toJSONString(map2));
        });
        return bulkCreate(arrayList, null, embedQADocuments);
    }

    @Override // com.huaweicloud.pangu.dev.sdk.api.memory.vector.Vector
    public List<Document> similaritySearch(String str, int i, float f) {
        List<Document> similaritySearchWithScore = similaritySearchWithScore(str, i);
        log.debug("query = {}, search result = {}", str, similaritySearchWithScore);
        similaritySearchWithScore.removeIf(document -> {
            return document.getScore() < ((double) f);
        });
        return similaritySearchWithScore;
    }

    private List<Document> similaritySearchWithScore(String str, int i) {
        List<Float> embedQuery = this.vectorStoreConfig.getEmbedding().embedQuery(str);
        SearchRequest searchRequest = new SearchRequest();
        searchRequest.indices(this.vectorStoreConfig.getIndexName());
        HashMap hashMap = new HashMap();
        hashMap.put("field", this.vectorStoreConfig.getVectorKey());
        hashMap.put(this.vectorStoreConfig.getVectorKey(), embedQuery);
        hashMap.put("metric", this.vectorStoreConfig.getDistanceStrategy().getText());
        searchRequest.source(new SearchSourceBuilder().query(QueryBuilders.scriptScoreQuery(QueryBuilders.matchAllQuery(), new Script(ScriptType.INLINE, STORE_TYPE, "vector_score", hashMap))).size(i));
        ArrayList arrayList = new ArrayList();
        try {
            for (SearchHit searchHit : this.client.search(searchRequest, RequestOptions.DEFAULT).getHits().getHits()) {
                Map<String, Object> sourceAsMap = searchHit.getSourceAsMap();
                arrayList.add(Document.builder().pageContent((String) sourceAsMap.get(this.vectorStoreConfig.getTextKey())).score(r0.getScore()).metadata((Map) sourceAsMap.get(this.vectorStoreConfig.getMetadataKey())).build());
            }
        } catch (IOException | ElasticsearchStatusException e) {
            log.error("search data error: {}", e.getMessage());
        }
        return arrayList;
    }

    private HttpHost[] constructHttpHosts(List<String> list) {
        return (HttpHost[]) list.stream().map(str -> {
            return new HttpHost(CommonUtil.parseUri(str).getHost(), CommonUtil.parseUri(str).getPort(), CommonUtil.parseUri(str).getScheme());
        }).toArray(i -> {
            return new HttpHost[i];
        });
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean indexExists() {
        try {
            return this.client.indices().exists(new GetIndexRequest(this.vectorStoreConfig.getIndexName()), RequestOptions.DEFAULT);
        } catch (IOException e) {
            log.debug("the index does not exist", (Throwable) e);
            return false;
        }
    }

    private boolean createIndex(Integer num) {
        if (indexExists()) {
            return true;
        }
        try {
            CreateIndexRequest createIndexRequest = new CreateIndexRequest(this.vectorStoreConfig.getIndexName());
            createIndexRequest.settings(Settings.builder().put("index.vector", true));
            HashMap hashMap = new HashMap();
            hashMap.put("type", STORE_TYPE);
            hashMap.put("dimension", num);
            hashMap.put("indexing", true);
            hashMap.put("metric", this.vectorStoreConfig.getDistanceStrategy().getText());
            HashMap hashMap2 = new HashMap();
            hashMap2.put("type", TextFieldMapper.CONTENT_TYPE);
            HashMap hashMap3 = new HashMap();
            hashMap3.put(this.vectorStoreConfig.getVectorKey(), hashMap);
            hashMap3.put(this.vectorStoreConfig.getTextKey(), hashMap2);
            HashMap hashMap4 = new HashMap();
            hashMap4.put("properties", hashMap3);
            createIndexRequest.mapping(hashMap4);
            this.client.indices().create(createIndexRequest, RequestOptions.DEFAULT);
            return true;
        } catch (IOException e) {
            log.error("create index fail");
            return false;
        }
    }

    @Override // com.huaweicloud.pangu.dev.sdk.vectorstore.VectorStore
    protected boolean delIndex() {
        if (!indexExists()) {
            return true;
        }
        try {
            this.client.indices().delete(new DeleteIndexRequest(this.vectorStoreConfig.getIndexName()), RequestOptions.DEFAULT);
            return true;
        } catch (IOException e) {
            log.error("delete index fail;");
            return false;
        }
    }

    private RestClientBuilder.HttpClientConfigCallback getHttpClientConfigCallback(String str, String str2) {
        try {
            final BasicCredentialsProvider basicCredentialsProvider = new BasicCredentialsProvider();
            basicCredentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(str, str2));
            final SSLIOSessionStrategy sSLIOSessionStrategy = new SSLIOSessionStrategy(new SSLContextBuilder().loadTrustMaterial((KeyStore) null, new TrustStrategy() { // from class: com.huaweicloud.pangu.dev.sdk.vectorstore.CSSVectorStore.1
                @Override // org.apache.http.ssl.TrustStrategy
                public boolean isTrusted(X509Certificate[] x509CertificateArr, String str3) throws CertificateException {
                    return true;
                }
            }).build(), NoopHostnameVerifier.INSTANCE);
            return new RestClientBuilder.HttpClientConfigCallback() { // from class: com.huaweicloud.pangu.dev.sdk.vectorstore.CSSVectorStore.2
                @Override // org.elasticsearch.client.RestClientBuilder.HttpClientConfigCallback
                public HttpAsyncClientBuilder customizeHttpClient(HttpAsyncClientBuilder httpAsyncClientBuilder) {
                    httpAsyncClientBuilder.disableAuthCaching();
                    httpAsyncClientBuilder.setSSLStrategy(sSLIOSessionStrategy);
                    httpAsyncClientBuilder.setDefaultCredentialsProvider(basicCredentialsProvider);
                    return httpAsyncClientBuilder;
                }
            };
        } catch (KeyManagementException | KeyStoreException | NoSuchAlgorithmException e) {
            throw new PanguDevSDKException("ssl config failed; " + e.getMessage());
        }
    }
}
