package org.bigml.mimir.nlp.topicmodel;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.random.MersenneTwister;
import org.bigml.mimir.nlp.tokenization.TokenStreamFactory;

/* loaded from: input_file:org/bigml/mimir/nlp/topicmodel/Inferencer.class */
public class Inferencer implements Serializable {
    private static final int _MIN_UPDATES = 16;
    private static final int _MAX_UPDATES = 512;
    private static final int _SAMPLES_PER_TOPIC = 128;
    private int _numberOfTopics;
    private int[] _seed;
    private int[] _lastAssignments;
    private double _alpha;
    private double _Kalpha;
    private double[][] _phi;
    private transient ThreadLocal<MersenneTwister> _randomGenerator;
    private transient ThreadLocal<Mappifier> _mappifier;
    private String _language;
    private List<String> _terms;
    private TokenStreamFactory _streamBuilder;
    private static final long serialVersionUID = 1;

    public static double[] thetaForDoc(long[] jArr, double d) {
        int length = jArr.length;
        long j = 0;
        for (long j2 : jArr) {
            j += j2;
        }
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            dArr[i] = (jArr[i] + d) / (j + (length * d));
        }
        return dArr;
    }

    public Inferencer(TopicModelParameters topicModelParameters, String str, TokenStreamFactory tokenStreamFactory) {
        this._seed = topicModelParameters.getSeed();
        this._seed[0] = Math.abs(this._seed[0]);
        this._numberOfTopics = topicModelParameters.getNumberOfTopics();
        this._alpha = topicModelParameters.getAlpha();
        this._Kalpha = this._numberOfTopics * this._alpha;
        this._phi = topicModelParameters.getPhi();
        this._streamBuilder = tokenStreamFactory;
        this._language = str;
        this._terms = topicModelParameters.getTerms();
        createThreadLocals();
    }

    public double[] topicDistribution(String str) {
        try {
            return topicDistribution(docToInts(str));
        } catch (UnsupportedEncodingException e) {
            throw new IllegalArgumentException(e);
        }
    }

    public double[] topicDistribution(int[] iArr) {
        return thetaForDoc(inference(iArr), this._alpha);
    }

    public int[] docToInts(String str) throws UnsupportedEncodingException {
        if (this._mappifier == null || this._streamBuilder == null) {
            throw new UnsupportedOperationException("No tokenization parameters specified!");
        }
        return this._mappifier.get().documentToInts(this._streamBuilder.getTokenList(str));
    }

    public int[] docToInts(Map<String, Long> map) {
        return this._mappifier.get().documentToInts(map);
    }

    public long[] inference(int[] iArr) {
        return inference(iArr, iArr.length > 0 ? Math.min(_MAX_UPDATES, Math.max(16, (_SAMPLES_PER_TOPIC * this._numberOfTopics) / iArr.length)) : 0);
    }

    public long[] inference(int[] iArr, int i) {
        MersenneTwister mersenneTwister = this._randomGenerator.get();
        double[] dArr = new double[this._numberOfTopics];
        mersenneTwister.setSeed(this._seed);
        int length = iArr.length;
        long[] jArr = new long[this._numberOfTopics];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 : iArr) {
                int sampleUniform = sampleUniform(i3, dArr, mersenneTwister);
                jArr[sampleUniform] = jArr[sampleUniform] + serialVersionUID;
            }
        }
        long[] jArr2 = new long[this._numberOfTopics];
        double d = (length * i) + this._Kalpha;
        double d2 = this._Kalpha;
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 : iArr) {
                int sampleTopic = sampleTopic(i5, jArr, d, dArr, mersenneTwister);
                jArr2[sampleTopic] = jArr2[sampleTopic] + serialVersionUID;
            }
            d2 += length;
        }
        long[] jArr3 = new long[this._numberOfTopics];
        int[] iArr2 = new int[length];
        for (int i6 = 0; i6 < i; i6++) {
            for (int i7 = 0; i7 < length; i7++) {
                int sampleTopic2 = sampleTopic(iArr[i7], jArr2, d2, dArr, mersenneTwister);
                jArr3[sampleTopic2] = jArr3[sampleTopic2] + serialVersionUID;
                iArr2[i7] = sampleTopic2;
            }
        }
        this._lastAssignments = iArr2;
        return jArr3;
    }

    public int[] getLastAssignments() {
        return this._lastAssignments;
    }

    public double getAlpha() {
        return this._alpha;
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        createThreadLocals();
    }

    private void createThreadLocals() {
        this._randomGenerator = new ThreadLocal<MersenneTwister>() { // from class: org.bigml.mimir.nlp.topicmodel.Inferencer.1
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public MersenneTwister initialValue() {
                return new MersenneTwister(Inferencer.this._seed);
            }
        };
        if (this._streamBuilder != null) {
            this._mappifier = new ThreadLocal<Mappifier>() { // from class: org.bigml.mimir.nlp.topicmodel.Inferencer.2
                /* JADX INFO: Access modifiers changed from: protected */
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.lang.ThreadLocal
                public Mappifier initialValue() {
                    return new Mappifier(Inferencer.this._terms, Inferencer.this._language);
                }
            };
        }
    }

    private int sampleTopic(int i, long[] jArr, double d, double[] dArr, MersenneTwister mersenneTwister) {
        for (int i2 = 0; i2 < this._numberOfTopics; i2++) {
            dArr[i2] = (this._phi[i2][i] * (jArr[i2] + this._alpha)) / d;
        }
        return sampleTopic(dArr, mersenneTwister);
    }

    private int sampleUniform(int i, double[] dArr, MersenneTwister mersenneTwister) {
        for (int i2 = 0; i2 < this._numberOfTopics; i2++) {
            dArr[i2] = this._phi[i2][i];
        }
        return sampleTopic(dArr, mersenneTwister);
    }

    private int sampleTopic(double[] dArr, MersenneTwister mersenneTwister) {
        for (int i = 1; i < this._numberOfTopics; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + dArr[i - 1];
        }
        double nextDouble = mersenneTwister.nextDouble() * dArr[this._numberOfTopics - 1];
        int i3 = 0;
        while (dArr[i3] < nextDouble && i3 < this._numberOfTopics) {
            i3++;
        }
        return i3;
    }
}
