package org.deeplearning4j.berkeley;

import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
import org.deeplearning4j.berkeley.MapFactory;

/* loaded from: input_file:org/deeplearning4j/berkeley/Counter.class */
public class Counter<E> implements Serializable {
    private static final long serialVersionUID = 1;
    Map<E, Double> entries;
    boolean dirty;
    double cacheTotal;
    MapFactory<E, Double> mf;
    double deflt;

    public double getDeflt() {
        return this.deflt;
    }

    public void setDeflt(double d) {
        this.deflt = d;
    }

    public Set<E> keySet() {
        return this.entries.keySet();
    }

    public Set<Map.Entry<E, Double>> entrySet() {
        return this.entries.entrySet();
    }

    public int size() {
        return this.entries.size();
    }

    public boolean isEmpty() {
        return size() == 0;
    }

    public boolean containsKey(E e) {
        return this.entries.containsKey(e);
    }

    public double getCount(E e) {
        Double d = this.entries.get(e);
        return d == null ? this.deflt : d.doubleValue();
    }

    public double getProbability(E e) {
        double count = getCount(e);
        double d = totalCount();
        if (d < 0.0d) {
            throw new RuntimeException("Can't call getProbability() with totalCount < 0.0");
        }
        if (d > 0.0d) {
            return count / d;
        }
        return 0.0d;
    }

    public void normalize() {
        double d = totalCount();
        for (E e : keySet()) {
            setCount(e, getCount(e) / d);
        }
        this.dirty = true;
    }

    public void setCount(E e, double d) {
        this.entries.put(e, Double.valueOf(d));
        this.dirty = true;
    }

    public void put(E e, double d, boolean z) {
        if (!z || !this.entries.containsKey(e)) {
            this.entries.put(e, Double.valueOf(d));
        } else if (d > this.entries.get(e).doubleValue()) {
            this.entries.put(e, Double.valueOf(d));
        }
        this.dirty = true;
    }

    public E sample(Random random) {
        double d = totalCount();
        if (d <= 0.0d) {
            throw new RuntimeException(String.format("Attempting to sample() with totalCount() %.3f\n", Double.valueOf(d)));
        }
        double d2 = 0.0d;
        double nextDouble = random.nextDouble();
        for (Map.Entry<E, Double> entry : this.entries.entrySet()) {
            d2 += entry.getValue().doubleValue() / d;
            if (nextDouble < d2) {
                return entry.getKey();
            }
        }
        throw new IllegalStateException("Shoudl've have returned a sample by now....");
    }

    public E sample() {
        return sample(new Random());
    }

    public void removeKey(E e) {
        setCount(e, 0.0d);
        this.dirty = true;
        removeKeyFromEntries(e);
    }

    protected void removeKeyFromEntries(E e) {
        this.entries.remove(e);
    }

    public void setMaxCount(E e, double d) {
        Double d2 = this.entries.get(e);
        if (d2 == null || d > d2.doubleValue()) {
            setCount(e, d);
            this.dirty = true;
        }
    }

    public void setMinCount(E e, double d) {
        Double d2 = this.entries.get(e);
        if (d2 == null || d < d2.doubleValue()) {
            setCount(e, d);
            this.dirty = true;
        }
    }

    public double incrementCount(E e, double d) {
        double count = getCount(e) + d;
        setCount(e, count);
        this.dirty = true;
        return count;
    }

    public void incrementAll(Collection<? extends E> collection, double d) {
        Iterator<? extends E> it = collection.iterator();
        while (it.hasNext()) {
            incrementCount(it.next(), d);
        }
        this.dirty = true;
    }

    public <T extends E> void incrementAll(Counter<T> counter) {
        for (T t : counter.keySet()) {
            incrementCount(t, counter.getCount(t));
        }
        this.dirty = true;
    }

    public double totalCount() {
        if (!this.dirty) {
            return this.cacheTotal;
        }
        double d = 0.0d;
        Iterator<Map.Entry<E, Double>> it = this.entries.entrySet().iterator();
        while (it.hasNext()) {
            d += it.next().getValue().doubleValue();
        }
        this.cacheTotal = d;
        this.dirty = false;
        return d;
    }

    public List<E> getSortedKeys() {
        PriorityQueue<E> asPriorityQueue = asPriorityQueue();
        ArrayList arrayList = new ArrayList();
        while (asPriorityQueue.hasNext()) {
            arrayList.add(asPriorityQueue.next());
        }
        return arrayList;
    }

    public E argMax() {
        double d = Double.NEGATIVE_INFINITY;
        E e = null;
        for (Map.Entry<E, Double> entry : this.entries.entrySet()) {
            if (entry.getValue().doubleValue() > d || e == null) {
                e = entry.getKey();
                d = entry.getValue().doubleValue();
            }
        }
        return e;
    }

    public double min() {
        return maxMinHelp(false);
    }

    public double max() {
        return maxMinHelp(true);
    }

    private double maxMinHelp(boolean z) {
        double d = z ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
        for (Map.Entry<E, Double> entry : this.entries.entrySet()) {
            if ((z && entry.getValue().doubleValue() > d) || (!z && entry.getValue().doubleValue() < d)) {
                d = entry.getValue().doubleValue();
            }
        }
        return d;
    }

    public String toString() {
        return toString(keySet().size());
    }

    public String toStringSortedByKeys() {
        StringBuilder sb = new StringBuilder("[");
        NumberFormat numberFormat = NumberFormat.getInstance();
        numberFormat.setMaximumFractionDigits(5);
        int i = 0;
        Iterator<E> it = new TreeSet(keySet()).iterator();
        while (it.hasNext()) {
            E next = it.next();
            sb.append(next.toString());
            sb.append(" : ");
            sb.append(numberFormat.format(getCount(next)));
            if (i < size() - 1) {
                sb.append(", ");
            }
            i++;
        }
        if (i < size()) {
            sb.append("...");
        }
        sb.append("]");
        return sb.toString();
    }

    public String toString(int i) {
        return asPriorityQueue().toString(i, false);
    }

    public String toString(int i, boolean z) {
        return asPriorityQueue().toString(i, z);
    }

    public PriorityQueue<E> asPriorityQueue() {
        PriorityQueue<E> priorityQueue = new PriorityQueue<>(this.entries.size());
        for (Map.Entry<E, Double> entry : this.entries.entrySet()) {
            priorityQueue.add(entry.getKey(), entry.getValue().doubleValue());
        }
        return priorityQueue;
    }

    public PriorityQueue<E> asMinPriorityQueue() {
        PriorityQueue<E> priorityQueue = new PriorityQueue<>(this.entries.size());
        for (Map.Entry<E, Double> entry : this.entries.entrySet()) {
            priorityQueue.add(entry.getKey(), -entry.getValue().doubleValue());
        }
        return priorityQueue;
    }

    public Counter() {
        this(false);
    }

    public Counter(boolean z) {
        this(z ? new MapFactory.IdentityHashMapFactory() : new MapFactory.HashMapFactory());
    }

    public Counter(MapFactory<E, Double> mapFactory) {
        this.dirty = true;
        this.cacheTotal = 0.0d;
        this.deflt = 0.0d;
        this.mf = mapFactory;
        this.entries = mapFactory.buildMap();
    }

    public Counter(Map<? extends E, Double> map) {
        this(false);
        this.entries = new HashMap();
        for (Map.Entry<? extends E, Double> entry : map.entrySet()) {
            incrementCount(entry.getKey(), entry.getValue().doubleValue());
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Counter(Counter<? extends E> counter) {
        this();
        incrementAll(counter);
    }

    public Counter(Collection<? extends E> collection) {
        this();
        incrementAll(collection, 1.0d);
    }

    public void pruneKeysBelowThreshold(double d) {
        Iterator<E> it = this.entries.keySet().iterator();
        while (it.hasNext()) {
            if (this.entries.get(it.next()).doubleValue() < d) {
                it.remove();
            }
        }
        this.dirty = true;
    }

    public Set<Map.Entry<E, Double>> getEntrySet() {
        return this.entries.entrySet();
    }

    public boolean isEqualTo(Counter<E> counter) {
        boolean z = true;
        for (E e : (counter.size() > size() ? counter : this).keySet()) {
            z &= counter.getCount(e) == getCount(e);
        }
        return z;
    }

    public static void main(String[] strArr) {
        Counter counter = new Counter();
        System.out.println(counter);
        counter.incrementCount("planets", 7.0d);
        System.out.println(counter);
        counter.incrementCount("planets", 1.0d);
        System.out.println(counter);
        counter.setCount("suns", 1.0d);
        System.out.println(counter);
        counter.setCount("aliens", 0.0d);
        System.out.println(counter);
        System.out.println(counter.toString(2));
        System.out.println("Total: " + counter.totalCount());
    }

    public void clear() {
        this.entries = this.mf.buildMap();
        this.dirty = true;
    }

    public void keepTopNKeys(int i) {
        keepKeysHelper(i, true);
    }

    public void keepBottomNKeys(int i) {
        keepKeysHelper(i, false);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void keepKeysHelper(int i, boolean z) {
        Counter counter = new Counter();
        int i2 = 0;
        for (Object obj : Iterators.able(z ? asPriorityQueue() : asMinPriorityQueue())) {
            if (i2 <= i) {
                counter.setCount(obj, getCount(obj));
            }
            i2++;
        }
        clear();
        incrementAll(counter);
        this.dirty = true;
    }

    public void setAllCounts(double d) {
        Iterator<E> it = keySet().iterator();
        while (it.hasNext()) {
            setCount(it.next(), d);
        }
    }

    public double dotProduct(Counter<E> counter) {
        double d = 0.0d;
        for (Map.Entry<E, Double> entry : getEntrySet()) {
            double count = counter.getCount(entry.getKey());
            if (count != 0.0d) {
                double doubleValue = entry.getValue().doubleValue();
                if (doubleValue != 0.0d) {
                    d += doubleValue * count;
                }
            }
        }
        return d;
    }

    public void scale(double d) {
        for (Map.Entry<E, Double> entry : getEntrySet()) {
            entry.setValue(Double.valueOf(entry.getValue().doubleValue() * d));
        }
    }

    public Counter<E> scaledClone(double d) {
        Counter<E> counter = new Counter<>();
        for (Map.Entry<E, Double> entry : getEntrySet()) {
            counter.setCount(entry.getKey(), entry.getValue().doubleValue() * d);
        }
        return counter;
    }

    public Counter<E> difference(Counter<E> counter) {
        Counter<E> counter2 = new Counter<>(this);
        for (E e : counter.keySet()) {
            counter2.incrementCount(e, (-1.0d) * counter.getCount(e));
        }
        return counter2;
    }

    public Counter<E> toLogSpace() {
        Counter<E> counter = new Counter<>(this);
        for (E e : counter.keySet()) {
            counter.setCount(e, Math.log(getCount(e)));
        }
        return counter;
    }

    public boolean approxEquals(Counter<E> counter, double d) {
        for (E e : keySet()) {
            if (Math.abs(getCount(e) - counter.getCount(e)) > d) {
                return false;
            }
        }
        for (E e2 : counter.keySet()) {
            if (Math.abs(getCount(e2) - counter.getCount(e2)) > d) {
                return false;
            }
        }
        return true;
    }

    public void setDirty(boolean z) {
        this.dirty = z;
    }

    public String toStringTabSeparated() {
        StringBuilder sb = new StringBuilder();
        for (E e : getSortedKeys()) {
            sb.append(e.toString() + "\t" + getCount(e) + "\n");
        }
        return sb.toString();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        Counter counter = (Counter) obj;
        if (this.dirty == counter.dirty && Double.compare(counter.cacheTotal, this.cacheTotal) == 0 && Double.compare(counter.deflt, this.deflt) == 0) {
            return this.entries == null ? counter.entries == null : this.entries.equals(counter.entries);
        }
        return false;
    }

    public int hashCode() {
        int hashCode = (31 * (this.entries != null ? this.entries.hashCode() : 0)) + (this.dirty ? 1 : 0);
        long doubleToLongBits = Double.doubleToLongBits(this.cacheTotal);
        int hashCode2 = (31 * ((31 * hashCode) + ((int) (doubleToLongBits ^ (doubleToLongBits >>> 32))))) + (this.mf != null ? this.mf.hashCode() : 0);
        long doubleToLongBits2 = Double.doubleToLongBits(this.deflt);
        return (31 * hashCode2) + ((int) (doubleToLongBits2 ^ (doubleToLongBits2 >>> 32)));
    }
}
