package dev.langchain4j.store.embedding.oracle;

import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import java.sql.BatchUpdateException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.SQLType;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import javax.sql.DataSource;
import oracle.jdbc.OracleStatement;
import oracle.jdbc.OracleType;
import oracle.sql.json.OracleJsonDecimal;
import oracle.sql.json.OracleJsonFactory;
import oracle.sql.json.OracleJsonObject;
import oracle.sql.json.OracleJsonValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/store/embedding/oracle/OracleEmbeddingStore.class */
public final class OracleEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(OracleEmbeddingStore.class);
    private final DataSource dataSource;
    private final EmbeddingTable table;
    private final boolean isExactSearch;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: dev.langchain4j.store.embedding.oracle.OracleEmbeddingStore$1, reason: invalid class name */
    /* loaded from: input_file:dev/langchain4j/store/embedding/oracle/OracleEmbeddingStore$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$oracle$sql$json$OracleJsonDecimal$TargetType;
        static final /* synthetic */ int[] $SwitchMap$oracle$sql$json$OracleJsonValue$OracleJsonType = new int[OracleJsonValue.OracleJsonType.values().length];

        static {
            try {
                $SwitchMap$oracle$sql$json$OracleJsonValue$OracleJsonType[OracleJsonValue.OracleJsonType.STRING.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$oracle$sql$json$OracleJsonValue$OracleJsonType[OracleJsonValue.OracleJsonType.DECIMAL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$oracle$sql$json$OracleJsonValue$OracleJsonType[OracleJsonValue.OracleJsonType.FLOAT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$oracle$sql$json$OracleJsonValue$OracleJsonType[OracleJsonValue.OracleJsonType.DOUBLE.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            $SwitchMap$oracle$sql$json$OracleJsonDecimal$TargetType = new int[OracleJsonDecimal.TargetType.values().length];
            try {
                $SwitchMap$oracle$sql$json$OracleJsonDecimal$TargetType[OracleJsonDecimal.TargetType.INT.ordinal()] = 1;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$oracle$sql$json$OracleJsonDecimal$TargetType[OracleJsonDecimal.TargetType.LONG.ordinal()] = 2;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    /* loaded from: input_file:dev/langchain4j/store/embedding/oracle/OracleEmbeddingStore$Builder.class */
    public static class Builder {
        private DataSource dataSource;
        private EmbeddingTable embeddingTable;
        private boolean isExactSearch = false;
        private Index[] indexes;

        private Builder() {
        }

        public Builder dataSource(DataSource dataSource) {
            this.dataSource = (DataSource) ValidationUtils.ensureNotNull(dataSource, "dataSource");
            return this;
        }

        public Builder embeddingTable(String str) {
            return embeddingTable(str, CreateOption.CREATE_NONE);
        }

        public Builder embeddingTable(String str, CreateOption createOption) {
            ValidationUtils.ensureNotNull(str, "tableName");
            ValidationUtils.ensureNotNull(createOption, "createOption");
            return embeddingTable(EmbeddingTable.builder().name(str).createOption(createOption).build());
        }

        public Builder embeddingTable(EmbeddingTable embeddingTable) {
            ValidationUtils.ensureNotNull(embeddingTable, "embeddingTable");
            this.embeddingTable = embeddingTable;
            return this;
        }

        public Builder vectorIndex(CreateOption createOption) {
            return index(((IVFIndexBuilder) Index.ivfIndexBuilder().createOption(createOption)).build());
        }

        public Builder index(Index... indexArr) {
            this.indexes = indexArr;
            return this;
        }

        public Builder exactSearch(boolean z) {
            this.isExactSearch = z;
            return this;
        }

        public OracleEmbeddingStore build() {
            ValidationUtils.ensureNotNull(this.dataSource, "dataSource");
            ValidationUtils.ensureNotNull(this.embeddingTable, "embeddingTable");
            return new OracleEmbeddingStore(this);
        }
    }

    private OracleEmbeddingStore(Builder builder) {
        this.dataSource = builder.dataSource;
        this.table = builder.embeddingTable;
        this.isExactSearch = builder.isExactSearch;
        try {
            this.table.create(this.dataSource);
            createIndex(builder);
        } catch (SQLException e) {
            throw uncheckSQLException(e);
        }
    }

    private static void createIndex(Builder builder) throws SQLException {
        if (builder.indexes != null) {
            try {
                for (Index index : builder.indexes) {
                    index.create(builder.dataSource, builder.embeddingTable);
                }
            } catch (SQLException e) {
                throw uncheckSQLException(e);
            }
        }
    }

    public String add(Embedding embedding) {
        ValidationUtils.ensureNotNull(embedding, "embedding");
        return addAll(Collections.singletonList(embedding)).get(0);
    }

    public List<String> addAll(List<Embedding> list) {
        ValidationUtils.ensureNotNull(list, "embeddings");
        String[] strArr = new String[list.size()];
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement("INSERT INTO " + this.table.name() + "(" + this.table.idColumn() + ", " + this.table.embeddingColumn() + ") VALUES (?, ?)");
                for (int i = 0; i < list.size(); i++) {
                    try {
                        String randomUUID = Utils.randomUUID();
                        strArr[i] = randomUUID;
                        Embedding embedding = (Embedding) ensureIndexNotNull(list, i, "embeddings");
                        prepareStatement.setString(1, randomUUID);
                        prepareStatement.setObject(2, (Object) embedding.vector(), (SQLType) OracleType.VECTOR_FLOAT32);
                        prepareStatement.addBatch();
                    } catch (Throwable th) {
                        if (prepareStatement != null) {
                            try {
                                prepareStatement.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                prepareStatement.executeBatch();
                if (prepareStatement != null) {
                    prepareStatement.close();
                }
                if (connection != null) {
                    connection.close();
                }
                return Arrays.asList(strArr);
            } finally {
            }
        } catch (SQLException e) {
            throw uncheckSQLException(e);
        }
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        ValidationUtils.ensureNotNull(embedding, "embedding");
        ValidationUtils.ensureNotNull(textSegment, "textSegment");
        return (String) addAll(Collections.singletonList(embedding), Collections.singletonList(textSegment)).get(0);
    }

    public void addAll(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        ValidationUtils.ensureNotNull(list2, "embeddings");
        ValidationUtils.ensureNotNull(list3, "embedded");
        if (list2.size() != list3.size()) {
            throw new IllegalArgumentException("embeddings.size() " + list2.size() + " is not equal to embedded.size() " + list3.size());
        }
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement("INSERT INTO " + this.table.name() + "(" + String.join(", ", this.table.idColumn(), this.table.embeddingColumn(), this.table.textColumn(), this.table.metadataColumn()) + ") VALUES (?, ?, ?, ?)");
                for (int i = 0; i < list2.size(); i++) {
                    try {
                        Embedding embedding = (Embedding) ensureIndexNotNull(list2, i, "embeddings");
                        TextSegment textSegment = (TextSegment) ensureIndexNotNull(list3, i, "embedded");
                        prepareStatement.setString(1, list.get(i));
                        prepareStatement.setObject(2, (Object) embedding.vector(), (SQLType) OracleType.VECTOR_FLOAT32);
                        prepareStatement.setObject(3, textSegment.text());
                        prepareStatement.setObject(4, getOsonFromMetadata(textSegment.metadata()));
                        prepareStatement.addBatch();
                    } catch (Throwable th) {
                        if (prepareStatement != null) {
                            try {
                                prepareStatement.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                prepareStatement.executeBatch();
                if (prepareStatement != null) {
                    prepareStatement.close();
                }
                if (connection != null) {
                    connection.close();
                }
            } finally {
            }
        } catch (SQLException e) {
            throw uncheckSQLException(e);
        }
    }

    public void add(String str, Embedding embedding) {
        ValidationUtils.ensureNotNull(str, "id");
        ValidationUtils.ensureNotNull(embedding, "embedding");
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement("MERGE INTO " + this.table.name() + " existing USING (SELECT ? as id, ? as embedding) new ON (new.id = existing." + this.table.idColumn() + ") WHEN MATCHED THEN UPDATE SET existing." + this.table.embeddingColumn() + " = new.embedding WHEN NOT MATCHED THEN INSERT (" + this.table.idColumn() + ", " + this.table.embeddingColumn() + ") VALUES (new.id, new.embedding)");
                try {
                    prepareStatement.setString(1, str);
                    prepareStatement.setObject(2, (Object) embedding.vector(), (SQLType) OracleType.VECTOR_FLOAT32);
                    prepareStatement.execute();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw uncheckSQLException(e);
        }
    }

    public void removeAll(Collection<String> collection) {
        ValidationUtils.ensureNotEmpty(collection, "ids");
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement("DELETE FROM " + this.table.name() + " WHERE " + this.table.idColumn() + " = ?");
                try {
                    for (String str : collection) {
                        ValidationUtils.ensureNotNull(str, "id");
                        prepareStatement.setString(1, str);
                        prepareStatement.addBatch();
                    }
                    prepareStatement.executeBatch();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw uncheckSQLException(e);
        }
    }

    public void removeAll(Filter filter) {
        ValidationUtils.ensureNotNull(filter, "filter");
        EmbeddingTable embeddingTable = this.table;
        Objects.requireNonNull(embeddingTable);
        SQLFilter create = SQLFilters.create(filter, embeddingTable::mapMetadataKey);
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement("DELETE FROM " + this.table.name() + create.asWhereClause());
                try {
                    create.setParameters(prepareStatement, 1);
                    prepareStatement.executeUpdate();
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw uncheckSQLException(e);
        }
    }

    public void removeAll() {
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                Statement createStatement = connection.createStatement();
                try {
                    createStatement.execute("TRUNCATE TABLE " + this.table.name());
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw uncheckSQLException(e);
        }
    }

    public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
        ValidationUtils.ensureNotNull(embeddingSearchRequest, "request");
        Filter filter = embeddingSearchRequest.filter();
        EmbeddingTable embeddingTable = this.table;
        Objects.requireNonNull(embeddingTable);
        SQLFilter create = SQLFilters.create(filter, embeddingTable::mapMetadataKey);
        int maxResults = embeddingSearchRequest.maxResults();
        try {
            Connection connection = this.dataSource.getConnection();
            try {
                PreparedStatement prepareStatement = connection.prepareStatement("SELECT VECTOR_DISTANCE(" + this.table.embeddingColumn() + ", ?, COSINE) distance, " + String.join(", ", this.table.idColumn(), this.table.embeddingColumn(), this.table.textColumn(), this.table.metadataColumn()) + " FROM " + this.table.name() + create.asWhereClause() + " ORDER BY distance FETCH " + (this.isExactSearch ? "" : " APPROXIMATE") + " FIRST " + maxResults + " ROWS ONLY");
                try {
                    prepareStatement.setObject(1, embeddingSearchRequest.queryEmbedding().vector(), -107);
                    create.setParameters(prepareStatement, 2);
                    prepareStatement.setFetchSize(maxResults);
                    OracleStatement oracleStatement = (OracleStatement) prepareStatement.unwrap(OracleStatement.class);
                    oracleStatement.defineColumnType(1, 101);
                    oracleStatement.defineColumnType(2, 12);
                    oracleStatement.defineColumnType(3, -107, Integer.MAX_VALUE);
                    oracleStatement.defineColumnType(4, 2005, Integer.MAX_VALUE);
                    oracleStatement.defineColumnType(5, 2016, Integer.MAX_VALUE);
                    oracleStatement.setLobPrefetchSize(Integer.MAX_VALUE);
                    ArrayList arrayList = new ArrayList(maxResults);
                    ResultSet executeQuery = prepareStatement.executeQuery();
                    while (executeQuery.next()) {
                        try {
                            double d = 1.0d - (executeQuery.getDouble("distance") / 2.0d);
                            if (d < embeddingSearchRequest.minScore()) {
                                break;
                            }
                            String string = executeQuery.getString(this.table.idColumn());
                            float[] fArr = (float[]) executeQuery.getObject(this.table.embeddingColumn(), float[].class);
                            String string2 = executeQuery.getString(this.table.textColumn());
                            arrayList.add(new EmbeddingMatch(Double.valueOf(d), string, new Embedding(fArr), string2 == null ? null : new TextSegment(string2, getMetadataFromOson((OracleJsonObject) executeQuery.getObject(this.table.metadataColumn(), OracleJsonObject.class)))));
                        } catch (Throwable th) {
                            if (executeQuery != null) {
                                try {
                                    executeQuery.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                            throw th;
                        }
                    }
                    if (executeQuery != null) {
                        executeQuery.close();
                    }
                    EmbeddingSearchResult<TextSegment> embeddingSearchResult = new EmbeddingSearchResult<>(arrayList);
                    if (prepareStatement != null) {
                        prepareStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                    return embeddingSearchResult;
                } catch (Throwable th3) {
                    if (prepareStatement != null) {
                        try {
                            prepareStatement.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                    }
                    throw th3;
                }
            } finally {
            }
        } catch (SQLException e) {
            throw uncheckSQLException(e);
        }
    }

    private static OracleJsonObject getOsonFromMetadata(Metadata metadata) {
        if (metadata == null) {
            return null;
        }
        OracleJsonFactory oracleJsonFactory = new OracleJsonFactory();
        OracleJsonObject createObject = oracleJsonFactory.createObject();
        for (Map.Entry entry : metadata.toMap().entrySet()) {
            String str = (String) entry.getKey();
            Object value = entry.getValue();
            if (value instanceof Number) {
                Number number = (Number) value;
                if (number instanceof Integer) {
                    createObject.put(str, number.intValue());
                } else if (number instanceof Long) {
                    createObject.put(str, number.longValue());
                } else if (number instanceof Float) {
                    createObject.put(str, oracleJsonFactory.createFloat(number.floatValue()));
                } else {
                    if (!(number instanceof Double)) {
                        throw unrecognizedMetadata(str, value);
                    }
                    createObject.put(str, number.doubleValue());
                }
            } else {
                createObject.put(str, value.toString());
            }
        }
        return createObject;
    }

    private static Metadata getMetadataFromOson(OracleJsonObject oracleJsonObject) {
        Metadata metadata = new Metadata();
        if (oracleJsonObject == null) {
            return metadata;
        }
        for (Map.Entry entry : oracleJsonObject.entrySet()) {
            String str = (String) entry.getKey();
            OracleJsonValue oracleJsonValue = (OracleJsonValue) entry.getValue();
            switch (AnonymousClass1.$SwitchMap$oracle$sql$json$OracleJsonValue$OracleJsonType[oracleJsonValue.getOracleJsonType().ordinal()]) {
                case 1:
                    metadata.put(str, oracleJsonValue.asJsonString().getString());
                    break;
                case 2:
                    OracleJsonDecimal asJsonDecimal = oracleJsonValue.asJsonDecimal();
                    switch (AnonymousClass1.$SwitchMap$oracle$sql$json$OracleJsonDecimal$TargetType[asJsonDecimal.getTargetType().ordinal()]) {
                        case 1:
                            metadata.put(str, asJsonDecimal.intValue());
                            break;
                        case 2:
                            metadata.put(str, asJsonDecimal.longValue());
                            break;
                        default:
                            metadata.put(str, asJsonDecimal.toString());
                            break;
                    }
                case 3:
                    metadata.put(str, oracleJsonValue.asJsonFloat().floatValue());
                    break;
                case 4:
                    metadata.put(str, oracleJsonValue.asJsonDouble().doubleValue());
                    break;
                default:
                    metadata.put(str, oracleJsonValue.toString());
                    break;
            }
        }
        return metadata;
    }

    private static RuntimeException uncheckSQLException(SQLException sQLException) {
        return sQLException instanceof BatchUpdateException ? uncheckSQLException((BatchUpdateException) sQLException) : new RuntimeException(sQLException);
    }

    private static RuntimeException uncheckSQLException(BatchUpdateException batchUpdateException) {
        SQLException nextException = batchUpdateException.getNextException();
        return new RuntimeException(nextException == null ? batchUpdateException : nextException);
    }

    private static <T> T ensureIndexNotNull(List<T> list, int i, String str) {
        T t = list.get(i);
        if (t != null) {
            return t;
        }
        throw new IllegalArgumentException("null entry at index " + i + " in " + str);
    }

    private static IllegalArgumentException unrecognizedMetadata(String str, Object obj) {
        return new IllegalArgumentException("Unrecognized object type in Metadata with key \"" + str + "\" and value \"" + String.valueOf(obj) + "\" of class " + obj.getClass().getSimpleName());
    }

    public static Builder builder() {
        return new Builder();
    }
}
