package org.datavec.spark.transform.analysis.histogram;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.datavec.api.writable.Writable;

/* loaded from: input_file:org/datavec/spark/transform/analysis/histogram/CategoricalHistogramCounter.class */
public class CategoricalHistogramCounter implements HistogramCounter {
    private HashMap<String, Integer> counts = new HashMap<>();
    private List<String> stateNames;

    public CategoricalHistogramCounter(List<String> list) {
        this.stateNames = list;
    }

    @Override // org.datavec.spark.transform.analysis.histogram.HistogramCounter
    public HistogramCounter add(Writable writable) {
        String obj = writable.toString();
        if (this.counts.containsKey(obj)) {
            this.counts.put(obj, Integer.valueOf(this.counts.get(obj).intValue() + 1));
        } else {
            this.counts.put(obj, 1);
        }
        return this;
    }

    @Override // org.datavec.spark.transform.analysis.histogram.HistogramCounter
    public HistogramCounter merge(HistogramCounter histogramCounter) {
        if (!(histogramCounter instanceof CategoricalHistogramCounter)) {
            throw new IllegalArgumentException("Input must be CategoricalHistogramCounter; got " + histogramCounter);
        }
        for (Map.Entry<String, Integer> entry : ((CategoricalHistogramCounter) histogramCounter).counts.entrySet()) {
            String key = entry.getKey();
            if (this.counts.containsKey(key)) {
                this.counts.put(key, Integer.valueOf(this.counts.get(key).intValue() + entry.getValue().intValue()));
            } else {
                this.counts.put(key, entry.getValue());
            }
        }
        return this;
    }

    @Override // org.datavec.spark.transform.analysis.histogram.HistogramCounter
    public double[] getBins() {
        double[] dArr = new double[this.stateNames.size() + 1];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = i;
        }
        return dArr;
    }

    @Override // org.datavec.spark.transform.analysis.histogram.HistogramCounter
    public long[] getCounts() {
        long[] jArr = new long[this.stateNames.size()];
        int i = 0;
        Iterator<String> it = this.stateNames.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            jArr[i2] = this.counts.containsKey(it.next()) ? this.counts.get(r0).intValue() : 0L;
        }
        return jArr;
    }
}
