package org.datavec.spark.transform;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.analysis.DataAnalysis;
import org.datavec.api.transform.analysis.SequenceDataAnalysis;
import org.datavec.api.transform.analysis.columns.BytesAnalysis;
import org.datavec.api.transform.analysis.columns.CategoricalAnalysis;
import org.datavec.api.transform.analysis.columns.ColumnAnalysis;
import org.datavec.api.transform.analysis.columns.DoubleAnalysis;
import org.datavec.api.transform.analysis.columns.IntegerAnalysis;
import org.datavec.api.transform.analysis.columns.LongAnalysis;
import org.datavec.api.transform.analysis.columns.StringAnalysis;
import org.datavec.api.transform.analysis.columns.TimeAnalysis;
import org.datavec.api.transform.analysis.sequence.SequenceLengthAnalysis;
import org.datavec.api.transform.metadata.CategoricalMetaData;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.DoubleMetaData;
import org.datavec.api.transform.metadata.IntegerMetaData;
import org.datavec.api.transform.metadata.LongMetaData;
import org.datavec.api.transform.metadata.StringMetaData;
import org.datavec.api.transform.metadata.TimeMetaData;
import org.datavec.api.transform.quality.DataQualityAnalysis;
import org.datavec.api.transform.quality.columns.BytesQuality;
import org.datavec.api.transform.quality.columns.CategoricalQuality;
import org.datavec.api.transform.quality.columns.ColumnQuality;
import org.datavec.api.transform.quality.columns.DoubleQuality;
import org.datavec.api.transform.quality.columns.IntegerQuality;
import org.datavec.api.transform.quality.columns.LongQuality;
import org.datavec.api.transform.quality.columns.StringQuality;
import org.datavec.api.transform.quality.columns.TimeQuality;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.analysis.SelectColumnFunction;
import org.datavec.spark.transform.analysis.SequenceFlatMapFunction;
import org.datavec.spark.transform.analysis.SequenceLengthFunction;
import org.datavec.spark.transform.analysis.aggregate.AnalysisAddFunction;
import org.datavec.spark.transform.analysis.aggregate.AnalysisCombineFunction;
import org.datavec.spark.transform.analysis.columns.BytesAnalysisCounter;
import org.datavec.spark.transform.analysis.columns.CategoricalAnalysisCounter;
import org.datavec.spark.transform.analysis.columns.DoubleAnalysisCounter;
import org.datavec.spark.transform.analysis.columns.IntegerAnalysisCounter;
import org.datavec.spark.transform.analysis.columns.LongAnalysisCounter;
import org.datavec.spark.transform.analysis.histogram.HistogramAddFunction;
import org.datavec.spark.transform.analysis.histogram.HistogramCombineFunction;
import org.datavec.spark.transform.analysis.histogram.HistogramCounter;
import org.datavec.spark.transform.analysis.seqlength.IntToDoubleFunction;
import org.datavec.spark.transform.analysis.seqlength.SequenceLengthAnalysisAddFunction;
import org.datavec.spark.transform.analysis.seqlength.SequenceLengthAnalysisCounter;
import org.datavec.spark.transform.analysis.seqlength.SequenceLengthAnalysisMergeFunction;
import org.datavec.spark.transform.analysis.string.StringAnalysisCounter;
import org.datavec.spark.transform.filter.FilterWritablesBySchemaFunction;
import org.datavec.spark.transform.quality.categorical.CategoricalQualityAddFunction;
import org.datavec.spark.transform.quality.categorical.CategoricalQualityMergeFunction;
import org.datavec.spark.transform.quality.integer.IntegerQualityAddFunction;
import org.datavec.spark.transform.quality.integer.IntegerQualityMergeFunction;
import org.datavec.spark.transform.quality.longq.LongQualityAddFunction;
import org.datavec.spark.transform.quality.longq.LongQualityMergeFunction;
import org.datavec.spark.transform.quality.real.RealQualityAddFunction;
import org.datavec.spark.transform.quality.real.RealQualityMergeFunction;
import org.datavec.spark.transform.quality.string.StringQualityAddFunction;
import org.datavec.spark.transform.quality.string.StringQualityMergeFunction;
import org.datavec.spark.transform.quality.time.TimeQualityAddFunction;
import org.datavec.spark.transform.quality.time.TimeQualityMergeFunction;
import scala.Tuple2;

/* loaded from: input_file:org/datavec/spark/transform/AnalyzeSpark.class */
public class AnalyzeSpark {
    public static final int DEFAULT_HISTOGRAM_BUCKETS = 30;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.datavec.spark.transform.AnalyzeSpark$1, reason: invalid class name */
    /* loaded from: input_file:org/datavec/spark/transform/AnalyzeSpark$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$datavec$api$transform$ColumnType = new int[ColumnType.values().length];

        static {
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.String.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Integer.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Long.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Double.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Categorical.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Time.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Bytes.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    public static SequenceDataAnalysis analyzeSequence(Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return analyzeSequence(schema, javaRDD, 30);
    }

    public static SequenceDataAnalysis analyzeSequence(Schema schema, JavaRDD<List<List<Writable>>> javaRDD, int i) {
        javaRDD.cache();
        DataAnalysis analyze = analyze(schema, (JavaRDD<List<Writable>>) javaRDD.flatMap(new SequenceFlatMapFunction()));
        JavaRDD map = javaRDD.map(new SequenceLengthFunction());
        map.cache();
        SequenceLengthAnalysisCounter sequenceLengthAnalysisCounter = (SequenceLengthAnalysisCounter) map.aggregate(new SequenceLengthAnalysisCounter(), new SequenceLengthAnalysisAddFunction(), new SequenceLengthAnalysisMergeFunction());
        int maxLengthSeen = sequenceLengthAnalysisCounter.getMaxLengthSeen();
        int minLengthSeen = sequenceLengthAnalysisCounter.getMinLengthSeen();
        int maxLengthSeen2 = sequenceLengthAnalysisCounter.getMaxLengthSeen() - sequenceLengthAnalysisCounter.getMinLengthSeen();
        Tuple2 tuple2 = maxLengthSeen == minLengthSeen ? new Tuple2(new double[]{minLengthSeen}, new long[]{sequenceLengthAnalysisCounter.getCountTotal()}) : maxLengthSeen2 < i ? map.mapToDouble(new IntToDoubleFunction()).histogram(maxLengthSeen2) : map.mapToDouble(new IntToDoubleFunction()).histogram(i);
        map.unpersist();
        return new SequenceDataAnalysis(schema, analyze.getColumnAnalysis(), SequenceLengthAnalysis.builder().totalNumSequences(sequenceLengthAnalysisCounter.getCountTotal()).minSeqLength(sequenceLengthAnalysisCounter.getMinLengthSeen()).maxSeqLength(sequenceLengthAnalysisCounter.getMaxLengthSeen()).countZeroLength(sequenceLengthAnalysisCounter.getCountZeroLength()).countOneLength(sequenceLengthAnalysisCounter.getCountOneLength()).meanLength(sequenceLengthAnalysisCounter.getMean()).histogramBuckets((double[]) tuple2._1()).histogramBucketCounts((long[]) tuple2._2()).build());
    }

    public static DataAnalysis analyze(Schema schema, JavaRDD<List<Writable>> javaRDD) {
        return analyze(schema, javaRDD, 30);
    }

    public static DataAnalysis analyze(Schema schema, JavaRDD<List<Writable>> javaRDD, int i) {
        javaRDD.cache();
        List columnTypes = schema.getColumnTypes();
        List list = (List) javaRDD.aggregate((Object) null, new AnalysisAddFunction(schema), new AnalysisCombineFunction());
        double[][] dArr = new double[list.size()][2];
        int numColumns = schema.numColumns();
        ArrayList arrayList = new ArrayList(numColumns);
        for (int i2 = 0; i2 < numColumns; i2++) {
            ColumnType columnType = (ColumnType) columnTypes.get(i2);
            switch (AnonymousClass1.$SwitchMap$org$datavec$api$transform$ColumnType[columnType.ordinal()]) {
                case 1:
                    StringAnalysisCounter stringAnalysisCounter = (StringAnalysisCounter) list.get(i2);
                    arrayList.add(new StringAnalysis.Builder().countTotal(stringAnalysisCounter.getCountTotal()).minLength(stringAnalysisCounter.getMinLengthSeen()).maxLength(stringAnalysisCounter.getMaxLengthSeen()).meanLength(stringAnalysisCounter.getSumLength() / stringAnalysisCounter.getCountTotal()).build());
                    dArr[i2][0] = stringAnalysisCounter.getMinLengthSeen();
                    dArr[i2][1] = stringAnalysisCounter.getMaxLengthSeen();
                    break;
                case 2:
                    IntegerAnalysisCounter integerAnalysisCounter = (IntegerAnalysisCounter) list.get(i2);
                    arrayList.add(new IntegerAnalysis.Builder().min(integerAnalysisCounter.getMinValueSeen()).max(integerAnalysisCounter.getMaxValueSeen()).mean(integerAnalysisCounter.getSum() / integerAnalysisCounter.getCountTotal()).countZero(integerAnalysisCounter.getCountZero()).countNegative(integerAnalysisCounter.getCountNegative()).countPositive(integerAnalysisCounter.getCountPositive()).countMinValue(integerAnalysisCounter.getCountMinValue()).countMaxValue(integerAnalysisCounter.getCountMaxValue()).countTotal(integerAnalysisCounter.getCountTotal()).build());
                    dArr[i2][0] = integerAnalysisCounter.getMinValueSeen();
                    dArr[i2][1] = integerAnalysisCounter.getMaxValueSeen();
                    break;
                case 3:
                    LongAnalysisCounter longAnalysisCounter = (LongAnalysisCounter) list.get(i2);
                    arrayList.add(new LongAnalysis.Builder().min(longAnalysisCounter.getMinValueSeen()).max(longAnalysisCounter.getMaxValueSeen()).mean(longAnalysisCounter.getSum().doubleValue() / longAnalysisCounter.getCountTotal()).countZero(longAnalysisCounter.getCountZero()).countNegative(longAnalysisCounter.getCountNegative()).countPositive(longAnalysisCounter.getCountPositive()).countMinValue(longAnalysisCounter.getCountMinValue()).countMaxValue(longAnalysisCounter.getCountMaxValue()).countTotal(longAnalysisCounter.getCountTotal()).build());
                    dArr[i2][0] = longAnalysisCounter.getMinValueSeen();
                    dArr[i2][1] = longAnalysisCounter.getMaxValueSeen();
                    break;
                case 4:
                    DoubleAnalysisCounter doubleAnalysisCounter = (DoubleAnalysisCounter) list.get(i2);
                    arrayList.add(new DoubleAnalysis.Builder().min(doubleAnalysisCounter.getMinValueSeen()).max(doubleAnalysisCounter.getMaxValueSeen()).mean(doubleAnalysisCounter.getSum() / doubleAnalysisCounter.getCountTotal()).countZero(doubleAnalysisCounter.getCountZero()).countNegative(doubleAnalysisCounter.getCountNegative()).countPositive(doubleAnalysisCounter.getCountPositive()).countMinValue(doubleAnalysisCounter.getCountMinValue()).countMaxValue(doubleAnalysisCounter.getCountMaxValue()).countNaN(doubleAnalysisCounter.getCountNaN()).countTotal(doubleAnalysisCounter.getCountTotal()).build());
                    dArr[i2][0] = doubleAnalysisCounter.getMinValueSeen();
                    dArr[i2][1] = doubleAnalysisCounter.getMaxValueSeen();
                    break;
                case 5:
                    arrayList.add(new CategoricalAnalysis(((CategoricalAnalysisCounter) list.get(i2)).getCounts()));
                    break;
                case 6:
                    LongAnalysisCounter longAnalysisCounter2 = (LongAnalysisCounter) list.get(i2);
                    arrayList.add(new TimeAnalysis.Builder().min(longAnalysisCounter2.getMinValueSeen()).max(longAnalysisCounter2.getMaxValueSeen()).mean(longAnalysisCounter2.getSum().doubleValue() / longAnalysisCounter2.getCountTotal()).countZero(longAnalysisCounter2.getCountZero()).countNegative(longAnalysisCounter2.getCountNegative()).countPositive(longAnalysisCounter2.getCountPositive()).countMinValue(longAnalysisCounter2.getCountMinValue()).countMaxValue(longAnalysisCounter2.getCountMaxValue()).countTotal(longAnalysisCounter2.getCountTotal()).build());
                    dArr[i2][0] = longAnalysisCounter2.getMinValueSeen();
                    dArr[i2][1] = longAnalysisCounter2.getMaxValueSeen();
                    break;
                case 7:
                    arrayList.add(new BytesAnalysis.Builder().countTotal(((BytesAnalysisCounter) list.get(i2)).getCountTotal()).build());
                    break;
                default:
                    throw new IllegalStateException("Unknown column type: " + columnType);
            }
        }
        List list2 = (List) javaRDD.aggregate((Object) null, new HistogramAddFunction(i, schema, dArr), new HistogramCombineFunction());
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            HistogramCounter histogramCounter = (HistogramCounter) list2.get(i3);
            IntegerAnalysis integerAnalysis = (ColumnAnalysis) arrayList.get(i3);
            if (integerAnalysis instanceof IntegerAnalysis) {
                integerAnalysis.setHistogramBuckets(histogramCounter.getBins());
                integerAnalysis.setHistogramBucketCounts(histogramCounter.getCounts());
            } else if (integerAnalysis instanceof DoubleAnalysis) {
                ((DoubleAnalysis) integerAnalysis).setHistogramBuckets(histogramCounter.getBins());
                ((DoubleAnalysis) integerAnalysis).setHistogramBucketCounts(histogramCounter.getCounts());
            } else if (integerAnalysis instanceof LongAnalysis) {
                ((LongAnalysis) integerAnalysis).setHistogramBuckets(histogramCounter.getBins());
                ((LongAnalysis) integerAnalysis).setHistogramBucketCounts(histogramCounter.getCounts());
            } else if (integerAnalysis instanceof TimeAnalysis) {
                ((TimeAnalysis) integerAnalysis).setHistogramBuckets(histogramCounter.getBins());
                ((TimeAnalysis) integerAnalysis).setHistogramBucketCounts(histogramCounter.getCounts());
            } else if (integerAnalysis instanceof StringAnalysis) {
                ((StringAnalysis) integerAnalysis).setHistogramBuckets(histogramCounter.getBins());
                ((StringAnalysis) integerAnalysis).setHistogramBucketCounts(histogramCounter.getCounts());
            }
        }
        return new DataAnalysis(schema, arrayList);
    }

    public static List<Writable> sampleFromColumn(int i, String str, Schema schema, JavaRDD<List<Writable>> javaRDD) {
        return javaRDD.map(new SelectColumnFunction(schema.getIndexOfColumn(str))).takeSample(false, i);
    }

    public static List<Writable> sampleFromColumnSequence(int i, String str, Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return sampleFromColumn(i, str, schema, javaRDD.flatMap(new SequenceFlatMapFunction()));
    }

    public static List<Writable> getUnique(String str, Schema schema, JavaRDD<List<Writable>> javaRDD) {
        return javaRDD.map(new SelectColumnFunction(schema.getIndexOfColumn(str))).distinct().collect();
    }

    public static List<Writable> getUniqueSequence(String str, Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return getUnique(str, schema, javaRDD.flatMap(new SequenceFlatMapFunction()));
    }

    public static List<List<Writable>> sample(int i, JavaRDD<List<Writable>> javaRDD) {
        return javaRDD.takeSample(false, i);
    }

    public static List<List<List<Writable>>> sampleSequence(int i, JavaRDD<List<List<Writable>>> javaRDD) {
        return javaRDD.takeSample(false, i);
    }

    private static ColumnQuality analyze(ColumnMetaData columnMetaData, JavaRDD<Writable> javaRDD) {
        switch (AnonymousClass1.$SwitchMap$org$datavec$api$transform$ColumnType[columnMetaData.getColumnType().ordinal()]) {
            case 1:
                javaRDD.cache();
                return ((StringQuality) javaRDD.aggregate(new StringQuality(), new StringQualityAddFunction((StringMetaData) columnMetaData), new StringQualityMergeFunction())).add(new StringQuality(0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, javaRDD.distinct().count()));
            case 2:
                return (ColumnQuality) javaRDD.aggregate(new IntegerQuality(0L, 0L, 0L, 0L, 0L), new IntegerQualityAddFunction((IntegerMetaData) columnMetaData), new IntegerQualityMergeFunction());
            case 3:
                return (ColumnQuality) javaRDD.aggregate(new LongQuality(), new LongQualityAddFunction((LongMetaData) columnMetaData), new LongQualityMergeFunction());
            case 4:
                return (ColumnQuality) javaRDD.aggregate(new DoubleQuality(), new RealQualityAddFunction((DoubleMetaData) columnMetaData), new RealQualityMergeFunction());
            case 5:
                return (ColumnQuality) javaRDD.aggregate(new CategoricalQuality(), new CategoricalQualityAddFunction((CategoricalMetaData) columnMetaData), new CategoricalQualityMergeFunction());
            case 6:
                return (ColumnQuality) javaRDD.aggregate(new TimeQuality(), new TimeQualityAddFunction((TimeMetaData) columnMetaData), new TimeQualityMergeFunction());
            case 7:
                return new BytesQuality();
            default:
                throw new RuntimeException("Unknown or not implemented column type: " + columnMetaData.getColumnType());
        }
    }

    public static DataQualityAnalysis analyzeQualitySequence(Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return analyzeQuality(schema, javaRDD.flatMap(new SequenceFlatMapFunction()));
    }

    public static DataQualityAnalysis analyzeQuality(Schema schema, JavaRDD<List<Writable>> javaRDD) {
        javaRDD.cache();
        int numColumns = schema.numColumns();
        ArrayList arrayList = new ArrayList(numColumns);
        for (int i = 0; i < numColumns; i++) {
            arrayList.add(analyze(schema.getMetaData(i), (JavaRDD<Writable>) javaRDD.map(new SelectColumnFunction(i))));
        }
        return new DataQualityAnalysis(schema, arrayList);
    }

    public static List<Writable> sampleInvalidFromColumn(int i, String str, Schema schema, JavaRDD<List<Writable>> javaRDD) {
        return sampleInvalidFromColumn(i, str, schema, javaRDD, false);
    }

    public static List<Writable> sampleInvalidFromColumn(int i, String str, Schema schema, JavaRDD<List<Writable>> javaRDD, boolean z) {
        return javaRDD.map(new SelectColumnFunction(schema.getIndexOfColumn(str))).filter(new FilterWritablesBySchemaFunction(schema.getMetaData(str), false, z)).takeSample(false, i);
    }

    public static List<Writable> sampleInvalidFromColumnSequence(int i, String str, Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return sampleInvalidFromColumn(i, str, schema, javaRDD.flatMap(new SequenceFlatMapFunction()));
    }
}
