package org.apache.solr.client.solrj.io.stream;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.solr.client.solrj.SolrRequest;
import org.apache.solr.client.solrj.impl.CloudSolrClient;
import org.apache.solr.client.solrj.impl.HttpSolrClient;
import org.apache.solr.client.solrj.io.ClassificationEvaluation;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.request.QueryRequest;
import org.apache.solr.common.cloud.Replica;
import org.apache.solr.common.cloud.Slice;
import org.apache.solr.common.cloud.ZkCoreNodeProps;
import org.apache.solr.common.cloud.ZkStateReader;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.util.ExecutorUtil;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.util.SolrjNamedThreadFactory;
import org.apache.solr.handler.CdcrParams;
import org.aspectj.weaver.model.AsmRelationshipUtils;

/* loaded from: input_file:WEB-INF/lib/solr-solrj-6.6.3.jar:org/apache/solr/client/solrj/io/stream/TextLogitStream.class */
public class TextLogitStream extends TupleStream implements Expressible {
    private static final long serialVersionUID = 1;
    protected String zkHost;
    protected String collection;
    protected Map<String, String> params;
    protected String field;
    protected String name;
    protected String outcome;
    protected int positiveLabel;
    protected double threshold;
    protected List<Double> weights;
    protected int maxIterations;
    protected int iteration;
    protected double error;
    protected List<Double> idfs;
    protected ClassificationEvaluation evaluation;
    protected transient SolrClientCache cache;
    protected transient boolean isCloseCache;
    protected transient CloudSolrClient cloudSolrClient;
    protected transient StreamContext streamContext;
    protected ExecutorService executorService;
    protected TupleStream termsStream;
    private List<String> terms;
    private double learningRate = 0.01d;
    private double lastError = CMAESOptimizer.DEFAULT_STOPFITNESS;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:WEB-INF/lib/solr-solrj-6.6.3.jar:org/apache/solr/client/solrj/io/stream/TextLogitStream$LogitCall.class */
    public class LogitCall implements Callable<Tuple> {
        private String baseUrl;
        private String feature;
        private List<String> terms;
        private List<Double> weights;
        private int iteration;
        private String outcome;
        private int positiveLabel;
        private double learningRate;
        private Map<String, String> paramsMap;

        public LogitCall(String str, Map<String, String> map, String str2, List<String> list, List<Double> list2, String str3, int i, double d, int i2) {
            this.baseUrl = str;
            this.feature = str2;
            this.terms = list;
            this.weights = list2;
            this.iteration = i2;
            this.outcome = str3;
            this.positiveLabel = i;
            this.learningRate = d;
            this.paramsMap = map;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Tuple call() throws Exception {
            ModifiableSolrParams modifiableSolrParams = new ModifiableSolrParams();
            HttpSolrClient httpSolrClient = TextLogitStream.this.cache.getHttpSolrClient(this.baseUrl);
            modifiableSolrParams.add(CommonParams.DISTRIB, "false");
            modifiableSolrParams.add("fq", "{!tlogit}");
            modifiableSolrParams.add("feature", this.feature);
            modifiableSolrParams.add("terms", TextLogitStream.toString(this.terms));
            modifiableSolrParams.add("idfs", TextLogitStream.toString(TextLogitStream.this.idfs));
            for (String str : this.paramsMap.keySet()) {
                modifiableSolrParams.add(str, this.paramsMap.get(str));
            }
            if (this.weights != null) {
                modifiableSolrParams.add("weights", TextLogitStream.toString(this.weights));
            }
            modifiableSolrParams.add("iteration", Integer.toString(this.iteration));
            modifiableSolrParams.add("outcome", this.outcome);
            modifiableSolrParams.add("positiveLabel", Integer.toString(this.positiveLabel));
            modifiableSolrParams.add("threshold", Double.toString(TextLogitStream.this.threshold));
            modifiableSolrParams.add("alpha", Double.toString(this.learningRate));
            NamedList namedList = (NamedList) new QueryRequest(modifiableSolrParams, SolrRequest.METHOD.POST).process(httpSolrClient).getResponse().get("logit");
            List list = (List) namedList.get("weights");
            double doubleValue = ((Double) namedList.get(AsmRelationshipUtils.DECLARE_ERROR)).doubleValue();
            HashMap hashMap = new HashMap();
            hashMap.put(AsmRelationshipUtils.DECLARE_ERROR, Double.valueOf(doubleValue));
            hashMap.put("weights", list);
            hashMap.put("evaluation", namedList.get("evaluation"));
            return new Tuple(hashMap);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:WEB-INF/lib/solr-solrj-6.6.3.jar:org/apache/solr/client/solrj/io/stream/TextLogitStream$TermsStream.class */
    public static class TermsStream extends TupleStream {
        private List<String> terms;
        private Iterator<String> it;

        public TermsStream(List<String> list) {
            this.terms = list;
        }

        @Override // org.apache.solr.client.solrj.io.stream.TupleStream
        public void setStreamContext(StreamContext streamContext) {
        }

        @Override // org.apache.solr.client.solrj.io.stream.TupleStream
        public List<TupleStream> children() {
            return new ArrayList();
        }

        @Override // org.apache.solr.client.solrj.io.stream.TupleStream
        public void open() throws IOException {
            this.it = this.terms.iterator();
        }

        @Override // org.apache.solr.client.solrj.io.stream.TupleStream, java.io.Closeable, java.lang.AutoCloseable
        public void close() throws IOException {
        }

        @Override // org.apache.solr.client.solrj.io.stream.TupleStream
        public Tuple read() throws IOException {
            HashMap hashMap = new HashMap();
            if (!this.it.hasNext()) {
                hashMap.put("EOF", true);
                return new Tuple(hashMap);
            }
            hashMap.put("term_s", this.it.next());
            hashMap.put("score_f", Double.valueOf(1.0d));
            return new Tuple(hashMap);
        }

        @Override // org.apache.solr.client.solrj.io.stream.TupleStream
        public StreamComparator getStreamSort() {
            return null;
        }

        @Override // org.apache.solr.client.solrj.io.stream.TupleStream, org.apache.solr.client.solrj.io.stream.expr.Expressible
        public Explanation toExplanation(StreamFactory streamFactory) throws IOException {
            return new StreamExplanation(getStreamNodeId().toString()).withFunctionName("non-expressible").withImplementingClass(getClass().getName()).withExpressionType(Explanation.ExpressionType.STREAM_SOURCE).withExpression("non-expressible");
        }
    }

    public TextLogitStream(String str, String str2, Map map, String str3, String str4, TupleStream tupleStream, List<Double> list, String str5, int i, double d, int i2) throws IOException {
        init(str2, str, map, str3, str4, tupleStream, list, str5, i, d, i2, this.iteration);
    }

    public TextLogitStream(StreamExpression streamExpression, StreamFactory streamFactory) throws IOException {
        String valueOperand = streamFactory.getValueOperand(streamExpression, 0);
        List<StreamExpressionNamedParameter> namedOperands = streamFactory.getNamedOperands(streamExpression);
        StreamExpressionNamedParameter namedOperand = streamFactory.getNamedOperand(streamExpression, CdcrParams.ZK_HOST_PARAM);
        List<StreamExpression> expressionOperandsRepresentingTypes = streamFactory.getExpressionOperandsRepresentingTypes(streamExpression, Expressible.class, TupleStream.class);
        if (streamExpression.getParameters().size() != 1 + namedOperands.size() + expressionOperandsRepresentingTypes.size()) {
            throw new IOException(String.format(Locale.ROOT, "invalid expression %s - unknown operands found", streamExpression));
        }
        if (null == valueOperand) {
            throw new IOException(String.format(Locale.ROOT, "invalid expression %s - collectionName expected as first operand", streamExpression));
        }
        if (0 == namedOperands.size()) {
            throw new IOException(String.format(Locale.ROOT, "invalid expression %s - at least one named parameter expected. eg. 'q=*:*'", streamExpression));
        }
        HashMap hashMap = new HashMap();
        for (StreamExpressionNamedParameter streamExpressionNamedParameter : namedOperands) {
            if (!streamExpressionNamedParameter.getName().equals(CdcrParams.ZK_HOST_PARAM)) {
                hashMap.put(streamExpressionNamedParameter.getName(), streamExpressionNamedParameter.getParameter().toString().trim());
            }
        }
        String str = (String) hashMap.get("name");
        if (str == null) {
            throw new IOException("name param cannot be null for TextLogitStream");
        }
        hashMap.remove("name");
        String str2 = (String) hashMap.get("field");
        if (str2 == null) {
            throw new IOException("field param cannot be null for TextLogitStream");
        }
        hashMap.remove("field");
        if (expressionOperandsRepresentingTypes.size() <= 0) {
            throw new IOException("features must be present for TextLogitStream");
        }
        TupleStream constructStream = streamFactory.constructStream(expressionOperandsRepresentingTypes.get(0));
        String str3 = (String) hashMap.get("maxIterations");
        if (str3 == null) {
            throw new IOException("maxIterations param cannot be null for TextLogitStream");
        }
        int parseInt = Integer.parseInt(str3);
        hashMap.remove("maxIterations");
        String str4 = (String) hashMap.get("outcome");
        if (str4 == null) {
            throw new IOException("outcome param cannot be null for TextLogitStream");
        }
        hashMap.remove("outcome");
        String str5 = (String) hashMap.get("positiveLabel");
        int i = 1;
        if (str5 != null) {
            i = Integer.parseInt(str5);
            hashMap.remove("positiveLabel");
        }
        String str6 = (String) hashMap.get("threshold");
        double d = 0.5d;
        if (str6 != null) {
            d = Double.parseDouble(str6);
            hashMap.remove("threshold");
        }
        int i2 = 0;
        String str7 = (String) hashMap.get("iteration");
        if (str7 != null) {
            i2 = Integer.parseInt(str7);
            hashMap.remove("iteration");
        }
        ArrayList arrayList = null;
        String str8 = (String) hashMap.get("weights");
        if (str8 != null) {
            arrayList = new ArrayList();
            for (String str9 : str8.split(",")) {
                arrayList.add(Double.valueOf(Double.parseDouble(str9)));
            }
            hashMap.remove("weights");
        }
        String str10 = null;
        if (null == namedOperand) {
            str10 = streamFactory.getCollectionZkHost(valueOperand);
        } else if (namedOperand.getParameter() instanceof StreamExpressionValue) {
            str10 = ((StreamExpressionValue) namedOperand.getParameter()).getValue();
        }
        if (null == str10) {
            throw new IOException(String.format(Locale.ROOT, "invalid expression %s - zkHost not found for collection '%s'", streamExpression, valueOperand));
        }
        init(valueOperand, str10, hashMap, str, str2, constructStream, arrayList, str4, i, d, parseInt, i2);
    }

    @Override // org.apache.solr.client.solrj.io.stream.expr.Expressible
    public StreamExpressionParameter toExpression(StreamFactory streamFactory) throws IOException {
        return toExpression(streamFactory, true);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private StreamExpression toExpression(StreamFactory streamFactory, boolean z) throws IOException {
        StreamExpression streamExpression = new StreamExpression(streamFactory.getFunctionName(getClass()));
        streamExpression.addParameter(this.collection);
        if (z && !(this.termsStream instanceof TermsStream)) {
            if (!(this.termsStream instanceof Expressible)) {
                throw new IOException("This TextLogitStream contains a non-expressible TupleStream - it cannot be converted to an expression");
            }
            streamExpression.addParameter(((Expressible) this.termsStream).toExpression(streamFactory));
        }
        for (Map.Entry<String, String> entry : this.params.entrySet()) {
            streamExpression.addParameter(new StreamExpressionNamedParameter(entry.getKey(), entry.getValue()));
        }
        streamExpression.addParameter(new StreamExpressionNamedParameter("field", this.field));
        streamExpression.addParameter(new StreamExpressionNamedParameter("name", this.name));
        if (this.termsStream instanceof TermsStream) {
            loadTerms();
            streamExpression.addParameter(new StreamExpressionNamedParameter("terms", toString(this.terms)));
        }
        streamExpression.addParameter(new StreamExpressionNamedParameter("outcome", this.outcome));
        if (this.weights != null) {
            streamExpression.addParameter(new StreamExpressionNamedParameter("weights", toString(this.weights)));
        }
        streamExpression.addParameter(new StreamExpressionNamedParameter("maxIterations", Integer.toString(this.maxIterations)));
        if (this.iteration > 0) {
            streamExpression.addParameter(new StreamExpressionNamedParameter("iteration", Integer.toString(this.iteration)));
        }
        streamExpression.addParameter(new StreamExpressionNamedParameter("positiveLabel", Integer.toString(this.positiveLabel)));
        streamExpression.addParameter(new StreamExpressionNamedParameter("threshold", Double.toString(this.threshold)));
        streamExpression.addParameter(new StreamExpressionNamedParameter(CdcrParams.ZK_HOST_PARAM, this.zkHost));
        return streamExpression;
    }

    private void init(String str, String str2, Map map, String str3, String str4, TupleStream tupleStream, List<Double> list, String str5, int i, double d, int i2, int i3) throws IOException {
        this.zkHost = str2;
        this.collection = str;
        this.params = map;
        this.name = str3;
        this.field = str4;
        this.termsStream = tupleStream;
        this.outcome = str5;
        this.positiveLabel = i;
        this.threshold = d;
        this.weights = list;
        this.maxIterations = i2;
        this.iteration = i3;
    }

    @Override // org.apache.solr.client.solrj.io.stream.TupleStream
    public void setStreamContext(StreamContext streamContext) {
        this.cache = streamContext.getSolrClientCache();
        this.streamContext = streamContext;
        this.termsStream.setStreamContext(streamContext);
    }

    @Override // org.apache.solr.client.solrj.io.stream.TupleStream
    public void open() throws IOException {
        if (this.cache == null) {
            this.isCloseCache = true;
            this.cache = new SolrClientCache();
        } else {
            this.isCloseCache = false;
        }
        this.cloudSolrClient = this.cache.getCloudSolrClient(this.zkHost);
        this.executorService = ExecutorUtil.newMDCAwareCachedThreadPool(new SolrjNamedThreadFactory("TextLogitSolrStream"));
    }

    @Override // org.apache.solr.client.solrj.io.stream.TupleStream
    public List<TupleStream> children() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.termsStream);
        return arrayList;
    }

    protected List<String> getShardUrls() throws IOException {
        try {
            ZkStateReader zkStateReader = this.cloudSolrClient.getZkStateReader();
            Collection<Slice> slices = CloudSolrStream.getSlices(this.collection, zkStateReader, false);
            Set<String> liveNodes = zkStateReader.getClusterState().getLiveNodes();
            ArrayList arrayList = new ArrayList();
            Iterator<Slice> it = slices.iterator();
            while (it.hasNext()) {
                Collection<Replica> replicas = it.next().getReplicas();
                ArrayList arrayList2 = new ArrayList();
                for (Replica replica : replicas) {
                    if (replica.getState() == Replica.State.ACTIVE && liveNodes.contains(replica.getNodeName())) {
                        arrayList2.add(replica);
                    }
                }
                Collections.shuffle(arrayList2, new Random());
                arrayList.add(new ZkCoreNodeProps((Replica) arrayList2.get(0)).getCoreUrl());
            }
            return arrayList;
        } catch (Exception e) {
            throw new IOException(e);
        }
    }

    private List<Future<Tuple>> callShards(List<String> list) throws IOException {
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(this.executorService.submit(new LogitCall(it.next(), this.params, this.field, this.terms, this.weights, this.outcome, this.positiveLabel, this.learningRate, this.iteration)));
        }
        return arrayList;
    }

    @Override // org.apache.solr.client.solrj.io.stream.TupleStream, java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        if (this.isCloseCache) {
            this.cache.close();
        }
        this.executorService.shutdown();
        this.termsStream.close();
    }

    @Override // org.apache.solr.client.solrj.io.stream.TupleStream
    public StreamComparator getStreamSort() {
        return null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.solr.client.solrj.io.stream.TupleStream, org.apache.solr.client.solrj.io.stream.expr.Expressible
    public Explanation toExplanation(StreamFactory streamFactory) throws IOException {
        StreamExplanation streamExplanation = new StreamExplanation(getStreamNodeId().toString());
        streamExplanation.setFunctionName(streamFactory.getFunctionName(getClass()));
        streamExplanation.setImplementingClass(getClass().getName());
        streamExplanation.setExpressionType(Explanation.ExpressionType.MACHINE_LEARNING_MODEL);
        streamExplanation.setExpression(toExpression(streamFactory).toString());
        streamExplanation.addChild(this.termsStream.toExplanation(streamFactory));
        return streamExplanation;
    }

    public void loadTerms() throws IOException {
        if (this.terms != null) {
            return;
        }
        this.termsStream.open();
        this.terms = new ArrayList();
        this.idfs = new ArrayList();
        while (true) {
            Tuple read = this.termsStream.read();
            if (read.EOF) {
                this.termsStream.close();
                return;
            } else {
                this.terms.add(read.getString("term_s"));
                this.idfs.add(read.getDouble("idf_d"));
            }
        }
    }

    @Override // org.apache.solr.client.solrj.io.stream.TupleStream
    public Tuple read() throws IOException {
        try {
            int i = this.iteration + 1;
            this.iteration = i;
            if (i > this.maxIterations) {
                HashMap hashMap = new HashMap();
                hashMap.put("EOF", true);
                return new Tuple(hashMap);
            }
            if (this.idfs == null) {
                loadTerms();
                if (this.weights != null && this.terms.size() + 1 != this.weights.size()) {
                    throw new IOException(String.format(Locale.ROOT, "invalid expression %s - the number of weights must be %d, found %d", Integer.valueOf(this.terms.size() + 1), Integer.valueOf(this.weights.size())));
                }
            }
            ArrayList arrayList = new ArrayList();
            this.evaluation = new ClassificationEvaluation();
            this.error = CMAESOptimizer.DEFAULT_STOPFITNESS;
            Iterator<Future<Tuple>> it = callShards(getShardUrls()).iterator();
            while (it.hasNext()) {
                Tuple tuple = it.next().get();
                arrayList.add((List) tuple.get("weights"));
                this.error += tuple.getDouble(AsmRelationshipUtils.DECLARE_ERROR).doubleValue();
                this.evaluation.addEvaluation((Map) tuple.get("evaluation"));
            }
            this.weights = averageWeights(arrayList);
            HashMap hashMap2 = new HashMap();
            hashMap2.put("id", this.name + "_" + this.iteration);
            hashMap2.put("name_s", this.name);
            hashMap2.put("field_s", this.field);
            hashMap2.put("terms_ss", this.terms);
            hashMap2.put("iteration_i", Integer.valueOf(this.iteration));
            if (this.weights != null) {
                hashMap2.put("weights_ds", this.weights);
            }
            hashMap2.put("error_d", Double.valueOf(this.error));
            this.evaluation.putToMap(hashMap2);
            hashMap2.put("alpha_d", Double.valueOf(this.learningRate));
            hashMap2.put("idfs_ds", this.idfs);
            if (this.iteration != 1) {
                if (this.lastError <= this.error) {
                    this.learningRate *= 0.5d;
                } else {
                    this.learningRate *= 1.05d;
                }
            }
            this.lastError = this.error;
            return new Tuple(hashMap2);
        } catch (Exception e) {
            throw new IOException(e);
        }
    }

    private List<Double> averageWeights(List<List<Double>> list) {
        double[] dArr = new double[list.get(0).size()];
        for (List<Double> list2 : list) {
            for (int i = 0; i < dArr.length; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] + list2.get(i).doubleValue();
            }
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = dArr[i3] / list.size();
        }
        ArrayList arrayList = new ArrayList();
        for (double d : dArr) {
            arrayList.add(Double.valueOf(d));
        }
        return arrayList;
    }

    static String toString(List list) {
        StringBuilder sb = new StringBuilder();
        for (Object obj : list) {
            if (sb.length() > 0) {
                sb.append(",");
            }
            sb.append(obj.toString());
        }
        return sb.toString();
    }
}
