package org.datavec.spark.transform;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.spark.api.java.JavaRDD;
import org.datavec.api.transform.analysis.DataAnalysis;
import org.datavec.api.transform.analysis.DataVecAnalysisUtils;
import org.datavec.api.transform.analysis.SequenceDataAnalysis;
import org.datavec.api.transform.analysis.quality.QualityAnalysisAddFunction;
import org.datavec.api.transform.analysis.quality.QualityAnalysisCombineFunction;
import org.datavec.api.transform.analysis.quality.QualityAnalysisState;
import org.datavec.api.transform.analysis.sequence.SequenceLengthAnalysis;
import org.datavec.api.transform.quality.DataQualityAnalysis;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.comparator.Comparators;
import org.datavec.spark.transform.analysis.SelectColumnFunction;
import org.datavec.spark.transform.analysis.SequenceFlatMapFunction;
import org.datavec.spark.transform.analysis.SequenceLengthFunction;
import org.datavec.spark.transform.analysis.aggregate.AnalysisAddFunction;
import org.datavec.spark.transform.analysis.aggregate.AnalysisCombineFunction;
import org.datavec.spark.transform.analysis.histogram.HistogramAddFunction;
import org.datavec.spark.transform.analysis.histogram.HistogramCombineFunction;
import org.datavec.spark.transform.analysis.seqlength.IntToDoubleFunction;
import org.datavec.spark.transform.analysis.seqlength.SequenceLengthAnalysisAddFunction;
import org.datavec.spark.transform.analysis.seqlength.SequenceLengthAnalysisCounter;
import org.datavec.spark.transform.analysis.seqlength.SequenceLengthAnalysisMergeFunction;
import org.datavec.spark.transform.analysis.unique.UniqueAddFunction;
import org.datavec.spark.transform.analysis.unique.UniqueMergeFunction;
import org.datavec.spark.transform.filter.FilterWritablesBySchemaFunction;
import org.datavec.spark.transform.misc.ColumnToKeyPairTransform;
import org.datavec.spark.transform.misc.SumLongsFunction2;
import org.datavec.spark.transform.misc.comparator.Tuple2Comparator;
import org.datavec.spark.transform.utils.adapter.BiFunctionAdapter;
import scala.Tuple2;

/* loaded from: input_file:org/datavec/spark/transform/AnalyzeSpark.class */
public class AnalyzeSpark {
    public static final int DEFAULT_HISTOGRAM_BUCKETS = 30;

    public static SequenceDataAnalysis analyzeSequence(Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return analyzeSequence(schema, javaRDD, 30);
    }

    public static SequenceDataAnalysis analyzeSequence(Schema schema, JavaRDD<List<List<Writable>>> javaRDD, int i) {
        javaRDD.cache();
        DataAnalysis analyze = analyze(schema, javaRDD.flatMap(new SequenceFlatMapFunction()));
        JavaRDD map = javaRDD.map(new SequenceLengthFunction());
        map.cache();
        SequenceLengthAnalysisCounter sequenceLengthAnalysisCounter = (SequenceLengthAnalysisCounter) map.aggregate(new SequenceLengthAnalysisCounter(), new SequenceLengthAnalysisAddFunction(), new SequenceLengthAnalysisMergeFunction());
        int maxLengthSeen = sequenceLengthAnalysisCounter.getMaxLengthSeen();
        int minLengthSeen = sequenceLengthAnalysisCounter.getMinLengthSeen();
        int maxLengthSeen2 = sequenceLengthAnalysisCounter.getMaxLengthSeen() - sequenceLengthAnalysisCounter.getMinLengthSeen();
        Tuple2 tuple2 = maxLengthSeen == minLengthSeen ? new Tuple2(new double[]{minLengthSeen}, new long[]{sequenceLengthAnalysisCounter.getCountTotal()}) : maxLengthSeen2 < i ? map.mapToDouble(new IntToDoubleFunction()).histogram(maxLengthSeen2) : map.mapToDouble(new IntToDoubleFunction()).histogram(i);
        map.unpersist();
        return new SequenceDataAnalysis(schema, analyze.getColumnAnalysis(), SequenceLengthAnalysis.builder().totalNumSequences(sequenceLengthAnalysisCounter.getCountTotal()).minSeqLength(sequenceLengthAnalysisCounter.getMinLengthSeen()).maxSeqLength(sequenceLengthAnalysisCounter.getMaxLengthSeen()).countZeroLength(sequenceLengthAnalysisCounter.getCountZeroLength()).countOneLength(sequenceLengthAnalysisCounter.getCountOneLength()).meanLength(sequenceLengthAnalysisCounter.getMean()).histogramBuckets((double[]) tuple2._1()).histogramBucketCounts((long[]) tuple2._2()).build());
    }

    public static DataAnalysis analyze(Schema schema, JavaRDD<List<Writable>> javaRDD) {
        return analyze(schema, javaRDD, 30);
    }

    public static DataAnalysis analyze(Schema schema, JavaRDD<List<Writable>> javaRDD, int i) {
        javaRDD.cache();
        List columnTypes = schema.getColumnTypes();
        List list = (List) javaRDD.aggregate((Object) null, new AnalysisAddFunction(schema), new AnalysisCombineFunction());
        double[][] dArr = new double[list.size()][2];
        List convertCounters = DataVecAnalysisUtils.convertCounters(list, dArr, columnTypes);
        DataVecAnalysisUtils.mergeCounters(convertCounters, (List) javaRDD.aggregate((Object) null, new HistogramAddFunction(i, schema, dArr), new HistogramCombineFunction()));
        return new DataAnalysis(schema, convertCounters);
    }

    public static List<Writable> sampleFromColumn(int i, String str, Schema schema, JavaRDD<List<Writable>> javaRDD) {
        return javaRDD.map(new SelectColumnFunction(schema.getIndexOfColumn(str))).takeSample(false, i);
    }

    public static List<Writable> sampleFromColumnSequence(int i, String str, Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return sampleFromColumn(i, str, schema, javaRDD.flatMap(new SequenceFlatMapFunction()));
    }

    public static List<Writable> getUnique(String str, Schema schema, JavaRDD<List<Writable>> javaRDD) {
        return javaRDD.map(new SelectColumnFunction(schema.getIndexOfColumn(str))).distinct().collect();
    }

    public static Map<String, List<Writable>> getUnique(List<String> list, Schema schema, JavaRDD<List<Writable>> javaRDD) {
        Map map = (Map) javaRDD.aggregate((Object) null, new UniqueAddFunction(list, schema), new UniqueMergeFunction());
        HashMap hashMap = new HashMap();
        for (String str : map.keySet()) {
            hashMap.put(str, new ArrayList((Collection) map.get(str)));
        }
        return hashMap;
    }

    public static List<Writable> getUniqueSequence(String str, Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return getUnique(str, schema, (JavaRDD<List<Writable>>) javaRDD.flatMap(new SequenceFlatMapFunction()));
    }

    public static Map<String, List<Writable>> getUniqueSequence(List<String> list, Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return getUnique(list, schema, (JavaRDD<List<Writable>>) javaRDD.flatMap(new SequenceFlatMapFunction()));
    }

    public static List<List<Writable>> sample(int i, JavaRDD<List<Writable>> javaRDD) {
        return javaRDD.takeSample(false, i);
    }

    public static List<List<List<Writable>>> sampleSequence(int i, JavaRDD<List<List<Writable>>> javaRDD) {
        return javaRDD.takeSample(false, i);
    }

    public static DataQualityAnalysis analyzeQualitySequence(Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return analyzeQuality(schema, javaRDD.flatMap(new SequenceFlatMapFunction()));
    }

    public static DataQualityAnalysis analyzeQuality(Schema schema, JavaRDD<List<Writable>> javaRDD) {
        int numColumns = schema.numColumns();
        List list = (List) javaRDD.aggregate((Object) null, new BiFunctionAdapter(new QualityAnalysisAddFunction(schema)), new BiFunctionAdapter(new QualityAnalysisCombineFunction()));
        ArrayList arrayList = new ArrayList(numColumns);
        Iterator it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(((QualityAnalysisState) it.next()).getColumnQuality());
        }
        return new DataQualityAnalysis(schema, arrayList);
    }

    public static List<Writable> sampleInvalidFromColumn(int i, String str, Schema schema, JavaRDD<List<Writable>> javaRDD) {
        return sampleInvalidFromColumn(i, str, schema, javaRDD, false);
    }

    public static List<Writable> sampleInvalidFromColumn(int i, String str, Schema schema, JavaRDD<List<Writable>> javaRDD, boolean z) {
        return javaRDD.map(new SelectColumnFunction(schema.getIndexOfColumn(str))).filter(new FilterWritablesBySchemaFunction(schema.getMetaData(str), false, z)).takeSample(false, i);
    }

    public static List<Writable> sampleInvalidFromColumnSequence(int i, String str, Schema schema, JavaRDD<List<List<Writable>>> javaRDD) {
        return sampleInvalidFromColumn(i, str, schema, javaRDD.flatMap(new SequenceFlatMapFunction()));
    }

    public static Map<Writable, Long> sampleMostFrequentFromColumn(int i, String str, Schema schema, JavaRDD<List<Writable>> javaRDD) {
        ArrayList<Tuple2> arrayList = new ArrayList(javaRDD.mapToPair(new ColumnToKeyPairTransform(schema.getIndexOfColumn(str))).reduceByKey(new SumLongsFunction2()).takeOrdered(i, new Tuple2Comparator(false)));
        Collections.sort(arrayList, new Tuple2Comparator(false));
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Tuple2 tuple2 : arrayList) {
            linkedHashMap.put(tuple2._1(), tuple2._2());
        }
        return linkedHashMap;
    }

    public static Writable min(JavaRDD<List<Writable>> javaRDD, String str, Schema schema) {
        return (Writable) javaRDD.map(new SelectColumnFunction(schema.getIndexOfColumn(str))).min(Comparators.forType(schema.getType(str).getWritableType()));
    }

    public static Writable max(JavaRDD<List<Writable>> javaRDD, String str, Schema schema) {
        return (Writable) javaRDD.map(new SelectColumnFunction(schema.getIndexOfColumn(str))).max(Comparators.forType(schema.getType(str).getWritableType()));
    }
}
