package com.github.hakenadu.javalangchains.chains.data.retrieval;

import com.github.hakenadu.javalangchains.util.PromptConstants;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;

/* loaded from: input_file:com/github/hakenadu/javalangchains/chains/data/retrieval/JdbcRetrievalChain.class */
public class JdbcRetrievalChain extends RetrievalChain {
    private final Supplier<Connection> connectionSupplier;
    private final Function<String, Pair<String, List<Object>>> queryBuilder;
    private final DocumentCreator documentCreator;

    @FunctionalInterface
    /* loaded from: input_file:com/github/hakenadu/javalangchains/chains/data/retrieval/JdbcRetrievalChain$DocumentCreator.class */
    public interface DocumentCreator {
        Map<String, String> create(ResultSet resultSet) throws SQLException;
    }

    public JdbcRetrievalChain(Supplier<Connection> supplier, Function<String, Pair<String, List<Object>>> function, DocumentCreator documentCreator, int i) {
        super(i);
        this.connectionSupplier = supplier;
        this.documentCreator = documentCreator;
        this.queryBuilder = function;
    }

    public JdbcRetrievalChain(Supplier<Connection> supplier, String str, String str2, int i) {
        this(supplier, (Function<String, Pair<String, List<Object>>>) str3 -> {
            return createQuery(str3, str, str2);
        }, JdbcRetrievalChain::documentFromResultSet, i);
    }

    public JdbcRetrievalChain(Supplier<Connection> supplier, int i) {
        this(supplier, "Documents", PromptConstants.CONTENT, i);
    }

    @Override // com.github.hakenadu.javalangchains.chains.Chain
    public Stream<Map<String, String>> run(String str) {
        Connection connection = this.connectionSupplier.get();
        Pair<String, List<Object>> apply = this.queryBuilder.apply(str);
        String str2 = (String) apply.getLeft();
        List list = (List) apply.getRight();
        try {
            PreparedStatement prepareStatement = connection.prepareStatement(str2);
            try {
                prepareStatement.setMaxRows(getMaxDocumentCount());
                for (int i = 0; i < list.size(); i++) {
                    prepareStatement.setObject(i + 1, list.get(i));
                }
                ResultSet executeQuery = prepareStatement.executeQuery();
                ArrayList arrayList = new ArrayList();
                while (executeQuery.next()) {
                    Map<String, String> create = this.documentCreator.create(executeQuery);
                    create.put(PromptConstants.QUESTION, str);
                    arrayList.add(create);
                }
                Stream<Map<String, String>> stream = arrayList.stream();
                if (prepareStatement != null) {
                    prepareStatement.close();
                }
                return stream;
            } finally {
            }
        } catch (SQLException e) {
            throw new IllegalStateException("error creating / executing database statement", e);
        }
    }

    private static Map<String, String> documentFromResultSet(ResultSet resultSet) throws SQLException {
        ResultSetMetaData metaData = resultSet.getMetaData();
        HashMap hashMap = new HashMap();
        for (int i = 1; i <= metaData.getColumnCount(); i++) {
            hashMap.put(metaData.getColumnName(i), resultSet.getObject(i).toString());
        }
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Pair<String, List<Object>> createQuery(String str, String str2, String str3) {
        return Pair.of(String.format("SELECT * FROM %s WHERE %s LIKE ANY (?)", str2, str3), Collections.singletonList((String[]) Arrays.stream(str.split(str)).map(str4 -> {
            return String.format("%%%s%%", str4);
        }).toArray(i -> {
            return new String[i];
        })));
    }
}
