package com.github.hakenadu.javalangchains.chains.retrieval.lucene;

import com.github.hakenadu.javalangchains.chains.retrieval.RetrievalChain;
import com.github.hakenadu.javalangchains.util.PromptConstants;
import java.io.Closeable;
import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.queryparser.classic.ParseException;
import org.apache.lucene.queryparser.classic.QueryParser;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.store.Directory;

/* loaded from: input_file:com/github/hakenadu/javalangchains/chains/retrieval/lucene/LuceneRetrievalChain.class */
public class LuceneRetrievalChain extends RetrievalChain implements Closeable {
    private final Function<String, Query> queryCreator;
    private final Function<Document, Map<String, String>> documentCreator;
    private final IndexReader indexReader;
    private final IndexSearcher indexSearcher;

    public LuceneRetrievalChain(Directory directory, int i, Function<String, Query> function, Function<Document, Map<String, String>> function2) {
        super(i);
        this.queryCreator = function;
        this.documentCreator = function2;
        try {
            this.indexReader = DirectoryReader.open(directory);
            this.indexSearcher = new IndexSearcher(this.indexReader);
            this.indexSearcher.setSimilarity(new BM25Similarity());
        } catch (IOException e) {
            throw new IllegalStateException("could not open indexReader", e);
        }
    }

    public LuceneRetrievalChain(Directory directory, int i, Function<String, Query> function) {
        this(directory, i, function, LuceneRetrievalChain::createDocument);
    }

    public LuceneRetrievalChain(Directory directory, int i) {
        this(directory, i, LuceneRetrievalChain::createQuery, LuceneRetrievalChain::createDocument);
    }

    public LuceneRetrievalChain(Directory directory) {
        this(directory, 4);
    }

    @Override // com.github.hakenadu.javalangchains.chains.Chain
    public Stream<Map<String, String>> run(String str) {
        Query apply = this.queryCreator.apply(str);
        try {
            return Arrays.stream(this.indexSearcher.search(apply, getMaxDocumentCount()).scoreDocs).map(scoreDoc -> {
                try {
                    return this.indexSearcher.doc(scoreDoc.doc);
                } catch (IOException e) {
                    throw new IllegalStateException("could not process document " + scoreDoc.doc, e);
                }
            }).map(this.documentCreator).map(map -> {
                LinkedHashMap linkedHashMap = new LinkedHashMap(map);
                linkedHashMap.put(PromptConstants.QUESTION, str);
                return linkedHashMap;
            });
        } catch (IOException e) {
            throw new IllegalStateException("error processing search for query " + apply, e);
        }
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        this.indexReader.close();
    }

    private static Map<String, String> createDocument(Document document) {
        return (Map) document.getFields().stream().collect(Collectors.toMap((v0) -> {
            return v0.name();
        }, (v0) -> {
            return v0.stringValue();
        }));
    }

    private static Query createQuery(String str) {
        try {
            return new QueryParser(PromptConstants.CONTENT, new StandardAnalyzer()).parse(str);
        } catch (ParseException e) {
            throw new IllegalStateException("could not create query for searchTerm " + str, e);
        }
    }
}
