package com.datarobot.mlops.stats;

import com.datarobot.mlops.stats.TypeConversion;
import com.datarobot.mlops.stats.exceptions.StatsAggregationException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tech.tablesaw.api.DoubleColumn;
import tech.tablesaw.api.Row;
import tech.tablesaw.api.StringColumn;
import tech.tablesaw.api.Table;

/* loaded from: input_file:com/datarobot/mlops/stats/StatsAggregator.class */
public class StatsAggregator {
    private static Logger logger = LoggerFactory.getLogger((Class<?>) StatsAggregator.class);
    private Map<String, NumericAggregate> featuresNumericAggregates = new HashMap();
    private Map<String, CategoricalAggregate> featuresCategoricalAggregates = new HashMap();
    private Map<String, NumericAggregate> predictionsAggregatesMap = new LinkedHashMap();
    private Map<String, SegmentAggregate> segmentAggregateMap = new HashMap();

    public List<String> getPredictionClassNamesList() {
        return new ArrayList(this.predictionsAggregatesMap.keySet());
    }

    private NumericAggregate aggregateNumericalFeature(List<Double> list, int i) {
        return aggregateNumericalFeature(DoubleColumn.create(Table.MELT_VALUE_COLUMN_NAME, list), i);
    }

    private NumericAggregate aggregateNumericalFeature(DoubleColumn doubleColumn, int i) {
        NumericAggregate numericAggregate = new NumericAggregate();
        numericAggregate.missingCount = Integer.valueOf(doubleColumn.countMissing());
        numericAggregate.count = Integer.valueOf(doubleColumn.size() - numericAggregate.missingCount.intValue());
        numericAggregate.max = Double.valueOf(doubleColumn.max());
        numericAggregate.sumOfSquares = Double.valueOf(doubleColumn.sumOfSquares());
        numericAggregate.min = Double.valueOf(doubleColumn.min());
        numericAggregate.sum = Double.valueOf(doubleColumn.sum());
        numericAggregate.histogram = new CentroidHistogram(i);
        numericAggregate.histogram.push(doubleColumn.filter(d -> {
            return !Double.isNaN(d);
        }));
        return numericAggregate;
    }

    private CategoricalAggregate aggregateCategoricalFeature(List<String> list) {
        return aggregateCategoricalFeature(StringColumn.create("category", list));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private CategoricalAggregate aggregateCategoricalFeature(StringColumn stringColumn) {
        CategoricalAggregate categoricalAggregate = new CategoricalAggregate();
        categoricalAggregate.missingCount = Integer.valueOf(stringColumn.countMissing());
        categoricalAggregate.count = Integer.valueOf(stringColumn.size() - categoricalAggregate.missingCount.intValue());
        categoricalAggregate.textWordCount = null;
        Iterator<Row> it = ((StringColumn) stringColumn.filter(str -> {
            return !str.isEmpty();
        })).countByCategory().iterator();
        while (it.hasNext()) {
            Row next = it.next();
            categoricalAggregate.categoryCounts.put(next.getString("Category"), Integer.valueOf(next.getInt("Count")));
        }
        return categoricalAggregate;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private CategoricalAggregate aggregateTextStatsFeature(List<Object> list, FeatureType featureType) {
        List list2 = (List) list.stream().map(obj -> {
            return obj == null ? "" : obj;
        }).map(String::valueOf).map(str -> {
            return str.replaceAll("\\d", "").toLowerCase();
        }).collect(Collectors.toList());
        StringColumn create = StringColumn.create("category", list2);
        CategoricalAggregate categoricalAggregate = new CategoricalAggregate();
        categoricalAggregate.missingCount = Integer.valueOf(create.countMissing());
        categoricalAggregate.count = Integer.valueOf(create.size() - categoricalAggregate.missingCount.intValue());
        StringColumn create2 = StringColumn.create("category");
        list2.forEach(str2 -> {
            create2.addAll(new ArrayList(new HashSet(featureType == FeatureType.TEXT_WORDS ? Arrays.asList(str2.split("\\W+")) : Arrays.asList(str2.split("")))));
        });
        StringColumn stringColumn = (StringColumn) create2.filter(str3 -> {
            return !str3.isEmpty();
        });
        Table countByCategory = stringColumn.countByCategory();
        categoricalAggregate.textWordCount = Integer.valueOf(stringColumn.size() - stringColumn.countMissing());
        Iterator<Row> it = countByCategory.iterator();
        while (it.hasNext()) {
            Row next = it.next();
            categoricalAggregate.categoryCounts.put(next.getString("Category"), Integer.valueOf(next.getInt("Count")));
        }
        return categoricalAggregate;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void aggregatePredictions(Map<String, List<Double>> map, int i) {
        map.forEach((str, list) -> {
            NumericAggregate aggregateNumericalFeature = aggregateNumericalFeature((List<Double>) list, i);
            NumericAggregate numericAggregate = this.predictionsAggregatesMap.get(str);
            this.predictionsAggregatesMap.put(str, numericAggregate != null ? numericAggregate.merge(aggregateNumericalFeature) : aggregateNumericalFeature);
        });
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void aggregateFeatures(Map<String, List<Object>> map, List<FeatureDescriptor> list, int i) {
        if (map == null || list == null) {
            return;
        }
        for (FeatureDescriptor featureDescriptor : list) {
            String name = featureDescriptor.getName();
            FeatureType featureType = featureDescriptor.getFeatureType();
            logger.debug("Iterating over: " + name + " of type: " + featureType);
            List<Object> list2 = map.get(name);
            if (list2 != null && !list2.isEmpty()) {
                switch (featureType) {
                    case NUMERIC:
                        updateNumericAggregate(name, aggregateNumericalFeature(convertToListOfDouble(list2), i));
                        break;
                    case CATEGORY:
                        updateCategoricalAggregate(name, aggregateCategoricalFeature(convertToListOfString(list2)));
                        break;
                    case TEXT_WORDS:
                    case TEXT_CHARS:
                        updateCategoricalAggregate(name, aggregateTextStatsFeature(list2, featureType));
                        break;
                    default:
                        logger.warn("Aggregation not implemented for " + featureType);
                        break;
                }
            } else {
                logger.debug("Skipping feature: " + name + " not in data");
                return;
            }
        }
    }

    private void updateNumericAggregate(String str, NumericAggregate numericAggregate) {
        NumericAggregate numericAggregate2 = this.featuresNumericAggregates.get(str);
        this.featuresNumericAggregates.put(str, numericAggregate2 != null ? numericAggregate2.merge(numericAggregate) : numericAggregate);
    }

    private void updateCategoricalAggregate(String str, CategoricalAggregate categoricalAggregate) {
        CategoricalAggregate categoricalAggregate2 = this.featuresCategoricalAggregates.get(str);
        this.featuresCategoricalAggregates.put(str, categoricalAggregate2 != null ? categoricalAggregate2.merge(categoricalAggregate) : categoricalAggregate);
    }

    private List<String> convertToListOfString(List<Object> list) {
        return (List) list.stream().map(obj -> {
            return obj == null ? "" : obj;
        }).map(String::valueOf).collect(Collectors.toList());
    }

    private List<Double> convertToListOfDouble(List<Object> list) {
        return (List) list.stream().map(obj -> {
            return obj instanceof String ? parseStringToDouble((String) obj) : obj instanceof Number ? Double.valueOf(((Number) obj).doubleValue()) : Double.valueOf(Double.NaN);
        }).collect(Collectors.toList());
    }

    private Double parseStringToDouble(String str) {
        if (str == null || str.equals("nan")) {
            return Double.valueOf(Double.NaN);
        }
        if (str.equalsIgnoreCase(BooleanUtils.TRUE)) {
            return Double.valueOf(1.0d);
        }
        if (str.equalsIgnoreCase(BooleanUtils.FALSE)) {
            return Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS);
        }
        try {
            return Double.valueOf(Double.parseDouble(str));
        } catch (NumberFormatException e) {
            return Double.valueOf(Double.NaN);
        }
    }

    private static <T> T getValueOrDefault(T t, T t2) {
        return t == null ? t2 : t;
    }

    public void aggregate(Map<String, List<Object>> map, List<FeatureDescriptor> list, Map<String, List<Double>> map2, Integer num, Integer num2) throws StatsAggregationException {
        aggregate(map, list, map2, num, num2, null, null);
    }

    public void aggregate(Map<String, List<Object>> map, List<FeatureDescriptor> list, Map<String, List<Double>> map2, Integer num, Integer num2, List<String> list2, Integer num3) throws StatsAggregationException {
        validateInput(map, list, map2, num, num2, list2);
        int intValue = ((Integer) getValueOrDefault(num, 100)).intValue();
        ((Integer) getValueOrDefault(num2, 10000)).intValue();
        int intValue2 = ((Integer) getValueOrDefault(num3, 10000)).intValue();
        TypeConversion.ConvertedFeatureTypes convertFeaturesForAggregation = TypeConversion.convertFeaturesForAggregation(map, list);
        Map<String, List<Object>> map3 = convertFeaturesForAggregation.features;
        List<FeatureDescriptor> list3 = convertFeaturesForAggregation.featureTypes;
        if (map2 != null) {
            logger.debug("Aggregating predictions");
            aggregatePredictions(map2, intValue);
        }
        if (map3 != null) {
            logger.debug("Aggregating features");
            aggregateFeatures(map3, list3, intValue);
        }
        if (list2 != null) {
            for (String str : list2) {
                List<String> convertToListOfString = convertToListOfString(map.get(str));
                this.segmentAggregateMap.computeIfAbsent(str, str2 -> {
                    return new SegmentAggregate();
                }).aggregate(convertToListOfString, aggregateCategoricalFeature(convertToListOfString), map3, list3, map2, intValue, intValue2);
            }
        }
    }

    private void validateInput(Map<String, List<Object>> map, List<FeatureDescriptor> list, Map<String, List<Double>> map2, Integer num, Integer num2, List<String> list2) throws StatsAggregationException {
        boolean z = (map == null || map.isEmpty()) ? false : true;
        boolean z2 = (list == null || list.isEmpty()) ? false : true;
        boolean z3 = (map2 == null || map2.isEmpty()) ? false : true;
        boolean z4 = (list2 == null || list2.isEmpty()) ? false : true;
        int size = z ? map.entrySet().iterator().next().getValue().size() : 0;
        int size2 = z3 ? map2.entrySet().iterator().next().getValue().size() : 0;
        if (z && z3 && size != size2) {
            throw new StatsAggregationException("Different numbers of rows for features and predictions specified");
        }
        if (z3) {
            Iterator<List<Double>> it = map2.values().iterator();
            while (it.hasNext()) {
                if (it.next().contains(null)) {
                    throw new StatsAggregationException("Missing values are not permitted in predictions");
                }
            }
        }
        if (z2) {
            if (!z) {
                throw new StatsAggregationException("Features must be specified when FeatureTypes are specified");
            }
            ArrayList arrayList = new ArrayList();
            for (FeatureDescriptor featureDescriptor : list) {
                if (!map.containsKey(featureDescriptor.getName())) {
                    arrayList.add(featureDescriptor.getName());
                }
            }
            if (!arrayList.isEmpty()) {
                throw new StatsAggregationException("Feature types '" + String.join(", ", arrayList) + "' not present in provided dataset");
            }
        }
        if (z4) {
            if (!z) {
                throw new StatsAggregationException("Features must be specified when segment attributes are specified");
            }
            Stream<String> stream = list2.stream();
            map.getClass();
            if (!stream.allMatch((v1) -> {
                return r1.containsKey(v1);
            })) {
                throw new StatsAggregationException("All segment attributes must be specified in features");
            }
        }
        if (num != null && num.intValue() <= 0) {
            throw new StatsAggregationException("If specified, histogram_bin_count must be a positive integer");
        }
        if (num2 != null && num2.intValue() <= 0) {
            throw new StatsAggregationException("If specified, distinct_category_count must be a positive integer");
        }
    }

    public Map<String, NumericAggregate> getPredictionsAggregates() {
        return this.predictionsAggregatesMap;
    }

    public Map<String, NumericAggregate> getFeaturesNumericAggregates() {
        return this.featuresNumericAggregates;
    }

    public Map<String, CategoricalAggregate> getFeaturesCategoricalAggregates() {
        return this.featuresCategoricalAggregates;
    }

    public Map<String, SegmentAggregate> getSegmentAggregateMap() {
        return this.segmentAggregateMap;
    }
}
