package cc.mallet.types;

import cc.mallet.util.Randoms;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

/* loaded from: input_file:cc/mallet/types/Multinomial.class */
public class Multinomial extends FeatureVector {
    private static final long serialVersionUID = 1;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:cc/mallet/types/Multinomial$Estimator.class */
    public static abstract class Estimator implements Cloneable, Serializable {
        Alphabet dictionary;
        double[] counts;
        int size;
        static final int minCapacity = 16;
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 1;
        static final /* synthetic */ boolean $assertionsDisabled;

        static {
            $assertionsDisabled = !Multinomial.class.desiredAssertionStatus();
        }

        protected Estimator(double[] dArr, int i, Alphabet alphabet) {
            this.counts = dArr;
            this.size = i;
            this.dictionary = alphabet;
        }

        public Estimator(double[] dArr, Alphabet alphabet) {
            this(dArr, alphabet.size(), alphabet);
        }

        public Estimator() {
            this(new double[minCapacity], 0, null);
        }

        public Estimator(int i) {
            this(new double[i > minCapacity ? i : minCapacity], i, null);
        }

        public Estimator(Alphabet alphabet) {
            this(new double[alphabet.size()], alphabet.size(), alphabet);
        }

        public void setAlphabet(Alphabet alphabet) {
            this.size = alphabet.size();
            this.counts = new double[this.size];
            this.dictionary = alphabet;
        }

        public int size() {
            return this.dictionary == null ? this.size : this.dictionary.size();
        }

        protected void ensureCapacity(int i) {
            if (i > this.size) {
                this.size = i;
            }
            if (this.counts.length > i) {
                return;
            }
            int length = this.counts.length < minCapacity ? minCapacity : this.counts.length;
            while (true) {
                int i2 = length;
                if (i2 > i) {
                    double[] dArr = new double[i2];
                    System.arraycopy(this.counts, 0, dArr, 0, this.counts.length);
                    this.counts = dArr;
                    return;
                }
                length = i2 * 2;
            }
        }

        public void reset() {
            for (int i = 0; i < this.counts.length; i++) {
                this.counts[i] = 0.0d;
            }
        }

        private void setCounts(double[] dArr) {
            if (!$assertionsDisabled && this.dictionary != null && dArr.length > size()) {
                throw new AssertionError();
            }
            this.counts = dArr;
        }

        public void increment(int i, double d) {
            ensureCapacity(i);
            double[] dArr = this.counts;
            dArr[i] = dArr[i] + d;
            if (this.size < i + 1) {
                this.size = i + 1;
            }
        }

        public void increment(String str, double d) {
            increment(this.dictionary.lookupIndex(str), d);
        }

        public void increment(FeatureSequence featureSequence, double d) {
            if (featureSequence.getAlphabet() != this.dictionary) {
                throw new IllegalArgumentException("Vocabularies don't match.");
            }
            for (int i = 0; i < featureSequence.size(); i++) {
                increment(featureSequence.getIndexAtPosition(i), d);
            }
        }

        public void increment(FeatureSequence featureSequence) {
            increment(featureSequence, 1.0d);
        }

        public void increment(FeatureVector featureVector, double d) {
            if (featureVector.getAlphabet() != this.dictionary) {
                throw new IllegalArgumentException("Vocabularies don't match.");
            }
            for (int i = 0; i < featureVector.numLocations(); i++) {
                increment(featureVector.indexAtLocation(i), d * featureVector.valueAtLocation(i));
            }
        }

        public void increment(FeatureVector featureVector) {
            increment(featureVector, 1.0d);
        }

        public double getCount(int i) {
            return this.counts[i];
        }

        public Object clone() {
            try {
                return super.clone();
            } catch (CloneNotSupportedException e) {
                return null;
            }
        }

        public void print() {
            System.out.println("Multinomial.Estimator");
            for (int i = 0; i < this.size; i++) {
                System.out.println("counts[" + i + "] = " + this.counts[i]);
            }
        }

        public abstract Multinomial estimate();

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.writeInt(1);
            objectOutputStream.writeObject(this.dictionary);
            objectOutputStream.writeObject(this.counts);
            objectOutputStream.writeInt(this.size);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            int readInt = objectInputStream.readInt();
            if (readInt != 1) {
                throw new ClassNotFoundException("Mismatched Multionmial.Estimator versions: wanted 1, got " + readInt);
            }
            this.dictionary = (Alphabet) objectInputStream.readObject();
            this.counts = (double[]) objectInputStream.readObject();
            this.size = objectInputStream.readInt();
        }
    }

    /* loaded from: input_file:cc/mallet/types/Multinomial$LaplaceEstimator.class */
    public static class LaplaceEstimator extends MEstimator {
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 1;

        public LaplaceEstimator() {
            super(1.0d);
        }

        public LaplaceEstimator(int i) {
            super(i, 1.0d);
        }

        public LaplaceEstimator(Alphabet alphabet) {
            super(alphabet, 1.0d);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.writeInt(1);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            int readInt = objectInputStream.readInt();
            if (readInt != 1) {
                throw new ClassNotFoundException("Mismatched Multinomial.LaplaceEstimator versions: wanted 1, got " + readInt);
            }
        }
    }

    /* loaded from: input_file:cc/mallet/types/Multinomial$Logged.class */
    public static class Logged extends Multinomial {
        private static final long serialVersionUID = 1;
        static final /* synthetic */ boolean $assertionsDisabled;

        static {
            $assertionsDisabled = !Multinomial.class.desiredAssertionStatus();
        }

        public Logged(double[] dArr, Alphabet alphabet, int i, boolean z) {
            super(dArr, alphabet, i, true, !z);
            if (!$assertionsDisabled && alphabet != null && alphabet.size() != i) {
                throw new AssertionError();
            }
            if (z) {
                return;
            }
            for (int i2 = 0; i2 < i; i2++) {
                this.values[i2] = Math.log(this.values[i2]);
            }
        }

        public Logged(double[] dArr, Alphabet alphabet, boolean z) {
            this(dArr, alphabet, alphabet == null ? dArr.length : alphabet.size(), z);
        }

        public Logged(double[] dArr, Alphabet alphabet, int i) {
            this(dArr, alphabet, i, false);
        }

        public Logged(double[] dArr, Alphabet alphabet) {
            this(dArr, alphabet, alphabet.size(), false);
        }

        public Logged(Multinomial multinomial) {
            this(multinomial.values, multinomial.dictionary, false);
        }

        public Logged(double[] dArr) {
            this(dArr, (Alphabet) null, false);
        }

        @Override // cc.mallet.types.Multinomial
        public double probability(int i) {
            return Math.exp(this.values[i]);
        }

        @Override // cc.mallet.types.Multinomial
        public double logProbability(int i) {
            return this.values[i];
        }

        public void addProbabilities(double[] dArr) {
            if (!$assertionsDisabled && dArr.length != this.values.length) {
                throw new AssertionError();
            }
            for (int i = 0; i < dArr.length; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] + Math.exp(this.values[i]);
            }
        }

        public void addLogProbabilities(double[] dArr) {
            for (int i = 0; i < this.values.length; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] + this.values[i];
            }
            for (int length = this.values.length; length < dArr.length; length++) {
                dArr[length] = Double.NEGATIVE_INFINITY;
            }
        }
    }

    /* loaded from: input_file:cc/mallet/types/Multinomial$MAPEstimator.class */
    public static class MAPEstimator extends Estimator {
        Dirichlet prior;
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 1;

        public MAPEstimator(Dirichlet dirichlet) {
            super(dirichlet.size());
            this.prior = dirichlet;
        }

        @Override // cc.mallet.types.Multinomial.Estimator
        public Multinomial estimate() {
            return null;
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.writeInt(1);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            int readInt = objectInputStream.readInt();
            if (readInt != 1) {
                throw new ClassNotFoundException("Mismatched Multinomial.MAPEstimator versions: wanted 1, got " + readInt);
            }
        }
    }

    /* loaded from: input_file:cc/mallet/types/Multinomial$MEstimator.class */
    public static class MEstimator extends Estimator {
        double m;
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 1;

        public MEstimator(Alphabet alphabet, double d) {
            super(alphabet);
            this.m = d;
        }

        public MEstimator(int i, double d) {
            super(i);
            this.m = d;
        }

        public MEstimator(double d) {
            this.m = d;
        }

        @Override // cc.mallet.types.Multinomial.Estimator
        public Multinomial estimate() {
            double[] dArr = new double[this.dictionary == null ? this.size : this.dictionary.size()];
            if (this.dictionary != null) {
                ensureCapacity(this.dictionary.size() - 1);
            }
            double d = 0.0d;
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = this.counts[i] + this.m;
                d += dArr[i];
            }
            for (int i2 = 0; i2 < dArr.length; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] / d;
            }
            return new Multinomial(dArr, this.dictionary, this.size, false, false);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.writeInt(1);
            objectOutputStream.writeDouble(this.m);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            int readInt = objectInputStream.readInt();
            if (readInt != 1) {
                throw new ClassNotFoundException("Mismatched Multinomial.MEstimator versions: wanted 1, got " + readInt);
            }
            this.m = objectInputStream.readDouble();
        }
    }

    /* loaded from: input_file:cc/mallet/types/Multinomial$MLEstimator.class */
    public static class MLEstimator extends MEstimator {
        private static final long serialVersionUID = 1;
        private static final int CURRENT_SERIAL_VERSION = 1;

        public MLEstimator() {
            super(0.0d);
        }

        public MLEstimator(int i) {
            super(i, 0.0d);
        }

        public MLEstimator(Alphabet alphabet) {
            super(alphabet, 0.0d);
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.writeInt(1);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            int readInt = objectInputStream.readInt();
            if (readInt != 1) {
                throw new ClassNotFoundException("Mismatched Multinomial.MLEstimator versions: wanted 1, got " + readInt);
            }
        }
    }

    static {
        $assertionsDisabled = !Multinomial.class.desiredAssertionStatus();
    }

    private static double[] getValues(double[] dArr, Alphabet alphabet, int i, boolean z, boolean z2) {
        double[] dArr2;
        if (!$assertionsDisabled && alphabet != null && alphabet.size() < i) {
            throw new AssertionError();
        }
        if (z) {
            dArr2 = new double[alphabet == null ? i : alphabet.size()];
            System.arraycopy(dArr, 0, dArr2, 0, i);
        } else {
            if (!$assertionsDisabled && alphabet != null && alphabet.size() != dArr.length) {
                throw new AssertionError();
            }
            dArr2 = dArr;
        }
        if (z2) {
            double d = 0.0d;
            for (double d2 : dArr2) {
                d += d2;
            }
            if (Math.abs(d - 1.0d) > 0.9999d) {
                throw new IllegalArgumentException("Probabilities sum to " + d + ", not to one.");
            }
        }
        return dArr2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Multinomial(double[] dArr, Alphabet alphabet, int i, boolean z, boolean z2) {
        super(alphabet, getValues(dArr, alphabet, i, z, z2));
    }

    public Multinomial(double[] dArr, Alphabet alphabet) {
        this(dArr, alphabet, alphabet.size(), true, true);
    }

    public Multinomial(double[] dArr, int i) {
        this(dArr, null, i, true, true);
    }

    public Multinomial(double[] dArr) {
        this(dArr, null, dArr.length, true, true);
    }

    public int size() {
        return this.values.length;
    }

    public double probability(int i) {
        return this.values[i];
    }

    public double probability(Object obj) {
        if (this.dictionary == null) {
            throw new IllegalStateException("This Multinomial has no dictionary.");
        }
        return probability(this.dictionary.lookupIndex(obj));
    }

    public double logProbability(int i) {
        return Math.log(this.values[i]);
    }

    public double logProbability(Object obj) {
        if (this.dictionary == null) {
            throw new IllegalStateException("This Multinomial has no dictionary.");
        }
        return logProbability(this.dictionary.lookupIndex(obj));
    }

    @Override // cc.mallet.types.FeatureVector, cc.mallet.types.AlphabetCarrying
    public Alphabet getAlphabet() {
        return this.dictionary;
    }

    public void addProbabilitiesTo(double[] dArr) {
        for (int i = 0; i < this.values.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + this.values[i];
        }
    }

    public int randomIndex(Randoms randoms) {
        double nextUniform = randoms.nextUniform();
        double d = 0.0d;
        int i = 0;
        while (i < this.values.length) {
            d += this.values[i];
            if (d >= nextUniform) {
                break;
            }
            i++;
        }
        if ($assertionsDisabled || d >= nextUniform) {
            return i;
        }
        throw new AssertionError();
    }

    public Object randomObject(Randoms randoms) {
        if (this.dictionary == null) {
            throw new IllegalStateException("This Multinomial has no dictionary.");
        }
        return this.dictionary.lookupObject(randomIndex(randoms));
    }

    public FeatureSequence randomFeatureSequence(Randoms randoms, int i) {
        if (!(this.dictionary instanceof Alphabet)) {
            throw new UnsupportedOperationException("Multinomial's dictionary must be a Alphabet");
        }
        FeatureSequence featureSequence = new FeatureSequence(this.dictionary, i);
        while (true) {
            int i2 = i;
            i--;
            if (i2 <= 0) {
                return featureSequence;
            }
            featureSequence.add(randomIndex(randoms));
        }
    }

    public FeatureVector randomFeatureVector(Randoms randoms, int i) {
        return new FeatureVector(randomFeatureSequence(randoms, i));
    }
}
