package org.deeplearning4j.spark.impl.multilayer;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.listener.RoutingIterationListener;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.impl.common.reduce.IntDoubleReduceFunction;
import org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouterProvider;
import org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluateFlatMapFunction;
import org.deeplearning4j.spark.impl.multilayer.evaluation.EvaluationReduceFunction;
import org.deeplearning4j.spark.impl.multilayer.scoring.ScoreExamplesFunction;
import org.deeplearning4j.spark.impl.multilayer.scoring.ScoreExamplesWithKeyFunction;
import org.deeplearning4j.spark.impl.multilayer.scoring.ScoreFlatMapFunction;
import org.deeplearning4j.spark.util.MLLibUtil;
import org.deeplearning4j.spark.util.SparkUtils;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

/* loaded from: input_file:org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.class */
public class SparkDl4jMultiLayer implements Serializable {
    private static final Logger log = LoggerFactory.getLogger(SparkDl4jMultiLayer.class);
    public static final int DEFAULT_EVAL_SCORE_BATCH_SIZE = 64;
    private transient JavaSparkContext sc;
    private TrainingMaster trainingMaster;
    private MultiLayerConfiguration conf;
    private MultiLayerNetwork network;
    private double lastScore;
    private List<IterationListener> listeners;
    private StatsStorage statsStorage;

    public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerNetwork multiLayerNetwork, TrainingMaster<?, ?> trainingMaster) {
        this(new JavaSparkContext(sparkContext), multiLayerNetwork, trainingMaster);
    }

    public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerConfiguration multiLayerConfiguration, TrainingMaster<?, ?> trainingMaster) {
        this(new JavaSparkContext(sparkContext), initNetwork(multiLayerConfiguration), trainingMaster);
    }

    public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerConfiguration multiLayerConfiguration, TrainingMaster<?, ?> trainingMaster) {
        this(javaSparkContext.sc(), multiLayerConfiguration, trainingMaster);
    }

    public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerNetwork multiLayerNetwork, TrainingMaster<?, ?> trainingMaster) {
        this.listeners = new ArrayList();
        this.sc = javaSparkContext;
        this.conf = multiLayerNetwork.getLayerWiseConfigurations().clone();
        this.network = multiLayerNetwork;
        if (!multiLayerNetwork.isInitCalled()) {
            multiLayerNetwork.init();
        }
        this.trainingMaster = trainingMaster;
        SparkUtils.checkKryoConfiguration(javaSparkContext, log);
    }

    private static MultiLayerNetwork initNetwork(MultiLayerConfiguration multiLayerConfiguration) {
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(multiLayerConfiguration);
        multiLayerNetwork.init();
        return multiLayerNetwork;
    }

    public JavaSparkContext getSparkContext() {
        return this.sc;
    }

    public MultiLayerNetwork getNetwork() {
        return this.network;
    }

    public TrainingMaster getTrainingMaster() {
        return this.trainingMaster;
    }

    public void setNetwork(MultiLayerNetwork multiLayerNetwork) {
        this.network = multiLayerNetwork;
    }

    public void setCollectTrainingStats(boolean z) {
        this.trainingMaster.setCollectTrainingStats(z);
    }

    public SparkTrainingStats getSparkTrainingStats() {
        return this.trainingMaster.getTrainingStats();
    }

    public Matrix predict(Matrix matrix) {
        return MLLibUtil.toMatrix(this.network.output(MLLibUtil.toMatrix(matrix)));
    }

    public Vector predict(Vector vector) {
        return MLLibUtil.toVector(this.network.output(MLLibUtil.toVector(vector)));
    }

    public MultiLayerNetwork fit(RDD<DataSet> rdd) {
        return fit(rdd.toJavaRDD());
    }

    public MultiLayerNetwork fit(JavaRDD<DataSet> javaRDD) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        this.trainingMaster.executeTraining(this, javaRDD);
        return this.network;
    }

    public MultiLayerNetwork fit(String str) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        try {
            return fitPaths(SparkUtils.listPaths(this.sc, str));
        } catch (IOException e) {
            throw new RuntimeException("Error listing paths in directory", e);
        }
    }

    @Deprecated
    public MultiLayerNetwork fit(String str, int i) {
        return fit(str);
    }

    public MultiLayerNetwork fitPaths(JavaRDD<String> javaRDD) {
        this.trainingMaster.executeTrainingPaths(this, javaRDD);
        return this.network;
    }

    public MultiLayerNetwork fitLabeledPoint(JavaRDD<LabeledPoint> javaRDD) {
        return fit(MLLibUtil.fromLabeledPoint(this.sc, javaRDD, this.network.getLayerWiseConfigurations().getConf(this.network.getLayerWiseConfigurations().getConfs().size() - 1).getLayer().getNOut()));
    }

    public MultiLayerNetwork fitContinuousLabeledPoint(JavaRDD<LabeledPoint> javaRDD) {
        return fit(MLLibUtil.fromContinuousLabeledPoint(this.sc, javaRDD));
    }

    public void setListeners(@NonNull Collection<IterationListener> collection) {
        if (collection == null) {
            throw new NullPointerException("listeners");
        }
        setListeners(null, collection);
    }

    public void setListeners(StatsStorageRouter statsStorageRouter, Collection<? extends IterationListener> collection) {
        VanillaStatsStorageRouterProvider vanillaStatsStorageRouterProvider = null;
        if (collection != null) {
            Iterator<? extends IterationListener> it = collection.iterator();
            while (it.hasNext()) {
                RoutingIterationListener routingIterationListener = (IterationListener) it.next();
                if (routingIterationListener instanceof RoutingIterationListener) {
                    RoutingIterationListener routingIterationListener2 = routingIterationListener;
                    if (routingIterationListener2.getStorageRouter() == null) {
                        log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}", routingIterationListener);
                    } else if (!(routingIterationListener2.getStorageRouter() instanceof Serializable)) {
                        throw new IllegalStateException("RoutingIterationListener provided with non-serializable storage router");
                    }
                    if (vanillaStatsStorageRouterProvider == null) {
                        vanillaStatsStorageRouterProvider = new VanillaStatsStorageRouterProvider();
                    }
                }
            }
        }
        this.listeners.clear();
        if (collection != null) {
            this.listeners.addAll(collection);
            if (this.trainingMaster != null) {
                this.trainingMaster.setListeners(statsStorageRouter, this.listeners);
            }
        }
    }

    protected void invokeListeners(MultiLayerNetwork multiLayerNetwork, int i) {
        Iterator<IterationListener> it = this.listeners.iterator();
        while (it.hasNext()) {
            try {
                it.next().iterationDone(multiLayerNetwork, i);
            } catch (Exception e) {
                log.error("Exception caught at IterationListener invocation" + e.getMessage());
                e.printStackTrace();
            }
        }
    }

    public double getScore() {
        return this.lastScore;
    }

    public void setScore(double d) {
        this.lastScore = d;
    }

    public double calculateScore(RDD<DataSet> rdd, boolean z) {
        return calculateScore(rdd.toJavaRDD(), z);
    }

    public double calculateScore(JavaRDD<DataSet> javaRDD, boolean z) {
        return calculateScore(javaRDD, z, 64);
    }

    public double calculateScore(JavaRDD<DataSet> javaRDD, boolean z, int i) {
        Tuple2 tuple2 = (Tuple2) javaRDD.mapPartitions(new ScoreFlatMapFunction(this.conf.toJson(), this.sc.broadcast(this.network.params(false)), i)).reduce(new IntDoubleReduceFunction());
        return z ? ((Double) tuple2._2()).doubleValue() / ((Integer) tuple2._1()).intValue() : ((Double) tuple2._2()).doubleValue();
    }

    public JavaDoubleRDD scoreExamples(RDD<DataSet> rdd, boolean z) {
        return scoreExamples(rdd.toJavaRDD(), z);
    }

    public JavaDoubleRDD scoreExamples(JavaRDD<DataSet> javaRDD, boolean z) {
        return scoreExamples(javaRDD, z, 64);
    }

    public JavaDoubleRDD scoreExamples(RDD<DataSet> rdd, boolean z, int i) {
        return scoreExamples(rdd.toJavaRDD(), z, i);
    }

    public JavaDoubleRDD scoreExamples(JavaRDD<DataSet> javaRDD, boolean z, int i) {
        return javaRDD.mapPartitionsToDouble(new ScoreExamplesFunction(this.sc.broadcast(this.network.params()), this.sc.broadcast(this.conf.toJson()), z, i));
    }

    public <K> JavaPairRDD<K, Double> scoreExamples(JavaPairRDD<K, DataSet> javaPairRDD, boolean z) {
        return scoreExamples(javaPairRDD, z, 64);
    }

    public <K> JavaPairRDD<K, Double> scoreExamples(JavaPairRDD<K, DataSet> javaPairRDD, boolean z, int i) {
        return javaPairRDD.mapPartitionsToPair(new ScoreExamplesWithKeyFunction(this.sc.broadcast(this.network.params()), this.sc.broadcast(this.conf.toJson()), z, i));
    }

    public Evaluation evaluate(RDD<DataSet> rdd) {
        return evaluate(rdd.toJavaRDD());
    }

    public Evaluation evaluate(JavaRDD<DataSet> javaRDD) {
        return evaluate(javaRDD, (List<String>) null);
    }

    public Evaluation evaluate(RDD<DataSet> rdd, List<String> list) {
        return evaluate(rdd.toJavaRDD(), list);
    }

    public Evaluation evaluate(JavaRDD<DataSet> javaRDD, List<String> list) {
        return evaluate(javaRDD, list, 64);
    }

    private void update(int i, long j) {
        Environment buildEnvironment = EnvironmentUtils.buildEnvironment();
        buildEnvironment.setNumCores(i);
        buildEnvironment.setAvailableMemory(j);
        Heartbeat.getInstance().reportEvent(Event.SPARK, buildEnvironment, ModelSerializer.taskByModel(this.network));
    }

    public Evaluation evaluate(JavaRDD<DataSet> javaRDD, List<String> list, int i) {
        return (Evaluation) javaRDD.mapPartitions(new EvaluateFlatMapFunction(this.sc.broadcast(this.conf.toJson()), this.sc.broadcast(this.network.params()), i, list == null ? null : this.sc.broadcast(list))).reduce(new EvaluationReduceFunction());
    }
}
