package org.apache.solr.update.processor;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.document.DocumentClassifier;
import org.apache.lucene.classification.document.KNearestNeighborDocumentClassifier;
import org.apache.lucene.classification.document.SimpleNaiveBayesDocumentClassifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.util.BytesRef;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.update.AddUpdateCommand;
import org.apache.solr.update.processor.ClassificationUpdateProcessorFactory;

/* loaded from: input_file:WEB-INF/lib/solr-core-6.6.3.jar:org/apache/solr/update/processor/ClassificationUpdateProcessor.class */
class ClassificationUpdateProcessor extends UpdateRequestProcessor {
    private final String trainingClassField;
    private final String predictedClassField;
    private final int maxOutputClasses;
    private DocumentClassifier<BytesRef> classifier;

    public ClassificationUpdateProcessor(ClassificationUpdateProcessorParams classificationUpdateProcessorParams, UpdateRequestProcessor updateRequestProcessor, IndexReader indexReader, IndexSchema indexSchema) {
        super(updateRequestProcessor);
        this.trainingClassField = classificationUpdateProcessorParams.getTrainingClassField();
        this.predictedClassField = classificationUpdateProcessorParams.getPredictedClassField();
        this.maxOutputClasses = classificationUpdateProcessorParams.getMaxPredictedClasses();
        String[] inputFieldNames = classificationUpdateProcessorParams.getInputFieldNames();
        ClassificationUpdateProcessorFactory.Algorithm algorithm = classificationUpdateProcessorParams.getAlgorithm();
        HashMap hashMap = new HashMap();
        for (String str : removeBoost(inputFieldNames)) {
            hashMap.put(str, indexSchema.getField(str).getType().getQueryAnalyzer());
        }
        switch (algorithm) {
            case KNN:
                this.classifier = new KNearestNeighborDocumentClassifier(indexReader, null, classificationUpdateProcessorParams.getTrainingFilterQuery(), classificationUpdateProcessorParams.getK(), classificationUpdateProcessorParams.getMinDf(), classificationUpdateProcessorParams.getMinTf(), this.trainingClassField, hashMap, inputFieldNames);
                return;
            case BAYES:
                this.classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, this.trainingClassField, hashMap, inputFieldNames);
                return;
            default:
                return;
        }
    }

    private String[] removeBoost(String[] strArr) {
        String[] strArr2 = new String[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr2[i] = strArr[i].split("\\^")[0];
        }
        return strArr2;
    }

    @Override // org.apache.solr.update.processor.UpdateRequestProcessor
    public void processAdd(AddUpdateCommand addUpdateCommand) throws IOException {
        List<ClassificationResult<BytesRef>> classes;
        SolrInputDocument solrInputDocument = addUpdateCommand.getSolrInputDocument();
        Document luceneDocument = addUpdateCommand.getLuceneDocument();
        if (solrInputDocument.getFieldValue(this.trainingClassField) == null && (classes = this.classifier.getClasses(luceneDocument, this.maxOutputClasses)) != null) {
            Iterator<ClassificationResult<BytesRef>> it = classes.iterator();
            while (it.hasNext()) {
                solrInputDocument.addField(this.predictedClassField, it.next().getAssignedClass().utf8ToString());
            }
        }
        super.processAdd(addUpdateCommand);
    }
}
