package com.datarobot.mlops.drift;

import com.datarobot.mlops.collections.FeatureList;
import com.datarobot.mlops.drift.metric.DriftMetric;
import com.datarobot.mlops.drift.metric.DriftMetricType;
import com.datarobot.mlops.drift.metric.FeatureDistribution;
import com.datarobot.mlops.drift.metric.MetricFactory;
import com.datarobot.mlops.stats.FeatureDescriptor;
import com.datarobot.mlops.stats.StatsAggregator;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

/* loaded from: input_file:com/datarobot/mlops/drift/DriftReport.class */
public class DriftReport {
    private static double DEFAULT_DRIFT_THRESHOLD = 0.3d;
    private static int DEFAULT_NUMBER_OF_ROWS_TO_DEFINE = 10;
    private final Map<String, Double> featureDriftMetrics = new HashMap();
    private final Map<String, FeatureDistribution> featureDistributions = new HashMap();
    private final DriftMetricType metricType;

    private DriftReport(DriftMetricType driftMetricType) {
        this.metricType = driftMetricType;
    }

    public DriftMetricType getMetricType() {
        return this.metricType;
    }

    public Map<String, Double> getMetrics() {
        return this.featureDriftMetrics;
    }

    public static DriftReport compute(StatsAggregator statsAggregator, StatsAggregator statsAggregator2, List<FeatureDescriptor> list, String... strArr) {
        return compute(statsAggregator, statsAggregator2, list, DriftMetricType.PSI, null, strArr);
    }

    public static DriftReport compute(StatsAggregator statsAggregator, StatsAggregator statsAggregator2, List<FeatureDescriptor> list, Function<FeatureDistribution, Double> function, String... strArr) {
        return compute(statsAggregator, statsAggregator2, list, DriftMetricType.CUSTOM, function, strArr);
    }

    public static DriftReport compute(StatsAggregator statsAggregator, StatsAggregator statsAggregator2, List<FeatureDescriptor> list, DriftMetricType driftMetricType, String... strArr) {
        return compute(statsAggregator, statsAggregator2, list, driftMetricType, null, strArr);
    }

    private static DriftReport compute(StatsAggregator statsAggregator, StatsAggregator statsAggregator2, List<FeatureDescriptor> list, DriftMetricType driftMetricType, Function<FeatureDistribution, Double> function, String... strArr) {
        if (function == null) {
            DriftMetric metric = MetricFactory.getMetric(driftMetricType);
            Objects.requireNonNull(metric);
            function = metric::score;
        }
        DriftReport driftReport = new DriftReport(driftMetricType);
        FeatureList featureList = new FeatureList(list);
        HashSet hashSet = new HashSet(Arrays.asList(strArr));
        for (String str : statsAggregator2.getFeaturesNumericAggregates().keySet()) {
            if (hashSet.isEmpty() || hashSet.contains(str)) {
                FeatureDistribution constructNumericDistribution = Utils.constructNumericDistribution(str, featureList.getFeatureType(str), statsAggregator.getFeaturesNumericAggregates().get(str), statsAggregator2.getFeaturesNumericAggregates().get(str));
                driftReport.addFeatureMetric(constructNumericDistribution, function.apply(constructNumericDistribution).doubleValue());
            }
        }
        for (String str2 : statsAggregator2.getFeaturesCategoricalAggregates().keySet()) {
            if (hashSet.isEmpty() || hashSet.contains(str2)) {
                FeatureDistribution constructCategoricalDistribution = Utils.constructCategoricalDistribution(str2, featureList.getFeatureType(str2), statsAggregator.getFeaturesCategoricalAggregates().get(str2), statsAggregator2.getFeaturesCategoricalAggregates().get(str2));
                driftReport.addFeatureMetric(constructCategoricalDistribution, function.apply(constructCategoricalDistribution).doubleValue());
            }
        }
        return driftReport;
    }

    private void addFeatureMetric(FeatureDistribution featureDistribution, double d) {
        this.featureDriftMetrics.put(featureDistribution.featureName, Double.valueOf(d));
        this.featureDistributions.put(featureDistribution.featureName, featureDistribution);
    }

    public Set<String> getDriftedFeatures() {
        return getDriftedFeatures(DEFAULT_DRIFT_THRESHOLD, DEFAULT_NUMBER_OF_ROWS_TO_DEFINE);
    }

    public Set<String> getDriftedFeatures(double d) {
        return getDriftedFeatures(d, DEFAULT_NUMBER_OF_ROWS_TO_DEFINE);
    }

    public Set<String> getDriftedFeatures(double d, int i) {
        HashSet hashSet = new HashSet();
        for (String str : this.featureDriftMetrics.keySet()) {
            double doubleValue = this.featureDriftMetrics.get(str).doubleValue();
            FeatureDistribution featureDistribution = this.featureDistributions.get(str);
            boolean z = doubleValue >= d;
            boolean z2 = featureDistribution.actualSampleSize >= i;
            if (z && z2) {
                hashSet.add(str);
            }
        }
        return hashSet;
    }

    public FeatureDistribution getDistribution(String str) {
        return this.featureDistributions.get(str);
    }

    public Double getDriftScore(String str) {
        return this.featureDriftMetrics.get(str);
    }

    public Set<String> getFeatures() {
        return this.featureDistributions.keySet();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        for (String str : this.featureDriftMetrics.keySet()) {
            sb.append(this.featureDistributions.get(str));
            sb.append(this.metricType.getName()).append(": ").append(this.featureDriftMetrics.get(str)).append(System.lineSeparator());
            sb.append("------------------------------------ ").append(System.lineSeparator());
        }
        return sb.toString();
    }
}
