package org.apache.lucene.classification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;

/* loaded from: input_file:WEB-INF/lib/lucene-classification-6.6.3.jar:org/apache/lucene/classification/SimpleNaiveBayesClassifier.class */
public class SimpleNaiveBayesClassifier implements Classifier<BytesRef> {
    protected final IndexReader indexReader;
    protected final String[] textFieldNames;
    protected final String classFieldName;
    protected final Analyzer analyzer;
    protected final IndexSearcher indexSearcher;
    protected final Query query;

    public SimpleNaiveBayesClassifier(IndexReader indexReader, Analyzer analyzer, Query query, String str, String... strArr) {
        this.indexReader = indexReader;
        this.indexSearcher = new IndexSearcher(this.indexReader);
        this.textFieldNames = strArr;
        this.classFieldName = str;
        this.analyzer = analyzer;
        this.query = query;
    }

    @Override // org.apache.lucene.classification.Classifier
    public ClassificationResult<BytesRef> assignClass(String str) throws IOException {
        ClassificationResult<BytesRef> classificationResult = null;
        double d = -1.7976931348623157E308d;
        for (ClassificationResult<BytesRef> classificationResult2 : assignClassNormalizedList(str)) {
            if (classificationResult2.getScore() > d) {
                classificationResult = classificationResult2;
                d = classificationResult2.getScore();
            }
        }
        return classificationResult;
    }

    @Override // org.apache.lucene.classification.Classifier
    public List<ClassificationResult<BytesRef>> getClasses(String str) throws IOException {
        List<ClassificationResult<BytesRef>> assignClassNormalizedList = assignClassNormalizedList(str);
        Collections.sort(assignClassNormalizedList);
        return assignClassNormalizedList;
    }

    @Override // org.apache.lucene.classification.Classifier
    public List<ClassificationResult<BytesRef>> getClasses(String str, int i) throws IOException {
        List<ClassificationResult<BytesRef>> assignClassNormalizedList = assignClassNormalizedList(str);
        Collections.sort(assignClassNormalizedList);
        return assignClassNormalizedList.subList(0, i);
    }

    protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String str) throws IOException {
        ArrayList arrayList = new ArrayList();
        Terms terms = MultiFields.getTerms(this.indexReader, this.classFieldName);
        if (terms != null) {
            TermsEnum it = terms.iterator();
            String[] strArr = tokenize(str);
            int countDocsWithClass = countDocsWithClass();
            while (true) {
                BytesRef next = it.next();
                if (next == null) {
                    break;
                }
                if (next.length > 0) {
                    Term term = new Term(this.classFieldName, next);
                    arrayList.add(new ClassificationResult<>(term.bytes(), calculateLogPrior(term, countDocsWithClass) + calculateLogLikelihood(strArr, term, countDocsWithClass)));
                }
            }
        }
        return normClassificationResults(arrayList);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int countDocsWithClass() throws IOException {
        int totalHits;
        Terms terms = MultiFields.getTerms(this.indexReader, this.classFieldName);
        if (terms == null || terms.getDocCount() == -1) {
            TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
            BooleanQuery.Builder builder = new BooleanQuery.Builder();
            builder.add(new BooleanClause(new WildcardQuery(new Term(this.classFieldName, String.valueOf('*'))), BooleanClause.Occur.MUST));
            if (this.query != null) {
                builder.add(this.query, BooleanClause.Occur.MUST);
            }
            this.indexSearcher.search(builder.build(), totalHitCountCollector);
            totalHits = totalHitCountCollector.getTotalHits();
        } else {
            totalHits = terms.getDocCount();
        }
        return totalHits;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public String[] tokenize(String str) throws IOException {
        LinkedList linkedList = new LinkedList();
        for (String str2 : this.textFieldNames) {
            TokenStream tokenStream = this.analyzer.tokenStream(str2, str);
            Throwable th = null;
            try {
                try {
                    CharTermAttribute charTermAttribute = (CharTermAttribute) tokenStream.addAttribute(CharTermAttribute.class);
                    tokenStream.reset();
                    while (tokenStream.incrementToken()) {
                        linkedList.add(charTermAttribute.toString());
                    }
                    tokenStream.end();
                    if (tokenStream != null) {
                        if (0 != 0) {
                            try {
                                tokenStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            tokenStream.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (tokenStream != null) {
                    if (th != null) {
                        try {
                            tokenStream.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        tokenStream.close();
                    }
                }
                throw th3;
            }
        }
        return (String[]) linkedList.toArray(new String[linkedList.size()]);
    }

    private double calculateLogLikelihood(String[] strArr, Term term, int i) throws IOException {
        double d = 0.0d;
        for (String str : strArr) {
            d += Math.log((getWordFreqForClass(str, term) + 1) / (getTextTermFreqForClass(term) + i));
        }
        return d;
    }

    private double getTextTermFreqForClass(Term term) throws IOException {
        double d = 0.0d;
        for (String str : this.textFieldNames) {
            Terms terms = MultiFields.getTerms(this.indexReader, str);
            d += terms.getSumDocFreq() / terms.getDocCount();
        }
        return d * this.indexReader.docFreq(term);
    }

    private int getWordFreqForClass(String str, Term term) throws IOException {
        BooleanQuery.Builder builder = new BooleanQuery.Builder();
        BooleanQuery.Builder builder2 = new BooleanQuery.Builder();
        for (String str2 : this.textFieldNames) {
            builder2.add(new BooleanClause(new TermQuery(new Term(str2, str)), BooleanClause.Occur.SHOULD));
        }
        builder.add(new BooleanClause(builder2.build(), BooleanClause.Occur.MUST));
        builder.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST));
        if (this.query != null) {
            builder.add(this.query, BooleanClause.Occur.MUST);
        }
        TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
        this.indexSearcher.search(builder.build(), totalHitCountCollector);
        return totalHitCountCollector.getTotalHits();
    }

    private double calculateLogPrior(Term term, int i) throws IOException {
        return Math.log(docCount(term)) - Math.log(i);
    }

    private int docCount(Term term) throws IOException {
        return this.indexReader.docFreq(term);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ArrayList<ClassificationResult<BytesRef>> normClassificationResults(List<ClassificationResult<BytesRef>> list) {
        ArrayList<ClassificationResult<BytesRef>> arrayList = new ArrayList<>();
        if (!list.isEmpty()) {
            Collections.sort(list);
            double score = list.get(0).getScore();
            double d = 0.0d;
            Iterator<ClassificationResult<BytesRef>> it = list.iterator();
            while (it.hasNext()) {
                d += Math.exp(it.next().getScore() - score);
            }
            double log = score + Math.log(d);
            for (ClassificationResult<BytesRef> classificationResult : list) {
                arrayList.add(new ClassificationResult<>(classificationResult.getAssignedClass(), Math.exp(classificationResult.getScore() - log)));
            }
        }
        return arrayList;
    }
}
