package org.dsmartml;

import java.util.Date;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.classification.LDA;
import org.apache.spark.ml.classification.LDAModel;
import org.apache.spark.ml.classification.QDA;
import org.apache.spark.ml.classification.QDAModel;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.TrainValidationSplit;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Serializable;
import scala.Tuple3;
import scala.collection.immutable.Map;
import scala.collection.mutable.StringBuilder;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;

/* compiled from: GridSearchManager.scala */
/* loaded from: input_file:org/dsmartml/GridSearchManager$$anonfun$Search$1.class */
public final class GridSearchManager$$anonfun$Search$1 extends AbstractFunction1<String, BoxedUnit> implements Serializable {
    public static final long serialVersionUID = 0;
    private final SparkSession spark$1;
    private final String featureCol$1;
    private final String TargetCol$1;
    private final int Parallelism$1;
    private final long seed$1;
    private final ObjectRef selectedModelMap$1;
    private final IntRef nr_classes$1;
    private final boolean hasNegativeFeatures$1;
    private final Dataset trainingData$1;
    private final Dataset testData$1;
    private final ClassifiersManager ClassifierMgr$1;

    public final void apply(String str) {
        int indexOf = ClassifiersManager$.MODULE$.classifiersLsit().indexOf(str);
        if ((this.nr_classes$1.elem == 2 && Predef$.MODULE$.intArrayOps(new int[]{4, 6}).contains(BoxesRunTime.boxToInteger(indexOf))) || ((this.nr_classes$1.elem >= 2 && Predef$.MODULE$.intArrayOps(new int[]{0, 1, 2, 3}).contains(BoxesRunTime.boxToInteger(indexOf))) || (!this.hasNegativeFeatures$1 && indexOf == 5))) {
            try {
                Predef$.MODULE$.println(new StringBuilder().append("-- GridSearch for algoritm: ").append(str).append(" Start").toString());
                long time = new Date().getTime();
                TrainValidationSplit parallelism = new TrainValidationSplit().setEstimator((Estimator) this.ClassifierMgr$1.ClassifiersMap().apply(str)).setEvaluator(this.ClassifierMgr$1.evaluator()).setEstimatorParamMaps((ParamMap[]) this.ClassifierMgr$1.ClassifierParamsMap().apply(str)).setCollectSubModels(false).setSeed(this.seed$1).setParallelism(this.Parallelism$1);
                long time2 = new Date().getTime();
                TrainValidationSplitModel fit = parallelism.fit(this.trainingData$1);
                long time3 = new Date().getTime() - time2;
                double evaluate = this.ClassifierMgr$1.evaluator().evaluate(fit.bestModel().transform(this.testData$1));
                this.selectedModelMap$1.elem = ((Map) this.selectedModelMap$1.elem).$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str), new Tuple3(fit, fit.bestModel().extractParamMap(), BoxesRunTime.boxToDouble(evaluate))));
                Predef$.MODULE$.println(new StringBuilder().append("   -- GridSearch for algoritm: ").append(str).append(" End (Time:").append(BoxesRunTime.boxToDouble((new Date().getTime() - time) / 1000.0d).toString()).append(")  Accuracy: ").append(BoxesRunTime.boxToDouble(evaluate)).toString());
            } catch (Exception e) {
                Predef$.MODULE$.println(new StringBuilder().append("   -- Exception (GridSearch - Search - ").append(str).append(" ):").append(e.getMessage()).toString());
                e.printStackTrace();
            }
        }
        if (indexOf == 7) {
            long time4 = new Date().getTime();
            String str2 = (String) ClassifiersManager$.MODULE$.classifiersLsit().apply(indexOf);
            MulticlassClassificationEvaluator metricName = new MulticlassClassificationEvaluator().setLabelCol(this.TargetCol$1).setPredictionCol("prediction").setMetricName("accuracy");
            LDA lda = new LDA();
            lda.sc_$eq(this.spark$1.sparkContext());
            lda.setLabelCol(this.TargetCol$1);
            lda.setFeaturesCol(this.featureCol$1);
            lda.setScaledData(false);
            lda.setPredictionCol("prediction");
            lda.setProbabilityCol("Probability");
            lda.setRawPredictionCol("RawPrediction");
            LDAModel fit2 = lda.fit(this.trainingData$1);
            double evaluate2 = metricName.evaluate(fit2.transform(this.testData$1));
            this.selectedModelMap$1.elem = ((Map) this.selectedModelMap$1.elem).$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str2), new Tuple3(fit2, (Object) null, BoxesRunTime.boxToDouble(evaluate2))));
            Predef$.MODULE$.println(new StringBuilder().append("   -- Hyperband for algoritm:").append(str2).append(" (Time:").append(BoxesRunTime.boxToDouble((new Date().getTime() - time4) / 1000.0d).toString()).append(") Accuracy: ").append(BoxesRunTime.boxToDouble(evaluate2)).toString());
        }
        if (indexOf == 8) {
            long time5 = new Date().getTime();
            String str3 = (String) ClassifiersManager$.MODULE$.classifiersLsit().apply(indexOf);
            MulticlassClassificationEvaluator metricName2 = new MulticlassClassificationEvaluator().setLabelCol(this.TargetCol$1).setPredictionCol("prediction").setMetricName("accuracy");
            QDA qda = new QDA(this.spark$1.sparkContext());
            qda.setLabelCol(this.TargetCol$1);
            qda.setFeaturesCol(this.featureCol$1);
            qda.setScaledData(false);
            qda.setPredictionCol("prediction");
            qda.setProbabilityCol("Probability");
            qda.setRawPredictionCol("RawPrediction");
            QDAModel fit3 = qda.fit(this.trainingData$1);
            double evaluate3 = metricName2.evaluate(fit3.transform(this.testData$1));
            this.selectedModelMap$1.elem = ((Map) this.selectedModelMap$1.elem).$plus(Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(str3), new Tuple3(fit3, (Object) null, BoxesRunTime.boxToDouble(evaluate3))));
            Predef$.MODULE$.println(new StringBuilder().append("   -- Hyperband for algoritm:").append(str3).append(" (Time:").append(BoxesRunTime.boxToDouble((new Date().getTime() - time5) / 1000.0d).toString()).append(")  Accuracy: ").append(BoxesRunTime.boxToDouble(evaluate3)).toString());
        }
    }

    public final /* bridge */ /* synthetic */ Object apply(Object obj) {
        apply((String) obj);
        return BoxedUnit.UNIT;
    }

    public GridSearchManager$$anonfun$Search$1(GridSearchManager gridSearchManager, SparkSession sparkSession, String str, String str2, int i, long j, ObjectRef objectRef, IntRef intRef, boolean z, Dataset dataset, Dataset dataset2, ClassifiersManager classifiersManager) {
        this.spark$1 = sparkSession;
        this.featureCol$1 = str;
        this.TargetCol$1 = str2;
        this.Parallelism$1 = i;
        this.seed$1 = j;
        this.selectedModelMap$1 = objectRef;
        this.nr_classes$1 = intRef;
        this.hasNegativeFeatures$1 = z;
        this.trainingData$1 = dataset;
        this.testData$1 = dataset2;
        this.ClassifierMgr$1 = classifiersManager;
    }
}
