package dev.langchain4j.store.embedding.mariadb;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.sql.DataSource;
import org.mariadb.jdbc.MariaDbDataSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/store/embedding/mariadb/MariaDbEmbeddingStore.class */
public class MariaDbEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(MariaDbEmbeddingStore.class);
    private final DataSource datasource;
    private final String table;
    private final MariaDBDistanceType distanceType;
    private final String idFieldName;
    private final String embeddingFieldName;
    private final String contentFieldName;
    public static final String DEFAULT_TABLE_NAME = "vector_store";
    public static final String DEFAULT_COLUMN_EMBEDDING = "embedding";
    public static final String DEFAULT_COLUMN_ID = "id";
    public static final String DEFAULT_COLUMN_CONTENT = "content";
    final MetadataHandler metadataHandler;

    /* loaded from: input_file:dev/langchain4j/store/embedding/mariadb/MariaDbEmbeddingStore$Builder.class */
    public static final class Builder {
        private String table;
        private MariaDBDistanceType distanceType;
        private String idFieldName;
        private String embeddingFieldName;
        private String contentFieldName;
        private MetadataStorageConfig metadataStorageConfig;
        private boolean dropTableFirst;
        private boolean createTable = false;
        private Integer dimension;
        private DataSource datasource;
        private String url;
        private String user;
        private String password;

        public Builder url(String str) {
            this.url = str;
            return this;
        }

        public Builder user(String str) {
            this.user = str;
            return this;
        }

        public Builder password(String str) {
            this.password = str;
            return this;
        }

        public Builder datasource(DataSource dataSource) {
            this.datasource = dataSource;
            return this;
        }

        public Builder table(String str) {
            this.table = str;
            return this;
        }

        public Builder distanceType(MariaDBDistanceType mariaDBDistanceType) {
            this.distanceType = mariaDBDistanceType;
            return this;
        }

        public Builder idFieldName(String str) {
            this.idFieldName = str;
            return this;
        }

        public Builder embeddingFieldName(String str) {
            this.embeddingFieldName = str;
            return this;
        }

        public Builder contentFieldName(String str) {
            this.contentFieldName = str;
            return this;
        }

        public Builder metadataStorageConfig(MetadataStorageConfig metadataStorageConfig) {
            this.metadataStorageConfig = metadataStorageConfig;
            return this;
        }

        public Builder dropTableFirst(boolean z) {
            this.dropTableFirst = z;
            return this;
        }

        public Builder createTable(boolean z) {
            this.createTable = z;
            return this;
        }

        public Builder dimension(Integer num) {
            this.dimension = num;
            return this;
        }

        public MariaDbEmbeddingStore build() {
            if (this.datasource == null) {
                if (this.url == null) {
                    throw new IllegalArgumentException("set datasource or url ");
                }
                MariaDbDataSource mariaDbDataSource = new MariaDbDataSource();
                try {
                    mariaDbDataSource.setUrl(this.url);
                    mariaDbDataSource.setUser(this.user);
                    mariaDbDataSource.setPassword(this.password);
                    this.datasource = mariaDbDataSource;
                } catch (SQLException e) {
                    throw new IllegalArgumentException("Wrong url configuring builder: '%s'".formatted(this.url), e);
                }
            }
            return new MariaDbEmbeddingStore(this.datasource, this);
        }
    }

    private MariaDbEmbeddingStore(DataSource dataSource, Builder builder) {
        this.datasource = (DataSource) ValidationUtils.ensureNotNull(dataSource, "datasource");
        this.table = validateAndEnquoteIdentifier(builder.table, DEFAULT_TABLE_NAME);
        this.contentFieldName = validateAndEnquoteIdentifier(builder.contentFieldName, DEFAULT_COLUMN_CONTENT);
        this.embeddingFieldName = validateAndEnquoteIdentifier(builder.embeddingFieldName, DEFAULT_COLUMN_EMBEDDING);
        this.idFieldName = validateAndEnquoteIdentifier(builder.idFieldName, DEFAULT_COLUMN_ID);
        this.metadataHandler = MetadataHandlerFactory.get((MetadataStorageConfig) Utils.getOrDefault(builder.metadataStorageConfig, DefaultMetadataStorageConfig.defaultConfig()), this.datasource);
        this.distanceType = builder.distanceType == null ? MariaDBDistanceType.COSINE : builder.distanceType;
        initTable(builder.dropTableFirst, builder.createTable, ((Integer) ValidationUtils.ensureNotNull(builder.dimension, "dimension")).intValue());
    }

    private String validateAndEnquoteIdentifier(String str, String str2) {
        return (str == null || str.isEmpty()) ? str2 : MariaDbValidator.validateAndEnquoteIdentifier(str, false);
    }

    protected void initTable(boolean z, boolean z2, int i) {
        try {
            Connection connection = this.datasource.getConnection();
            try {
                Statement createStatement = connection.createStatement();
                if (z) {
                    try {
                        createStatement.executeUpdate("DROP TABLE IF EXISTS " + this.table);
                    } catch (Throwable th) {
                        if (createStatement != null) {
                            try {
                                createStatement.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                if (z2) {
                    createStatement.executeUpdate(String.format("CREATE TABLE IF NOT EXISTS %s (%s UUID NOT NULL DEFAULT uuid() PRIMARY KEY, %s VECTOR(%s) NOT NULL, %s TEXT NULL, %s, VECTOR INDEX %s_idx (%s) ) ENGINE=InnoDB COLLATE uca1400_ai_cs", this.table, this.idFieldName, this.embeddingFieldName, Integer.valueOf(ValidationUtils.ensureGreaterThanZero(Integer.valueOf(i), "dimension")), this.contentFieldName, this.metadataHandler.columnDefinitionsString(), (this.table + "_" + this.embeddingFieldName).replaceAll("[ \\`\"'\\\\\\P{Print}]", ""), this.embeddingFieldName));
                    this.metadataHandler.createMetadataIndexes(createStatement, this.table);
                }
                if (createStatement != null) {
                    createStatement.close();
                }
                if (connection != null) {
                    connection.close();
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(String.format("Failed to execute '%s'", "init"), e);
        }
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, null);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).toList();
        addAllInternal(list2, list, null);
        return list2;
    }

    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        List<String> list3 = list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).toList();
        addAllInternal(list3, list, list2);
        return list3;
    }

    public void addAll(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        addAllInternal(list, list2, list3);
    }

    public void removeAll(Collection<String> collection) {
        ValidationUtils.ensureNotEmpty(collection, "ids");
        try {
            Connection connection = this.datasource.getConnection();
            try {
                Statement createStatement = connection.createStatement();
                try {
                    createStatement.executeUpdate(String.format("DELETE FROM %s WHERE %s IN (%s)", this.table, this.idFieldName, (String) collection.stream().map(UUID::fromString).map(uuid -> {
                        return "'" + String.valueOf(uuid) + "'";
                    }).collect(Collectors.joining(","))));
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    public void removeAll(Filter filter) {
        ValidationUtils.ensureNotNull(filter, "filter");
        String format = String.format("DELETE FROM %s WHERE %s", this.table, this.metadataHandler.whereClause(filter));
        try {
            Connection connection = this.datasource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(format);
                try {
                    prepareStatement.executeUpdate();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    public void removeAll() {
        try {
            Connection connection = this.datasource.getConnection();
            try {
                Statement createStatement = connection.createStatement();
                try {
                    createStatement.executeUpdate(String.format("TRUNCATE TABLE %s", this.table));
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        String whereClause;
        Embedding queryEmbedding = embeddingSearchRequest.queryEmbedding();
        int maxResults = embeddingSearchRequest.maxResults();
        double minScore = embeddingSearchRequest.minScore();
        Filter filter = embeddingSearchRequest.filter();
        ArrayList arrayList = new ArrayList();
        try {
            Connection connection = this.datasource.getConnection();
            if (filter != null) {
                try {
                    whereClause = this.metadataHandler.whereClause(filter);
                } finally {
                }
            } else {
                whereClause = null;
            }
            String str = whereClause;
            String str2 = "";
            if (str != null && !str.isEmpty()) {
                str2 = "and " + str + " ";
            }
            PreparedStatement prepareStatement = connection.prepareStatement(String.format("SELECT * FROM (select %s, %s, %s, (2 - vec_distance_%s(%s, ?)) / 2 as score, %s from %s) as t where score >= ? %sorder by score desc LIMIT %s", this.idFieldName, this.embeddingFieldName, this.contentFieldName, this.distanceType.name().toLowerCase(Locale.ROOT), this.embeddingFieldName, String.join(",", this.metadataHandler.escapedColumnsName()), this.table, str2, Integer.valueOf(maxResults)));
            try {
                prepareStatement.setObject(1, queryEmbedding.vector());
                prepareStatement.setDouble(2, minScore);
                ResultSet executeQuery = prepareStatement.executeQuery();
                while (executeQuery.next()) {
                    try {
                        String string = executeQuery.getString(1);
                        Embedding embedding = new Embedding((float[]) executeQuery.getObject(2, float[].class));
                        String string2 = executeQuery.getString(3);
                        double d = executeQuery.getDouble(4);
                        TextSegment textSegment = null;
                        if (Utils.isNotNullOrBlank(string2)) {
                            textSegment = TextSegment.from(string2, this.metadataHandler.fromResultSet(executeQuery));
                        }
                        arrayList.add(new EmbeddingMatch(Double.valueOf(d), string, embedding, textSegment));
                    } catch (Throwable th) {
                        if (executeQuery != null) {
                            try {
                                executeQuery.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                if (executeQuery != null) {
                    executeQuery.close();
                }
                if (prepareStatement != null) {
                    prepareStatement.close();
                }
                if (connection != null) {
                    connection.close();
                }
                return new EmbeddingSearchResult<>(arrayList);
            } catch (Throwable th3) {
                if (prepareStatement != null) {
                    try {
                        prepareStatement.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                }
                throw th3;
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAllInternal(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    private void addAllInternal(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            log.info("Empty embeddings - no ops");
            return;
        }
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue(list3 == null || list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        try {
            Connection connection = this.datasource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(String.format("INSERT INTO %s (%s, %s, %s, %s) VALUES (?, ?, ?, %s) ON DUPLICATE KEY UPDATE %s = VALUES(%s), %s = VALUES(%s)%s", this.table, this.idFieldName, this.embeddingFieldName, this.contentFieldName, String.join(",", this.metadataHandler.escapedColumnsName()), String.join(",", Collections.nCopies(this.metadataHandler.escapedColumnsName().size(), "?")), this.embeddingFieldName, this.embeddingFieldName, this.contentFieldName, this.contentFieldName, this.metadataHandler.insertClause()));
                for (int i = 0; i < list.size(); i++) {
                    try {
                        prepareStatement.setString(1, list.get(i));
                        prepareStatement.setObject(2, list2.get(i).vector());
                        if (list3 == null || list3.get(i) == null) {
                            prepareStatement.setNull(3, 12);
                            IntStream.range(4, 4 + this.metadataHandler.escapedColumnsName().size()).forEach(i2 -> {
                                try {
                                    prepareStatement.setNull(i2, 1111);
                                } catch (SQLException e) {
                                    throw new RuntimeException(e);
                                }
                            });
                        } else {
                            prepareStatement.setString(3, list3.get(i).text());
                            this.metadataHandler.setMetadata(prepareStatement, 4, list3.get(i).metadata());
                        }
                        prepareStatement.addBatch();
                    } catch (Throwable th) {
                        if (prepareStatement != null) {
                            try {
                                prepareStatement.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                prepareStatement.executeBatch();
                if (prepareStatement != null) {
                    prepareStatement.close();
                }
                if (connection != null) {
                    connection.close();
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static Builder builder() {
        return new Builder();
    }
}
