package org.deeplearning4j.nn.multilayer;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.layers.factory.LayerFactories;
import org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.util.MultiLayerUtil;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.nd4j.linalg.heartbeat.utils.TaskUtils;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.FeatureUtil;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/multilayer/MultiLayerNetwork.class */
public class MultiLayerNetwork implements Serializable, Classifier, Layer {
    private static final Logger log = LoggerFactory.getLogger(MultiLayerNetwork.class);
    protected Layer[] layers;
    protected LinkedHashMap<String, Layer> layerMap;
    protected INDArray input;
    protected INDArray labels;
    protected boolean initCalled;
    private Collection<IterationListener> listeners;
    protected NeuralNetConfiguration defaultConfiguration;
    protected MultiLayerConfiguration layerWiseConfigurations;
    protected Gradient gradient;
    protected INDArray epsilon;
    protected double score;
    protected boolean initDone;
    protected INDArray flattenedParams;
    protected transient INDArray flattenedGradients;
    protected INDArray mask;
    protected int layerIndex;
    protected transient Solver solver;

    public MultiLayerNetwork(MultiLayerConfiguration multiLayerConfiguration) {
        this.layerMap = new LinkedHashMap<>();
        this.initCalled = false;
        this.listeners = new ArrayList();
        this.initDone = false;
        this.layerWiseConfigurations = multiLayerConfiguration;
        this.defaultConfiguration = multiLayerConfiguration.getConf(0).m37clone();
    }

    public MultiLayerNetwork(String str, INDArray iNDArray) {
        this(MultiLayerConfiguration.fromJson(str));
        init();
        setParameters(iNDArray);
    }

    public MultiLayerNetwork(MultiLayerConfiguration multiLayerConfiguration, INDArray iNDArray) {
        this(multiLayerConfiguration);
        init();
        setParameters(iNDArray);
    }

    protected void intializeConfigurations() {
        if (this.layerWiseConfigurations == null) {
            this.layerWiseConfigurations = new MultiLayerConfiguration.Builder().build();
        }
        if (this.layers == null) {
            this.layers = new Layer[getnLayers()];
        }
        if (this.defaultConfiguration == null) {
            this.defaultConfiguration = new NeuralNetConfiguration.Builder().build();
        }
    }

    public void pretrain(DataSetIterator dataSetIterator) {
        INDArray featureMatrix;
        if (this.layerWiseConfigurations.isPretrain()) {
            for (int i = 0; i < getnLayers(); i++) {
                if (i == 0) {
                    while (dataSetIterator.hasNext()) {
                        DataSet dataSet = (DataSet) dataSetIterator.next();
                        if (getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                            INDArray featureMatrix2 = dataSet.getFeatureMatrix();
                            featureMatrix = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(featureMatrix2, featureMatrix2.size(0));
                        } else {
                            featureMatrix = dataSet.getFeatureMatrix();
                        }
                        setInput(featureMatrix);
                        if (getInput() == null || getLayers() == null) {
                            initializeLayers(input());
                        }
                        this.layers[i].fit(input());
                        log.info("Training on layer " + (i + 1) + " with " + input().size(0) + " examples");
                    }
                } else {
                    while (dataSetIterator.hasNext()) {
                        INDArray featureMatrix3 = ((DataSet) dataSetIterator.next()).getFeatureMatrix();
                        for (int i2 = 1; i2 <= i; i2++) {
                            featureMatrix3 = activationFromPrevLayer(i2 - 1, featureMatrix3, true);
                        }
                        log.info("Training on layer " + (i + 1) + " with " + featureMatrix3.size(0) + " examples");
                        getLayer(i).fit(featureMatrix3);
                    }
                }
                dataSetIterator.reset();
            }
        }
    }

    public void pretrain(INDArray iNDArray) {
        if (this.layerWiseConfigurations.isPretrain()) {
            int size = iNDArray.size(0);
            INDArray iNDArray2 = null;
            int i = 0;
            while (i < getnLayers() - 1) {
                iNDArray2 = i == 0 ? getLayerWiseConfigurations().getInputPreProcess(i) != null ? getLayerWiseConfigurations().getInputPreProcess(i).preProcess(iNDArray, size) : iNDArray : activationFromPrevLayer(i - 1, iNDArray2, true);
                log.info("Training on layer " + (i + 1) + " with " + iNDArray2.size(0) + " examples");
                getLayers()[i].fit(iNDArray2);
                i++;
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int batchSize() {
        return this.input.size(0);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public NeuralNetConfiguration conf() {
        return this.defaultConfiguration;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray input() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void validateInput() {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public ConvexOptimizer getOptimizer() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray getParam(String str) {
        int indexOf = str.indexOf("_");
        if (indexOf == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + str + "\"");
        }
        return this.layers[Integer.parseInt(str.substring(0, indexOf))].getParam(str.substring(indexOf + 1));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void initParams() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Map<String, INDArray> paramTable() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < this.layers.length; i++) {
            for (Map.Entry<String, INDArray> entry : this.layers[i].paramTable().entrySet()) {
                linkedHashMap.put(i + "_" + entry.getKey(), entry.getValue());
            }
        }
        return linkedHashMap;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamTable(Map<String, INDArray> map) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParam(String str, INDArray iNDArray) {
        int indexOf = str.indexOf("_");
        if (indexOf == -1) {
            throw new IllegalStateException("Invalid param key: not have layer separator: \"" + str + "\"");
        }
        this.layers[Integer.parseInt(str.substring(0, indexOf))].setParam(str.substring(indexOf + 1), iNDArray);
    }

    public MultiLayerConfiguration getLayerWiseConfigurations() {
        return this.layerWiseConfigurations;
    }

    public void setLayerWiseConfigurations(MultiLayerConfiguration multiLayerConfiguration) {
        this.layerWiseConfigurations = multiLayerConfiguration;
    }

    public void initializeLayers(INDArray iNDArray) {
        if (iNDArray == null) {
            throw new IllegalArgumentException("Unable to initialize neuralNets with empty input");
        }
        this.input = iNDArray;
        setInputMiniBatchSize(iNDArray.size(0));
        if (this.initCalled) {
            return;
        }
        init();
    }

    public void init() {
        if (this.layerWiseConfigurations == null || this.layers == null) {
            intializeConfigurations();
        }
        if (this.initCalled) {
            return;
        }
        int i = getnLayers();
        if (i < 1) {
            throw new IllegalStateException("Unable to createComplex network neuralNets; number specified is less than 1");
        }
        if (this.layers == null || this.layers[0] == null) {
            if (this.layers == null) {
                this.layers = new Layer[i];
            }
            int i2 = 0;
            int[] iArr = new int[i];
            for (int i3 = 0; i3 < i; i3++) {
                NeuralNetConfiguration conf = this.layerWiseConfigurations.getConf(i3);
                iArr[i3] = LayerFactories.getFactory(conf).initializer().numParams(conf, true);
                i2 += iArr[i3];
            }
            this.flattenedParams = Nd4j.create(1, i2);
            int i4 = 0;
            for (int i5 = 0; i5 < i; i5++) {
                INDArray iNDArray = iArr[i5] > 0 ? this.flattenedParams.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i4, i4 + iArr[i5])}) : null;
                i4 += iArr[i5];
                NeuralNetConfiguration conf2 = this.layerWiseConfigurations.getConf(i5);
                this.layers[i5] = LayerFactories.getFactory(conf2).create(conf2, this.listeners, i5, iNDArray);
                this.layerMap.put(conf2.getLayer().getLayerName(), this.layers[i5]);
            }
            this.initCalled = true;
            initMask();
        }
        this.defaultConfiguration.clearVariables();
        for (int i6 = 0; i6 < this.layers.length; i6++) {
            Iterator<String> it = this.layers[i6].conf().variables().iterator();
            while (it.hasNext()) {
                this.defaultConfiguration.addVariable(i6 + "_" + it.next());
            }
        }
    }

    protected void initGradientsView() {
        if (this.layers == null) {
            init();
        }
        int length = this.layers.length;
        int i = 0;
        int[] iArr = new int[length];
        for (int i2 = 0; i2 < length; i2++) {
            NeuralNetConfiguration conf = this.layerWiseConfigurations.getConf(i2);
            iArr[i2] = LayerFactories.getFactory(conf).initializer().numParams(conf, true);
            i += iArr[i2];
        }
        this.flattenedGradients = Nd4j.createUninitialized(new int[]{1, i}, 'f');
        int i3 = 0;
        for (int i4 = 0; i4 < this.layers.length; i4++) {
            if (iArr[i4] != 0) {
                this.layers[i4].setBackpropGradientsViewArray(this.flattenedGradients.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i3, i3 + iArr[i4])}));
                i3 += iArr[i4];
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return getLayers()[getLayers().length - 1].activate();
    }

    public INDArray activate(int i) {
        return getLayer(i).activate();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    public INDArray activate(int i, INDArray iNDArray) {
        return getLayer(i).activate(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activationMean() {
        throw new UnsupportedOperationException();
    }

    public void initialize(DataSet dataSet) {
        setInput(dataSet.getFeatureMatrix());
        feedForward(getInput());
        this.labels = dataSet.getLabels();
        if (getOutputLayer() instanceof BaseOutputLayer) {
            ((BaseOutputLayer) getOutputLayer()).setLabels(this.labels);
        }
    }

    public INDArray zFromPrevLayer(int i, INDArray iNDArray, boolean z) {
        if (getLayerWiseConfigurations().getInputPreProcess(i) != null) {
            iNDArray = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(iNDArray, iNDArray.size(0));
        }
        return this.layers[i].preOutput(iNDArray, z);
    }

    public INDArray activationFromPrevLayer(int i, INDArray iNDArray, boolean z) {
        if (getLayerWiseConfigurations().getInputPreProcess(i) != null) {
            iNDArray = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(iNDArray, getInputMiniBatchSize());
        }
        return this.layers[i].activate(iNDArray, z);
    }

    public INDArray activateSelectedLayers(int i, int i2, INDArray iNDArray) {
        if (iNDArray == null) {
            throw new IllegalStateException("Unable to perform activation; no input found");
        }
        if (i < 0 || i >= this.layers.length || i >= i2) {
            throw new IllegalStateException("Unable to perform activation; FROM is out of layer space");
        }
        if (i2 < 1 || i2 >= this.layers.length) {
            throw new IllegalStateException("Unable to perform activation; TO is out of layer space");
        }
        INDArray iNDArray2 = iNDArray;
        for (int i3 = i; i3 <= i2; i3++) {
            iNDArray2 = activationFromPrevLayer(i3, iNDArray2, false);
        }
        return iNDArray2;
    }

    public List<INDArray> computeZ(boolean z) {
        INDArray iNDArray = this.input;
        ArrayList arrayList = new ArrayList();
        arrayList.add(iNDArray);
        for (int i = 0; i < this.layers.length; i++) {
            iNDArray = zFromPrevLayer(i, iNDArray, z);
            arrayList.add(iNDArray);
        }
        return arrayList;
    }

    public List<INDArray> computeZ(INDArray iNDArray, boolean z) {
        if (iNDArray == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        if (getLayerWiseConfigurations().getInputPreProcess(0) != null) {
            setInput(getLayerWiseConfigurations().getInputPreProcess(0).preProcess(iNDArray, getInputMiniBatchSize()));
        } else {
            setInput(iNDArray);
        }
        return computeZ(z);
    }

    public List<INDArray> feedForward(INDArray iNDArray, boolean z) {
        setInput(iNDArray);
        return feedForward(z);
    }

    public List<INDArray> feedForward(boolean z) {
        return feedForwardToLayer(this.layers.length - 1, z);
    }

    public List<INDArray> feedForwardToLayer(int i, INDArray iNDArray) {
        return feedForwardToLayer(i, iNDArray, false);
    }

    public List<INDArray> feedForwardToLayer(int i, INDArray iNDArray, boolean z) {
        setInput(iNDArray);
        return feedForwardToLayer(i, z);
    }

    public List<INDArray> feedForwardToLayer(int i, boolean z) {
        INDArray iNDArray = this.input;
        ArrayList arrayList = new ArrayList();
        arrayList.add(iNDArray);
        for (int i2 = 0; i2 <= i; i2++) {
            iNDArray = activationFromPrevLayer(i2, iNDArray, z);
            arrayList.add(iNDArray);
        }
        return arrayList;
    }

    public List<INDArray> feedForward() {
        return feedForward(false);
    }

    public List<INDArray> feedForward(INDArray iNDArray) {
        if (iNDArray == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        if (getLayerWiseConfigurations().getInputPreProcess(0) != null) {
            setInput(getLayerWiseConfigurations().getInputPreProcess(0).preProcess(iNDArray, iNDArray.size(0)));
        } else {
            setInput(iNDArray);
        }
        return feedForward();
    }

    public List<INDArray> feedForward(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        setLayerMaskArrays(iNDArray2, iNDArray3);
        List<INDArray> feedForward = feedForward(iNDArray);
        clearLayerMaskArrays();
        return feedForward;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        return this.gradient;
    }

    public INDArray epsilon() {
        return this.epsilon;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(gradient(), Double.valueOf(score()));
    }

    protected List<INDArray> computeDeltasR(INDArray iNDArray) {
        ArrayList arrayList = new ArrayList();
        INDArray[] iNDArrayArr = new INDArray[getnLayers() + 1];
        List<INDArray> feedForward = feedForward();
        List<INDArray> feedForwardR = feedForwardR(feedForward, iNDArray);
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i = 0; i < getLayers().length; i++) {
            arrayList2.add(getLayers()[i].getParam("W"));
            arrayList3.add(getLayers()[i].getParam("b"));
            arrayList4.add(getLayers()[i].conf().getLayer().getActivationFunction());
        }
        INDArray divi = feedForwardR.get(feedForwardR.size() - 1).divi(Double.valueOf(this.input.size(0)));
        LinAlgExceptions.assertValidNum(divi);
        for (int i2 = getnLayers() - 1; i2 >= 0; i2--) {
            iNDArrayArr[i2] = feedForward.get(i2).transpose().mmul(divi);
            if (i2 > 0) {
                divi = divi.mmul(((INDArray) arrayList2.get(i2)).addRowVector((INDArray) arrayList3.get(i2)).transpose()).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform((String) arrayList4.get(i2 - 1), feedForward.get(i2)).derivative()));
            }
        }
        for (int i3 = 0; i3 < iNDArrayArr.length - 1; i3++) {
            arrayList.add(iNDArrayArr[i3]);
        }
        return arrayList;
    }

    protected List<Pair<INDArray, INDArray>> computeDeltas2() {
        ArrayList arrayList = new ArrayList();
        List<INDArray> feedForward = feedForward();
        INDArray[] iNDArrayArr = new INDArray[feedForward.size() - 1];
        INDArray[] iNDArrayArr2 = new INDArray[feedForward.size() - 1];
        INDArray div = feedForward.get(feedForward.size() - 1).sub(this.labels).div(Integer.valueOf(this.labels.size(0)));
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i = 0; i < getLayers().length; i++) {
            arrayList2.add(getLayers()[i].getParam("W"));
            arrayList3.add(getLayers()[i].getParam("b"));
            arrayList4.add(getLayers()[i].conf().getLayer().getActivationFunction());
        }
        for (int size = arrayList2.size() - 1; size >= 0; size--) {
            iNDArrayArr[size] = feedForward.get(size).transpose().mmul(div);
            iNDArrayArr2[size] = Transforms.pow(feedForward.get(size).transpose(), 2).mmul(Transforms.pow(div, 2)).muli(Integer.valueOf(this.labels.size(0)));
            if (size > 0) {
                div = div.mmul(((INDArray) arrayList2.get(size)).transpose()).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform((String) arrayList4.get(size - 1), feedForward.get(size)).derivative()));
            }
        }
        for (int i2 = 0; i2 < iNDArrayArr.length; i2++) {
            arrayList.add(new Pair(iNDArrayArr[i2], iNDArrayArr2[i2]));
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public MultiLayerNetwork m71clone() {
        try {
            MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) getClass().getDeclaredConstructor(MultiLayerConfiguration.class).newInstance(getLayerWiseConfigurations().m35clone());
            multiLayerNetwork.update(this);
            multiLayerNetwork.setParameters(params().dup());
            return multiLayerNetwork;
        } catch (Exception e) {
            throw new IllegalStateException("Unable to clone network", e);
        }
    }

    public INDArray params(boolean z) {
        if (z) {
            return params();
        }
        ArrayList arrayList = new ArrayList();
        for (Layer layer : getLayers()) {
            INDArray paramsBackprop = ((layer instanceof BasePretrainNetwork) && z) ? ((BasePretrainNetwork) layer).paramsBackprop() : layer.params();
            if (paramsBackprop != null) {
                arrayList.add(paramsBackprop);
            }
        }
        return Nd4j.toFlattened('f', arrayList);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray params() {
        return this.flattenedParams;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        if (this.flattenedParams == iNDArray) {
            return;
        }
        if (this.flattenedParams != null && iNDArray.length() == this.flattenedParams.length()) {
            this.flattenedParams.assign(iNDArray);
            return;
        }
        int i = 0;
        for (int i2 = 0; i2 < getLayers().length; i2++) {
            Layer layer = getLayer(i2);
            int numParamsBackprop = layer instanceof BasePretrainNetwork ? ((BasePretrainNetwork) layer).numParamsBackprop() : layer.numParams();
            if (numParamsBackprop > 0) {
                layer.setParams(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i, numParamsBackprop + i)}));
                i += numParamsBackprop;
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParamsViewArray(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams() {
        return numParams(false);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams(boolean z) {
        int i = 0;
        for (int i2 = 0; i2 < this.layers.length; i2++) {
            i += this.layers[i2].numParams(z);
        }
        return i;
    }

    public INDArray pack() {
        return params();
    }

    public INDArray pack(List<Pair<INDArray, INDArray>> list) {
        ArrayList arrayList = new ArrayList();
        for (Pair<INDArray, INDArray> pair : list) {
            arrayList.add(pair.getFirst());
            arrayList.add(pair.getSecond());
        }
        return Nd4j.toFlattened(arrayList);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(org.nd4j.linalg.dataset.api.DataSet dataSet) {
        return f1Score(dataSet.getFeatureMatrix(), dataSet.getLabels());
    }

    public List<Pair<INDArray, INDArray>> unPack(INDArray iNDArray) {
        if (iNDArray.size(0) != 1) {
            iNDArray = iNDArray.reshape(1, iNDArray.length());
        }
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (int i2 = 0; i2 < this.layers.length; i2++) {
            int length = this.layers[i2].getParam("W").length() + this.layers[i2].getParam("b").length();
            INDArray iNDArray2 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.interval(i, i + length)});
            INDArray iNDArray3 = iNDArray2.get(new INDArrayIndex[]{NDArrayIndex.interval(0, this.layers[i2].getParam("W").length())});
            INDArray iNDArray4 = iNDArray2.get(new INDArrayIndex[]{NDArrayIndex.interval(this.layers[i2].getParam("W").length(), iNDArray2.length())});
            if (iNDArray3.length() + iNDArray4.length() != length) {
                if (iNDArray4.length() != this.layers[i2].getParam("b").length()) {
                    throw new IllegalStateException("Hidden bias on layer " + i2 + " was off");
                }
                if (iNDArray3.length() != this.layers[i2].getParam("W").length()) {
                    throw new IllegalStateException("Weight portion on layer " + i2 + " was off");
                }
            }
            arrayList.add(new Pair(iNDArray3.reshape(this.layers[i2].getParam("W").size(0), this.layers[i2].getParam("W").columns()), iNDArray4.reshape(this.layers[i2].getParam("b").size(0), this.layers[i2].getParam("b").columns())));
            i += length;
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(DataSetIterator dataSetIterator) {
        if (this.layerWiseConfigurations.isPretrain()) {
            pretrain(dataSetIterator);
            dataSetIterator.reset();
            while (dataSetIterator.hasNext()) {
                DataSet dataSet = (DataSet) dataSetIterator.next();
                if (dataSet.getFeatureMatrix() == null || dataSet.getLabels() == null) {
                    break;
                }
                setInput(dataSet.getFeatureMatrix());
                setLabels(dataSet.getLabels());
                finetune();
            }
        }
        if (this.layerWiseConfigurations.isBackprop()) {
            if (this.layerWiseConfigurations.isPretrain()) {
                dataSetIterator.reset();
            }
            update(TaskUtils.buildTask(dataSetIterator));
            dataSetIterator.reset();
            while (dataSetIterator.hasNext()) {
                DataSet dataSet2 = (DataSet) dataSetIterator.next();
                if (dataSet2.getFeatureMatrix() == null || dataSet2.getLabels() == null) {
                    return;
                }
                boolean hasMaskArrays = dataSet2.hasMaskArrays();
                if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                    doTruncatedBPTT(dataSet2.getFeatureMatrix(), dataSet2.getLabels(), dataSet2.getFeaturesMaskArray(), dataSet2.getLabelsMaskArray());
                } else {
                    if (hasMaskArrays) {
                        setLayerMaskArrays(dataSet2.getFeaturesMaskArray(), dataSet2.getLabelsMaskArray());
                    }
                    setInput(dataSet2.getFeatureMatrix());
                    setLabels(dataSet2.getLabels());
                    if (this.solver == null) {
                        this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
                    }
                    this.solver.optimize();
                }
                if (hasMaskArrays) {
                    clearLayerMaskArrays();
                }
            }
        }
    }

    protected void backprop() {
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        Pair<Gradient, INDArray> calcBackpropGradients = calcBackpropGradients(null, true);
        this.gradient = calcBackpropGradients == null ? null : calcBackpropGradients.getFirst();
        this.epsilon = calcBackpropGradients == null ? null : calcBackpropGradients.getSecond();
    }

    protected Pair<Gradient, INDArray> calcBackpropGradients(INDArray iNDArray, boolean z) {
        Pair<Gradient, INDArray> pair;
        int i;
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        DefaultGradient defaultGradient = new DefaultGradient(this.flattenedGradients);
        int i2 = getnLayers();
        LinkedList linkedList = new LinkedList();
        if (!z) {
            pair = new Pair<>(null, iNDArray);
            i = i2 - 1;
        } else {
            if (!(getOutputLayer() instanceof BaseOutputLayer)) {
                log.warn("Warning: final layer isn't output layer. You cannot use backprop without an output layer.");
                return null;
            }
            BaseOutputLayer baseOutputLayer = (BaseOutputLayer) getOutputLayer();
            if (this.labels == null) {
                throw new IllegalStateException("No labels found");
            }
            baseOutputLayer.setLabels(this.labels);
            pair = baseOutputLayer.backpropGradient(null);
            for (Map.Entry<String, INDArray> entry : pair.getFirst().gradientForVariable().entrySet()) {
                String key = entry.getKey();
                linkedList.addLast(new Triple(String.valueOf(i2 - 1) + "_" + key, entry.getValue(), pair.getFirst().flatteningOrderForVariable(key)));
            }
            if (getLayerWiseConfigurations().getInputPreProcess(i2 - 1) != null) {
                pair = new Pair<>(pair.getFirst(), this.layerWiseConfigurations.getInputPreProcess(i2 - 1).backprop(pair.getSecond(), getInputMiniBatchSize()));
            }
            i = i2 - 2;
        }
        for (int i3 = i; i3 >= 0; i3--) {
            pair = getLayer(i3).backpropGradient(pair.getSecond());
            LinkedList linkedList2 = new LinkedList();
            for (Map.Entry<String, INDArray> entry2 : pair.getFirst().gradientForVariable().entrySet()) {
                String key2 = entry2.getKey();
                linkedList2.addFirst(new Triple(String.valueOf(i3) + "_" + key2, entry2.getValue(), pair.getFirst().flatteningOrderForVariable(key2)));
            }
            Iterator it = linkedList2.iterator();
            while (it.hasNext()) {
                linkedList.addFirst((Triple) it.next());
            }
            if (getLayerWiseConfigurations().getInputPreProcess(i3) != null) {
                pair = new Pair<>(pair.getFirst(), getLayerWiseConfigurations().getInputPreProcess(i3).backprop(pair.getSecond(), getInputMiniBatchSize()));
            }
        }
        Iterator it2 = linkedList.iterator();
        while (it2.hasNext()) {
            Triple triple = (Triple) it2.next();
            defaultGradient.setGradientFor((String) triple.getFirst(), (INDArray) triple.getSecond(), (Character) triple.getThird());
        }
        return new Pair<>(defaultGradient, pair.getSecond());
    }

    protected void doTruncatedBPTT(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        if (iNDArray.rank() != 3 || iNDArray2.rank() != 3) {
            log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got " + Arrays.toString(iNDArray.shape()) + "\t" + Arrays.toString(iNDArray2.shape()));
            return;
        }
        if (iNDArray.size(2) != iNDArray2.size(2)) {
            log.warn("Input and label time series have different lengths: {} input length, {} label length", Integer.valueOf(iNDArray.size(2)), Integer.valueOf(iNDArray2.size(2)));
            return;
        }
        int tbpttFwdLength = this.layerWiseConfigurations.getTbpttFwdLength();
        update(TaskUtils.buildTask(iNDArray, iNDArray2));
        int size = iNDArray.size(2);
        int i = size / tbpttFwdLength;
        if (tbpttFwdLength > size) {
            log.warn("Cannot do TBPTT: Truncated BPTT forward length (" + tbpttFwdLength + ") > input time series length (" + size + ")");
            return;
        }
        rnnClearPreviousState();
        for (int i2 = 0; i2 < i; i2++) {
            int i3 = i2 * tbpttFwdLength;
            int i4 = i3 + tbpttFwdLength;
            INDArray iNDArray5 = iNDArray.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i3, i4)});
            INDArray iNDArray6 = iNDArray2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(i3, i4)});
            setInput(iNDArray5);
            setLabels(iNDArray6);
            INDArray iNDArray7 = iNDArray3 != null ? iNDArray3.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i3, i4)}) : null;
            INDArray iNDArray8 = iNDArray4 != null ? iNDArray4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i3, i4)}) : null;
            if (iNDArray7 != null || iNDArray8 != null) {
                setLayerMaskArrays(iNDArray7, iNDArray8);
            }
            if (this.solver == null) {
                this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
            }
            this.solver.optimize();
            updateRnnStateWithTBPTTState();
        }
        rnnClearPreviousState();
        if (iNDArray3 == null && iNDArray4 == null) {
            return;
        }
        clearLayerMaskArrays();
    }

    public void updateRnnStateWithTBPTTState() {
        for (int i = 0; i < this.layers.length; i++) {
            if (this.layers[i] instanceof BaseRecurrentLayer) {
                BaseRecurrentLayer baseRecurrentLayer = (BaseRecurrentLayer) this.layers[i];
                baseRecurrentLayer.rnnSetPreviousState(baseRecurrentLayer.rnnGetTBPTTState());
            } else if (this.layers[i] instanceof MultiLayerNetwork) {
                ((MultiLayerNetwork) this.layers[i]).updateRnnStateWithTBPTTState();
            }
        }
    }

    protected void truncatedBPTTGradient() {
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        this.gradient = new DefaultGradient();
        if (!(getOutputLayer() instanceof BaseOutputLayer)) {
            log.warn("Warning: final layer isn't output layer. You cannot use backprop (truncated BPTT) without an output layer.");
            return;
        }
        BaseOutputLayer baseOutputLayer = (BaseOutputLayer) getOutputLayer();
        if (this.labels == null) {
            throw new IllegalStateException("No labels found");
        }
        if (baseOutputLayer.conf().getLayer().getWeightInit() == WeightInit.ZERO) {
            throw new IllegalStateException("Output layer weights cannot be initialized to zero when using backprop.");
        }
        baseOutputLayer.setLabels(this.labels);
        int i = getnLayers();
        LinkedList linkedList = new LinkedList();
        Pair<Gradient, INDArray> backpropGradient = baseOutputLayer.backpropGradient(null);
        for (Map.Entry<String, INDArray> entry : backpropGradient.getFirst().gradientForVariable().entrySet()) {
            linkedList.addLast(new Pair(String.valueOf(i - 1) + "_" + entry.getKey(), entry.getValue()));
        }
        if (getLayerWiseConfigurations().getInputPreProcess(i - 1) != null) {
            backpropGradient = new Pair<>(backpropGradient.getFirst(), this.layerWiseConfigurations.getInputPreProcess(i - 1).backprop(backpropGradient.getSecond(), getInputMiniBatchSize()));
        }
        for (int i2 = i - 2; i2 >= 0; i2--) {
            Layer layer = getLayer(i2);
            backpropGradient = layer instanceof BaseRecurrentLayer ? ((BaseRecurrentLayer) layer).tbpttBackpropGradient(backpropGradient.getSecond(), this.layerWiseConfigurations.getTbpttBackLength()) : layer.backpropGradient(backpropGradient.getSecond());
            LinkedList linkedList2 = new LinkedList();
            for (Map.Entry<String, INDArray> entry2 : backpropGradient.getFirst().gradientForVariable().entrySet()) {
                linkedList2.addFirst(new Pair(String.valueOf(i2) + "_" + entry2.getKey(), entry2.getValue()));
            }
            Iterator it = linkedList2.iterator();
            while (it.hasNext()) {
                linkedList.addFirst((Pair) it.next());
            }
            if (getLayerWiseConfigurations().getInputPreProcess(i2) != null) {
                backpropGradient = new Pair<>(backpropGradient.getFirst(), getLayerWiseConfigurations().getInputPreProcess(i2).backprop(backpropGradient.getSecond(), getInputMiniBatchSize()));
            }
        }
        Iterator it2 = linkedList.iterator();
        while (it2.hasNext()) {
            Pair pair = (Pair) it2.next();
            this.gradient.setGradientFor((String) pair.getFirst(), (INDArray) pair.getSecond());
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setListeners(Collection<IterationListener> collection) {
        this.listeners = collection;
        if (this.layers == null) {
            init();
        }
        for (Layer layer : this.layers) {
            layer.setListeners(collection);
        }
        if (this.solver != null) {
            this.solver.setListeners(collection);
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setListeners(IterationListener... iterationListenerArr) {
        ArrayList arrayList = new ArrayList();
        Collections.addAll(arrayList, iterationListenerArr);
        setListeners(arrayList);
    }

    public void finetune() {
        if (this.flattenedGradients == null) {
            initGradientsView();
        }
        if (!(getOutputLayer() instanceof BaseOutputLayer)) {
            log.warn("Output layer not instance of output layer returning.");
            return;
        }
        if (this.labels == null) {
            throw new IllegalStateException("No labels found");
        }
        log.info("Finetune phase");
        BaseOutputLayer baseOutputLayer = (BaseOutputLayer) getOutputLayer();
        if (baseOutputLayer.conf().getOptimizationAlgo() == OptimizationAlgorithm.HESSIAN_FREE) {
            throw new UnsupportedOperationException();
        }
        feedForward();
        baseOutputLayer.fit(baseOutputLayer.input(), this.labels);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int[] predict(INDArray iNDArray) {
        INDArray output = output(iNDArray, Layer.TrainingMode.TEST);
        int[] iArr = new int[iNDArray.size(0)];
        if (iNDArray.isRowVector()) {
            iArr[0] = Nd4j.getBlasWrapper().iamax(output);
        } else {
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
            }
        }
        return iArr;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public List<String> predict(org.nd4j.linalg.dataset.api.DataSet dataSet) {
        int[] predict = predict(dataSet.getFeatureMatrix());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < predict.length; i++) {
            arrayList.add(i, dataSet.getLabelName(predict[i]));
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public INDArray labelProbabilities(INDArray iNDArray) {
        List<INDArray> feedForward = feedForward(iNDArray);
        return ((BaseOutputLayer) getOutputLayer()).labelProbabilities(feedForward.get(feedForward.size() - 1));
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, INDArray iNDArray2) {
        setInput(iNDArray);
        setLabels(iNDArray2);
        update(TaskUtils.buildTask(iNDArray, iNDArray2));
        if (this.layerWiseConfigurations.isPretrain()) {
            pretrain(iNDArray);
            finetune();
        }
        if (this.layerWiseConfigurations.isBackprop()) {
            if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
                doTruncatedBPTT(iNDArray, iNDArray2, null, null);
                return;
            }
            if (this.solver == null) {
                this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
            }
            this.solver.optimize();
        }
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
        setInput(iNDArray);
        update(TaskUtils.buildTask(iNDArray));
        pretrain(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void iterate(INDArray iNDArray) {
        pretrain(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(org.nd4j.linalg.dataset.api.DataSet dataSet) {
        if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
            doTruncatedBPTT(dataSet.getFeatureMatrix(), dataSet.getLabels(), dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray());
            return;
        }
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            setLayerMaskArrays(dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray());
        }
        fit(dataSet.getFeatureMatrix(), dataSet.getLabels());
        if (hasMaskArrays) {
            clearLayerMaskArrays();
        }
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, int[] iArr) {
        fit(iNDArray, FeatureUtil.toOutcomeMatrix(iArr, ((OutputLayer) getOutputLayer().conf().getLayer()).getNOut()));
    }

    public INDArray output(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        return output(iNDArray, trainingMode == Layer.TrainingMode.TRAIN);
    }

    public INDArray output(INDArray iNDArray, boolean z) {
        List<INDArray> feedForward = feedForward(iNDArray, z);
        return feedForward.get(feedForward.size() - 1);
    }

    public INDArray output(INDArray iNDArray, boolean z, INDArray iNDArray2, INDArray iNDArray3) {
        setLayerMaskArrays(iNDArray2, iNDArray3);
        INDArray output = output(iNDArray, z);
        clearLayerMaskArrays();
        return output;
    }

    public INDArray output(INDArray iNDArray) {
        return output(iNDArray, Layer.TrainingMode.TRAIN);
    }

    public INDArray reconstruct(INDArray iNDArray, int i) {
        return feedForward(iNDArray).get(i - 1);
    }

    public void printConfiguration() {
        StringBuilder sb = new StringBuilder();
        int i = 0;
        Iterator<NeuralNetConfiguration> it = getLayerWiseConfigurations().getConfs().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            sb.append(" Layer " + i2 + " conf " + it.next());
        }
        log.info(sb.toString());
    }

    public void update(MultiLayerNetwork multiLayerNetwork) {
        this.defaultConfiguration = multiLayerNetwork.defaultConfiguration != null ? multiLayerNetwork.defaultConfiguration.m37clone() : null;
        if (multiLayerNetwork.input != null) {
            setInput(multiLayerNetwork.input.dup());
        }
        this.labels = multiLayerNetwork.labels;
        if (multiLayerNetwork.layers != null) {
            this.layers = new Layer[multiLayerNetwork.layers.length];
            for (int i = 0; i < this.layers.length; i++) {
                this.layers[i] = multiLayerNetwork.layers[i].m71clone();
            }
        } else {
            this.layers = null;
        }
        if (multiLayerNetwork.solver != null) {
            setUpdater(multiLayerNetwork.getUpdater().m72clone());
        } else {
            this.solver = null;
        }
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double f1Score(INDArray iNDArray, INDArray iNDArray2) {
        feedForward(iNDArray);
        setLabels(iNDArray2);
        Evaluation evaluation = new Evaluation();
        evaluation.eval(iNDArray2, labelProbabilities(iNDArray));
        return evaluation.f1();
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int numLabels() {
        return this.labels.columns();
    }

    public double score(DataSet dataSet) {
        return score(dataSet, false);
    }

    public double score(DataSet dataSet, boolean z) {
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            setLayerMaskArrays(dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray());
        }
        List<INDArray> feedForwardToLayer = feedForwardToLayer(this.layers.length - 2, dataSet.getFeatureMatrix(), z);
        int size = feedForwardToLayer.size();
        setLabels(dataSet.getLabels());
        if (!(getOutputLayer() instanceof BaseOutputLayer)) {
            log.warn("Cannot calculate score wrt labels without an OutputLayer");
            return 0.0d;
        }
        BaseOutputLayer baseOutputLayer = (BaseOutputLayer) getOutputLayer();
        INDArray iNDArray = feedForwardToLayer.get(size - 1);
        if (getLayerWiseConfigurations().getInputPreProcess(size - 1) != null) {
            iNDArray = getLayerWiseConfigurations().getInputPreProcess(size - 1).preProcess(iNDArray, this.input.size(0));
        }
        baseOutputLayer.setInput(iNDArray);
        baseOutputLayer.setLabels(dataSet.getLabels());
        baseOutputLayer.computeScore(calcL1(), calcL2(), z);
        this.score = baseOutputLayer.score();
        if (hasMaskArrays) {
            clearLayerMaskArrays();
        }
        return score();
    }

    public INDArray scoreExamples(DataSetIterator dataSetIterator, boolean z) {
        ArrayList arrayList = new ArrayList();
        while (dataSetIterator.hasNext()) {
            arrayList.add(scoreExamples((DataSet) dataSetIterator.next(), z));
        }
        return Nd4j.toFlattened('f', arrayList);
    }

    public INDArray scoreExamples(DataSet dataSet, boolean z) {
        boolean hasMaskArrays = dataSet.hasMaskArrays();
        if (hasMaskArrays) {
            setLayerMaskArrays(dataSet.getFeaturesMaskArray(), dataSet.getLabelsMaskArray());
        }
        feedForward(dataSet.getFeatureMatrix(), false);
        setLabels(dataSet.getLabels());
        if (!(getOutputLayer() instanceof BaseOutputLayer)) {
            throw new UnsupportedOperationException("Cannot calculate score wrt labels without an OutputLayer");
        }
        BaseOutputLayer baseOutputLayer = (BaseOutputLayer) getOutputLayer();
        baseOutputLayer.setLabels(dataSet.getLabels());
        INDArray computeScoreForExamples = baseOutputLayer.computeScoreForExamples(z ? calcL1() : 0.0d, z ? calcL2() : 0.0d);
        if (hasMaskArrays) {
            clearLayerMaskArrays();
        }
        return computeScoreForExamples;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit() {
        fit(this.input, this.labels);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void update(INDArray iNDArray, String str) {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public double score() {
        return this.score;
    }

    public void setScore(double d) {
        this.score = d;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void computeGradientAndScore() {
        if (this.layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) {
            rnnActivateUsingStoredState(getInput(), true, true);
            truncatedBPTTGradient();
        } else {
            List<INDArray> feedForwardToLayer = feedForwardToLayer(this.layers.length - 2, true);
            INDArray iNDArray = feedForwardToLayer.get(feedForwardToLayer.size() - 1);
            if (this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1) != null) {
                iNDArray = this.layerWiseConfigurations.getInputPreProcess(this.layers.length - 1).preProcess(iNDArray, getInputMiniBatchSize());
            }
            getOutputLayer().setInput(iNDArray);
            backprop();
        }
        this.score = ((BaseOutputLayer) getOutputLayer()).computeScore(calcL1(), calcL2(), true);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void accumulateScore(double d) {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void clear() {
        for (Layer layer : this.layers) {
            layer.clear();
        }
        this.input = null;
        this.labels = null;
        this.solver = null;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void merge(Layer layer, int i) {
        throw new UnsupportedOperationException();
    }

    public void merge(MultiLayerNetwork multiLayerNetwork, int i) {
        if (multiLayerNetwork.layers.length != this.layers.length) {
            throw new IllegalArgumentException("Unable to merge networks that are not of equal length");
        }
        for (int i2 = 0; i2 < getnLayers(); i2++) {
            this.layers[i2].merge(multiLayerNetwork.layers[i2], i);
        }
        getOutputLayer().merge(multiLayerNetwork.getOutputLayer(), i);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setInput(INDArray iNDArray) {
        this.input = iNDArray;
        if (this.layers == null) {
            initializeLayers(getInput());
        }
        if (iNDArray != null) {
            if (iNDArray.length() == 0) {
                throw new IllegalArgumentException("Invalid input: length 0 (shape: " + Arrays.toString(iNDArray.shape()) + ")");
            }
            setInputMiniBatchSize(iNDArray.size(0));
        }
    }

    private void initMask() {
        setMask(Nd4j.ones(1, pack().length()));
    }

    public Layer getOutputLayer() {
        return getLayers()[getLayers().length - 1];
    }

    public void setParameters(INDArray iNDArray) {
        setParams(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void applyLearningRateScoreDecay() {
        for (Layer layer : this.layers) {
            if (!layer.conf().getLearningRateByParam().isEmpty()) {
                for (Map.Entry<String, Double> entry : layer.conf().getLearningRateByParam().entrySet()) {
                    layer.conf().setLearningRateByParam(entry.getKey(), entry.getValue().doubleValue() * (layer.conf().getLrPolicyDecayRate() + Nd4j.EPS_THRESHOLD));
                }
            }
        }
    }

    public List<INDArray> feedForwardR(List<INDArray> list, INDArray iNDArray) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(Nd4j.zeros(this.input.size(0), this.input.columns()));
        List<Pair<INDArray, INDArray>> unPack = unPack(iNDArray);
        List<INDArray> weightMatrices = MultiLayerUtil.weightMatrices(this);
        for (int i = 0; i < this.layers.length; i++) {
            arrayList.add(((INDArray) arrayList.get(i)).mmul(weightMatrices.get(i)).addi(list.get(i).mmul(unPack.get(i).getFirst().addiRowVector(unPack.get(i).getSecond()))).muli(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(getLayers()[i].conf().getLayer().getActivationFunction(), list.get(i + 1)).derivative())));
        }
        return arrayList;
    }

    public NeuralNetConfiguration getDefaultConfiguration() {
        return this.defaultConfiguration;
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public INDArray getInput() {
        return this.input;
    }

    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    public int getnLayers() {
        return this.layerWiseConfigurations.getConfs().size();
    }

    public Layer[] getLayers() {
        return this.layers;
    }

    public Layer getLayer(int i) {
        return this.layers[i];
    }

    public Layer getLayer(String str) {
        return this.layerMap.get(str);
    }

    public List<String> getLayerNames() {
        return new ArrayList(this.layerMap.keySet());
    }

    public void setLayers(Layer[] layerArr) {
        this.layers = layerArr;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public void setMask(INDArray iNDArray) {
        this.mask = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Gradient error(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.MULTILAYER;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray derivativeActivation(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray) {
        INDArray iNDArray2 = iNDArray;
        for (int i = 0; i < this.layers.length - 1; i++) {
            if (getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                iNDArray2 = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(iNDArray2, getInputMiniBatchSize());
            }
            iNDArray2 = this.layers[i].activate(iNDArray2);
        }
        if (getLayerWiseConfigurations().getInputPreProcess(this.layers.length - 1) != null) {
            iNDArray2 = getLayerWiseConfigurations().getInputPreProcess(this.layers.length - 1).preProcess(iNDArray2, getInputMiniBatchSize());
        }
        return this.layers[this.layers.length - 1].preOutput(iNDArray2);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        return preOutput(iNDArray, trainingMode == Layer.TrainingMode.TRAIN);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(Layer.TrainingMode trainingMode) {
        return activate(trainingMode == Layer.TrainingMode.TRAIN);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, Layer.TrainingMode trainingMode) {
        return activate(iNDArray, trainingMode == Layer.TrainingMode.TRAIN);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        if (this.layers[this.layers.length - 1] instanceof BaseOutputLayer) {
            throw new UnsupportedOperationException("Cannot calculate gradients based on epsilon with OutputLayer");
        }
        return calcBackpropGradients(iNDArray, false);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setIndex(int i) {
        this.layerIndex = i;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getIndex() {
        return this.layerIndex;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public double calcL2() {
        double d = 0.0d;
        for (int i = 0; i < this.layers.length; i++) {
            d += this.layers[i].calcL2();
        }
        return d;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public double calcL1() {
        double d = 0.0d;
        for (int i = 0; i < this.layers.length; i++) {
            d += this.layers[i].calcL1();
        }
        return d;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void update(Gradient gradient) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setInputMiniBatchSize(int i) {
        if (this.layers != null) {
            for (Layer layer : this.layers) {
                layer.setInputMiniBatchSize(i);
            }
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public int getInputMiniBatchSize() {
        return this.input.size(0);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setMaskArray(INDArray iNDArray) {
        throw new UnsupportedOperationException();
    }

    public INDArray rnnTimeStep(INDArray iNDArray) {
        setInputMiniBatchSize(iNDArray.size(0));
        boolean z = iNDArray.rank() == 2;
        for (int i = 0; i < this.layers.length; i++) {
            if (getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                iNDArray = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(iNDArray, getInputMiniBatchSize());
            }
            iNDArray = this.layers[i] instanceof BaseRecurrentLayer ? ((BaseRecurrentLayer) this.layers[i]).rnnTimeStep(iNDArray) : this.layers[i] instanceof MultiLayerNetwork ? ((MultiLayerNetwork) this.layers[i]).rnnTimeStep(iNDArray) : this.layers[i].activate(iNDArray, false);
        }
        return (z && iNDArray.rank() == 3 && this.layers[this.layers.length - 1].type() == Layer.Type.RECURRENT) ? iNDArray.tensorAlongDimension(0, new int[]{1, 0}) : iNDArray;
    }

    public Map<String, INDArray> rnnGetPreviousState(int i) {
        if (i < 0 || i >= this.layers.length) {
            throw new IllegalArgumentException("Invalid layer number");
        }
        if (this.layers[i] instanceof BaseRecurrentLayer) {
            return ((BaseRecurrentLayer) this.layers[i]).rnnGetPreviousState();
        }
        throw new IllegalArgumentException("Layer is not an RNN layer");
    }

    public void rnnSetPreviousState(int i, Map<String, INDArray> map) {
        if (i < 0 || i >= this.layers.length) {
            throw new IllegalArgumentException("Invalid layer number");
        }
        if (!(this.layers[i] instanceof BaseRecurrentLayer)) {
            throw new IllegalArgumentException("Layer is not an RNN layer");
        }
        ((BaseRecurrentLayer) this.layers[i]).rnnSetPreviousState(map);
    }

    public void rnnClearPreviousState() {
        if (this.layers == null) {
            return;
        }
        for (int i = 0; i < this.layers.length; i++) {
            if (this.layers[i] instanceof BaseRecurrentLayer) {
                ((BaseRecurrentLayer) this.layers[i]).rnnClearPreviousState();
            } else if (this.layers[i] instanceof MultiLayerNetwork) {
                ((MultiLayerNetwork) this.layers[i]).rnnClearPreviousState();
            }
        }
    }

    public List<INDArray> rnnActivateUsingStoredState(INDArray iNDArray, boolean z, boolean z2) {
        INDArray activate;
        INDArray iNDArray2 = iNDArray;
        ArrayList arrayList = new ArrayList();
        arrayList.add(iNDArray2);
        for (int i = 0; i < this.layers.length; i++) {
            if (getLayerWiseConfigurations().getInputPreProcess(i) != null) {
                iNDArray2 = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(iNDArray2, iNDArray.size(0));
            }
            if (this.layers[i] instanceof BaseRecurrentLayer) {
                activate = ((BaseRecurrentLayer) this.layers[i]).rnnActivateUsingStoredState(iNDArray2, z, z2);
            } else if (this.layers[i] instanceof MultiLayerNetwork) {
                List<INDArray> rnnActivateUsingStoredState = ((MultiLayerNetwork) this.layers[i]).rnnActivateUsingStoredState(iNDArray2, z, z2);
                activate = rnnActivateUsingStoredState.get(rnnActivateUsingStoredState.size() - 1);
            } else {
                activate = this.layers[i].activate(iNDArray2, z);
            }
            iNDArray2 = activate;
            arrayList.add(iNDArray2);
        }
        return arrayList;
    }

    public synchronized Updater getUpdater() {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
            this.solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this));
        }
        return this.solver.getOptimizer().getUpdater();
    }

    public void setUpdater(Updater updater) {
        if (this.solver == null) {
            this.solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build();
        }
        this.solver.getOptimizer().setUpdater(updater);
    }

    public void setLayerMaskArrays(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray != null) {
            INDArray reshapeTimeSeriesMaskToVector = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(iNDArray);
            for (int i = 0; i < this.layers.length - 1; i++) {
                Layer.Type type = this.layers[i].type();
                if (type != Layer.Type.CONVOLUTIONAL && type != Layer.Type.FEED_FORWARD) {
                    if (type == Layer.Type.RECURRENT) {
                        break;
                    }
                } else {
                    this.layers[i].setMaskArray(reshapeTimeSeriesMaskToVector);
                }
            }
        }
        if (iNDArray2 == null || !(this.layers[this.layers.length - 1] instanceof BaseOutputLayer)) {
            return;
        }
        this.layers[this.layers.length - 1].setMaskArray(iNDArray2);
    }

    public void clearLayerMaskArrays() {
        for (Layer layer : this.layers) {
            layer.setMaskArray(null);
        }
    }

    public Evaluation evaluate(DataSetIterator dataSetIterator) {
        return evaluate(dataSetIterator, null);
    }

    public Evaluation evaluate(DataSetIterator dataSetIterator, List<String> list) {
        if (this.layers == null || !(this.layers[this.layers.length - 1] instanceof BaseOutputLayer)) {
            throw new IllegalStateException("Cannot evaluate network with no output layer");
        }
        Evaluation evaluation = list == null ? new Evaluation() : new Evaluation(list);
        while (dataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) dataSetIterator.next();
            if (dataSet.getFeatureMatrix() == null || dataSet.getLabels() == null) {
                break;
            }
            INDArray features = dataSet.getFeatures();
            INDArray labels = dataSet.getLabels();
            if (dataSet.hasMaskArrays()) {
                INDArray featuresMaskArray = dataSet.getFeaturesMaskArray();
                INDArray labelsMaskArray = dataSet.getLabelsMaskArray();
                INDArray output = output(features, false, featuresMaskArray, labelsMaskArray);
                if (labelsMaskArray != null) {
                    evaluation.evalTimeSeries(labels, output, labelsMaskArray);
                } else {
                    evaluation.evalTimeSeries(labels, output);
                }
            } else {
                INDArray output2 = output(features, false);
                if (labels.rank() == 3) {
                    evaluation.evalTimeSeries(labels, output2);
                } else {
                    evaluation.eval(labels, output2);
                }
            }
        }
        return evaluation;
    }

    private void update(Task task) {
        if (this.initDone) {
            return;
        }
        this.initDone = true;
        Heartbeat heartbeat = Heartbeat.getInstance();
        Task taskByModel = ModelSerializer.taskByModel(this);
        heartbeat.reportEvent(Event.STANDALONE, EnvironmentUtils.buildEnvironment(), taskByModel);
    }

    public void setInitDone(boolean z) {
        this.initDone = z;
    }
}
