package org.deeplearning4j.nn;

import java.io.InputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.models.featuredetectors.autoencoder.AutoEncoder;
import org.deeplearning4j.nn.api.Classifier;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.Persistable;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.Layer;
import org.deeplearning4j.nn.layers.OutputLayer;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.deeplearning4j.optimize.optimizers.BackPropOptimizer;
import org.deeplearning4j.optimize.optimizers.BackPropROptimizer;
import org.deeplearning4j.optimize.optimizers.MultiLayerNetworkOptimizer;
import org.deeplearning4j.util.Dl4jReflection;
import org.deeplearning4j.util.SerializationUtils;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.sampling.Sampling;
import org.nd4j.linalg.transformation.MatrixTransform;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/BaseMultiLayerNetwork.class */
public abstract class BaseMultiLayerNetwork implements Serializable, Persistable, Classifier {
    private static Logger log;
    private static final long serialVersionUID = -5029161847383716484L;
    protected int[] hiddenLayerSizes;
    protected Layer[] layers;
    protected INDArray input;
    protected INDArray labels;
    protected MultiLayerNetworkOptimizer optimizer;
    protected Map<Integer, MatrixTransform> weightTransforms;
    protected Map<Integer, MatrixTransform> hiddenBiasTransforms;
    protected Map<Integer, MatrixTransform> visibleBiasTransforms;
    protected boolean shouldBackProp;
    protected boolean forceNumEpochs;
    protected boolean initCalled;
    protected boolean sampleFromHiddenActivations;
    protected NeuralNetConfiguration defaultConfiguration;
    protected List<NeuralNetConfiguration> layerWiseConfigurations;
    protected double learningRateUpdate;
    protected NeuralNetwork[] neuralNets;
    protected double errorTolerance;
    protected boolean lineSearchBackProp;
    protected INDArray mask;
    protected boolean useDropConnect;
    protected double dampingFactor;
    protected boolean useGaussNewtonVectorProductBackProp;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/deeplearning4j/nn/BaseMultiLayerNetwork$Builder.class */
    public static class Builder<E extends BaseMultiLayerNetwork> {
        protected Class<? extends BaseMultiLayerNetwork> clazz;
        private int[] hiddenLayerSizes;
        private int nLayers;
        private INDArray input;
        private INDArray labels;
        protected Map<Integer, MatrixTransform> weightTransforms = new HashMap();
        protected boolean backProp = true;
        protected boolean shouldForceEpochs = false;
        private Map<Integer, MatrixTransform> hiddenBiasTransforms = new HashMap();
        private Map<Integer, MatrixTransform> visibleBiasTransforms = new HashMap();
        private boolean lineSearchBackProp = false;
        private boolean useDropConnect = false;
        private boolean useGaussNewtonVectorProductBackProp = false;
        protected NeuralNetConfiguration conf;
        protected List<NeuralNetConfiguration> layerWiseConfiguration;

        public Builder<E> layerWiseCOnfiguration(List<NeuralNetConfiguration> list) {
            this.layerWiseConfiguration = list;
            return this;
        }

        public Builder<E> configure(NeuralNetConfiguration neuralNetConfiguration) {
            this.conf = neuralNetConfiguration;
            return this;
        }

        public Builder<E> useGaussNewtonVectorProductBackProp(boolean z) {
            this.useGaussNewtonVectorProductBackProp = z;
            return this;
        }

        public Builder<E> useDropConnection(boolean z) {
            this.useDropConnect = z;
            return this;
        }

        public Builder<E> lineSearchBackProp(boolean z) {
            this.lineSearchBackProp = z;
            return this;
        }

        public Builder<E> withVisibleBiasTransforms(Map<Integer, MatrixTransform> map) {
            this.visibleBiasTransforms = map;
            return this;
        }

        public Builder<E> withHiddenBiasTransforms(Map<Integer, MatrixTransform> map) {
            this.hiddenBiasTransforms = map;
            return this;
        }

        public Builder<E> forceEpochs() {
            this.shouldForceEpochs = true;
            return this;
        }

        public Builder<E> disableBackProp() {
            this.backProp = false;
            return this;
        }

        public Builder<E> transformWeightsAt(int i, MatrixTransform matrixTransform) {
            this.weightTransforms.put(Integer.valueOf(i), matrixTransform);
            return this;
        }

        public Builder<E> transformWeightsAt(Map<Integer, MatrixTransform> map) {
            this.weightTransforms.putAll(map);
            return this;
        }

        public Builder<E> hiddenLayerSizes(Integer[] numArr) {
            this.hiddenLayerSizes = new int[numArr.length];
            this.nLayers = numArr.length;
            for (int i = 0; i < numArr.length; i++) {
                this.hiddenLayerSizes[i] = numArr[i].intValue();
            }
            return this;
        }

        public Builder<E> hiddenLayerSizes(int[] iArr) {
            this.hiddenLayerSizes = iArr;
            this.nLayers = iArr.length;
            return this;
        }

        public Builder<E> withInput(INDArray iNDArray) {
            this.input = iNDArray;
            return this;
        }

        public Builder<E> withLabels(INDArray iNDArray) {
            this.labels = iNDArray;
            return this;
        }

        public Builder<E> withClazz(Class<? extends BaseMultiLayerNetwork> cls) {
            this.clazz = cls;
            return this;
        }

        public E buildEmpty() {
            try {
                Constructor<?> emptyConstructor = Dl4jReflection.getEmptyConstructor(this.clazz);
                emptyConstructor.setAccessible(true);
                return (E) emptyConstructor.newInstance(new Object[0]);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public E build() {
            try {
                Constructor<?> emptyConstructor = Dl4jReflection.getEmptyConstructor(this.clazz);
                emptyConstructor.setAccessible(true);
                E e = (E) emptyConstructor.newInstance(new Object[0]);
                e.setDefaultConfiguration(this.conf);
                e.useGaussNewtonVectorProductBackProp = this.useGaussNewtonVectorProductBackProp;
                e.setUseDropConnect(this.useDropConnect);
                e.setInput(this.input);
                e.setLabels(this.labels);
                e.setHiddenLayerSizes(this.hiddenLayerSizes);
                e.setnLayers(this.nLayers);
                e.setShouldBackProp(this.backProp);
                e.setLayerWiseConfigurations(this.layerWiseConfiguration);
                e.neuralNets = new NeuralNetwork[this.nLayers];
                e.setInput(this.input);
                e.setLineSearchBackProp(this.lineSearchBackProp);
                e.setLabels(this.labels);
                e.setForceNumEpochs(this.shouldForceEpochs);
                e.getWeightTransforms().putAll(this.weightTransforms);
                e.getVisibleBiasTransforms().putAll(this.visibleBiasTransforms);
                e.getHiddenBiasTransforms().putAll(this.hiddenBiasTransforms);
                if (this.hiddenLayerSizes == null) {
                    throw new IllegalStateException("Unable to build network, no hidden layer sizes defined");
                }
                return e;
            } catch (Exception e2) {
                throw new RuntimeException(e2);
            }
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/BaseMultiLayerNetwork$ParamRange.class */
    public static class ParamRange implements Serializable {
        private int wStart;
        private int wEnd;
        private int biasStart;
        private int biasEnd;

        private ParamRange(int i, int i2, int i3, int i4) {
            this.wStart = i;
            this.wEnd = i2;
            this.biasStart = i3;
            this.biasEnd = i4;
        }

        public int getwStart() {
            return this.wStart;
        }

        public void setwStart(int i) {
            this.wStart = i;
        }

        public int getwEnd() {
            return this.wEnd;
        }

        public void setwEnd(int i) {
            this.wEnd = i;
        }

        public int getBiasStart() {
            return this.biasStart;
        }

        public void setBiasStart(int i) {
            this.biasStart = i;
        }

        public int getBiasEnd() {
            return this.biasEnd;
        }

        public void setBiasEnd(int i) {
            this.biasEnd = i;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseMultiLayerNetwork() {
        this.weightTransforms = new HashMap();
        this.hiddenBiasTransforms = new HashMap();
        this.visibleBiasTransforms = new HashMap();
        this.shouldBackProp = true;
        this.forceNumEpochs = false;
        this.initCalled = false;
        this.sampleFromHiddenActivations = true;
        this.learningRateUpdate = 0.949999988079071d;
        this.errorTolerance = 9.999999747378752E-5d;
        this.lineSearchBackProp = false;
        this.useDropConnect = false;
        this.dampingFactor = 10.0d;
        this.useGaussNewtonVectorProductBackProp = false;
    }

    protected BaseMultiLayerNetwork(int[] iArr, int i) {
        this(iArr, i, null, null);
    }

    protected BaseMultiLayerNetwork(int[] iArr, int i, INDArray iNDArray, INDArray iNDArray2) {
        this.weightTransforms = new HashMap();
        this.hiddenBiasTransforms = new HashMap();
        this.visibleBiasTransforms = new HashMap();
        this.shouldBackProp = true;
        this.forceNumEpochs = false;
        this.initCalled = false;
        this.sampleFromHiddenActivations = true;
        this.learningRateUpdate = 0.949999988079071d;
        this.errorTolerance = 9.999999747378752E-5d;
        this.lineSearchBackProp = false;
        this.useDropConnect = false;
        this.dampingFactor = 10.0d;
        this.useGaussNewtonVectorProductBackProp = false;
        this.hiddenLayerSizes = iArr;
        this.input = iNDArray.dup();
        this.labels = iNDArray2.dup();
        if (iArr.length != i) {
            throw new IllegalArgumentException("The number of hidden layer sizes must be equivalent to the nLayers argument which is a value of " + i);
        }
        setnLayers(i);
        this.layers = new org.deeplearning4j.nn.layers.Layer[i + 1];
        intializeConfigurations();
        if (iNDArray != null) {
            initializeLayers(iNDArray);
        }
    }

    protected void intializeConfigurations() {
        if (this.layerWiseConfigurations == null) {
            this.layerWiseConfigurations = new ArrayList();
        }
        if (this.layers == null) {
            this.layers = new Layer[getnLayers() + 1];
        }
        if (this.neuralNets == null) {
            this.neuralNets = new NeuralNetwork[getnLayers()];
        }
        if (this.defaultConfiguration == null) {
            this.defaultConfiguration = new NeuralNetConfiguration.Builder().build();
        }
        for (int i = 0; i < this.hiddenLayerSizes.length + 1; i++) {
            this.layerWiseConfigurations.add(this.defaultConfiguration.m21clone());
        }
    }

    public void dimensionCheck() {
        if (!$assertionsDisabled && this.layers.length != this.neuralNets.length + 1) {
            throw new AssertionError("Missing output layer");
        }
        for (int i = 0; i < getnLayers(); i++) {
            Layer layer = this.layers[i];
            NeuralNetwork neuralNetwork = this.neuralNets[i];
            LinAlgExceptions.assertSameShape(neuralNetwork.getW(), layer.getW());
            LinAlgExceptions.assertSameShape(layer.getB(), neuralNetwork.gethBias());
            if (!$assertionsDisabled && layer.conf().getnIn() != layer.getW().rows()) {
                throw new AssertionError("Number of inputs not consistent with number of rows in weight matrix");
            }
            if (!$assertionsDisabled && layer.conf().getnOut() != layer.getW().columns()) {
                throw new AssertionError("Number of inputs not consistent with number of rows in weight matrix");
            }
            if (i < getnLayers() - 1) {
                Layer layer2 = this.layers[i + 1];
                NeuralNetwork neuralNetwork2 = this.neuralNets[i + 1];
                if (!$assertionsDisabled && layer2.conf().getnIn() != layer2.getW().rows()) {
                    throw new AssertionError("Number of inputs not consistent with number of rows in weight matrix");
                }
                if (!$assertionsDisabled && layer2.conf().getnOut() != layer2.getW().columns()) {
                    throw new AssertionError("Number of inputs not consistent with number of rows in weight matrix");
                }
                if (!$assertionsDisabled && neuralNetwork2.conf().getnIn() != neuralNetwork2.getW().rows()) {
                    throw new AssertionError("Number of inputs not consistent with number of rows in weight matrix");
                }
                if (!$assertionsDisabled && neuralNetwork2.conf().getnOut() != neuralNetwork2.getW().columns()) {
                    throw new AssertionError("Number of inputs not consistent with number of rows in weight matrix");
                }
                if (layer2.conf().getnIn() != layer.conf().getnOut()) {
                    throw new IllegalStateException("Invalid structure: hidden layer in for " + (i + 1) + " not equal to number of ins " + i);
                }
                if (neuralNetwork.conf().getnOut() != neuralNetwork2.conf().getnIn()) {
                    throw new IllegalStateException("Invalid structure: network hidden for " + (i + 1) + " not equal to number of visible " + i);
                }
            }
        }
    }

    public INDArray transform(INDArray iNDArray) {
        return output(iNDArray);
    }

    public NeuralNetConfiguration getDefaultConfiguration() {
        return this.defaultConfiguration;
    }

    public void setDefaultConfiguration(NeuralNetConfiguration neuralNetConfiguration) {
        this.defaultConfiguration = neuralNetConfiguration;
    }

    public List<NeuralNetConfiguration> getLayerWiseConfigurations() {
        return this.layerWiseConfigurations;
    }

    public void setLayerWiseConfigurations(List<NeuralNetConfiguration> list) {
        this.layerWiseConfigurations = list;
    }

    public void initializeLayers(INDArray iNDArray) {
        if (iNDArray == null) {
            throw new IllegalArgumentException("Unable to initialize neuralNets with empty input");
        }
        if (iNDArray.columns() != this.defaultConfiguration.getnIn()) {
            throw new IllegalArgumentException(String.format("Unable to iterate on number of inputs; columns should be equal to number of inputs. Number of inputs was %d while number of columns was %d", Integer.valueOf(this.defaultConfiguration.getnIn()), Integer.valueOf(iNDArray.columns())));
        }
        if (this.neuralNets == null) {
            this.neuralNets = new NeuralNetwork[getnLayers()];
        }
        for (int i = 0; i < this.hiddenLayerSizes.length; i++) {
            if (this.hiddenLayerSizes[i] < 1) {
                throw new IllegalArgumentException("All hidden layer sizes must be >= 1");
            }
        }
        this.input = iNDArray.dup();
        if (this.initCalled) {
            return;
        }
        init();
        log.info("Initializing neuralNets with input of dims " + iNDArray.rows() + " x " + iNDArray.columns());
    }

    public void init() {
        if (this.layerWiseConfigurations == null || this.layers == null) {
            intializeConfigurations();
        }
        INDArray iNDArray = this.input;
        if (getnLayers() < 1) {
            throw new IllegalStateException("Unable to createComplex network neuralNets; number specified is less than 1");
        }
        if (this.neuralNets == null || this.neuralNets == null || this.neuralNets[0] == null || this.neuralNets[0] == null) {
            this.neuralNets = new NeuralNetwork[getnLayers()];
            int i = 0;
            while (i < getnLayers()) {
                int i2 = i == 0 ? this.defaultConfiguration.getnIn() : this.hiddenLayerSizes[i - 1];
                if (i == 0) {
                    this.layerWiseConfigurations.get(i).setnIn(i2);
                    this.layerWiseConfigurations.get(i).setnOut(this.hiddenLayerSizes[i]);
                    this.layers[i] = createHiddenLayer(i, iNDArray);
                } else {
                    if (this.input != null) {
                        iNDArray = activationFromPrevLayer(i - 1, iNDArray);
                    }
                    this.layerWiseConfigurations.get(i).setnIn(i2);
                    this.layerWiseConfigurations.get(i).setnOut(this.hiddenLayerSizes[i]);
                    this.layers[i] = createHiddenLayer(i, iNDArray);
                }
                this.neuralNets[i] = createLayer(iNDArray, this.layers[i].getW(), this.layers[i].getB(), null, i);
                i++;
            }
        }
        NeuralNetConfiguration neuralNetConfiguration = this.layerWiseConfigurations.get(this.layerWiseConfigurations.size() - 1);
        neuralNetConfiguration.setnIn(this.layerWiseConfigurations.get(this.layerWiseConfigurations.size() - 2).getnOut());
        this.layers[this.layers.length - 1] = new OutputLayer.Builder().configure(neuralNetConfiguration).build();
        dimensionCheck();
        applyTransforms();
        this.initCalled = true;
        initMask();
    }

    public INDArray activate() {
        return getLayers()[getNeuralNets().length - 1].activate();
    }

    public INDArray activate(int i) {
        return getLayers()[i].activate();
    }

    public INDArray activate(int i, INDArray iNDArray) {
        return getLayers()[i].activate(iNDArray);
    }

    public void finetune(double d, int i) {
        finetune(this.labels, d, i);
    }

    public void initialize(DataSet dataSet) {
        setInput(dataSet.getFeatureMatrix());
        feedForward(dataSet.getFeatureMatrix());
        this.labels = dataSet.getLabels();
        getOutputLayer().setLabels(this.labels);
    }

    public INDArray activationFromPrevLayer(int i, INDArray iNDArray) {
        if (i == this.neuralNets.length) {
            return getOutputLayer().labelProbabilities(iNDArray);
        }
        switch (this.layers[i].conf().getActivationType()) {
            case HIDDEN_LAYER_ACTIVATION:
                return this.layers[i].activate(iNDArray);
            case NET_ACTIVATION:
                return this.neuralNets[i].hiddenActivation(iNDArray);
            case SAMPLE:
                return this.neuralNets[i].sampleHiddenGivenVisible(iNDArray).getSecond();
            default:
                throw new IllegalStateException("Invalid activation type");
        }
    }

    public List<INDArray> feedForward() {
        INDArray iNDArray = this.input;
        if (this.input.columns() != this.defaultConfiguration.getnIn()) {
            throw new IllegalStateException("Illegal input length");
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(iNDArray);
        for (int i = 0; i < this.layers.length; i++) {
            iNDArray = activationFromPrevLayer(i, iNDArray);
            applyDropConnectIfNecessary(iNDArray);
            arrayList.add(iNDArray);
        }
        return arrayList;
    }

    public List<INDArray> feedForward(INDArray iNDArray) {
        if (iNDArray == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        this.input = iNDArray;
        return feedForward();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void applyDropConnectIfNecessary(INDArray iNDArray) {
        if (this.useDropConnect) {
            iNDArray.muli(Sampling.binomial(Nd4j.valueArrayOf(iNDArray.rows(), iNDArray.columns(), 0.5d), 1, this.defaultConfiguration.getRng()));
            if (this.defaultConfiguration.getL2() > 0.0f) {
                iNDArray.muli(Float.valueOf(this.defaultConfiguration.getL2()));
            }
        }
    }

    protected List<INDArray> computeDeltasR(INDArray iNDArray) {
        ArrayList arrayList = new ArrayList();
        INDArray[] iNDArrayArr = new INDArray[getnLayers() + 1];
        List<INDArray> feedForward = feedForward();
        List<INDArray> feedForwardR = feedForwardR(feedForward, iNDArray);
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i = 0; i < getNeuralNets().length; i++) {
            arrayList2.add(getNeuralNets()[i].getW());
            arrayList3.add(getNeuralNets()[i].gethBias());
            arrayList4.add(getNeuralNets()[i].conf().getActivationFunction());
        }
        INDArray divi = feedForwardR.get(feedForwardR.size() - 1).divi(Integer.valueOf(this.input.rows()));
        LinAlgExceptions.assertValidNum(divi);
        for (int i2 = getnLayers(); i2 >= 0; i2--) {
            iNDArrayArr[i2] = feedForward.get(i2).transpose().mmul(divi);
            applyDropConnectIfNecessary(iNDArrayArr[i2]);
            if (i2 > 0) {
                divi = divi.mmul(((INDArray) arrayList2.get(i2)).addRowVector((INDArray) arrayList3.get(i2)).transpose()).muli(((ActivationFunction) arrayList4.get(i2 - 1)).applyDerivative(feedForward.get(i2)));
            }
        }
        for (int i3 = 0; i3 < iNDArrayArr.length; i3++) {
            if (this.defaultConfiguration.isConstrainGradientToUnitNorm()) {
                arrayList.add(iNDArrayArr[i3].div(iNDArrayArr[i3].norm2(Integer.MAX_VALUE)));
            } else {
                arrayList.add(iNDArrayArr[i3]);
            }
            LinAlgExceptions.assertValidNum((INDArray) arrayList.get(i3));
        }
        return arrayList;
    }

    public void dampingUpdate(double d, double d2, double d3) {
        if (d < 0.25d || Double.isNaN(d)) {
            this.dampingFactor *= d2;
        } else if (d > 0.75d) {
            this.dampingFactor *= d3;
        }
    }

    public double reductionRatio(INDArray iNDArray, double d, double d2, INDArray iNDArray2) {
        double d3 = this.dampingFactor;
        this.dampingFactor = 0.0d;
        INDArray backPropRGradient = getBackPropRGradient(iNDArray);
        backPropRGradient.muli(Double.valueOf(0.5d)).muli(iNDArray.mul(backPropRGradient)).sum(0);
        backPropRGradient.subi(iNDArray2.mul(iNDArray).sum(0));
        double doubleValue = (d - d2) / ((Double) backPropRGradient.getScalar(0).element()).doubleValue();
        this.dampingFactor = d3;
        if (d2 - d > 0.0d) {
            return Double.NEGATIVE_INFINITY;
        }
        return doubleValue;
    }

    protected List<Pair<INDArray, INDArray>> computeDeltas2() {
        ArrayList arrayList = new ArrayList();
        List<INDArray> feedForward = feedForward();
        INDArray[] iNDArrayArr = new INDArray[feedForward.size() - 1];
        INDArray[] iNDArrayArr2 = new INDArray[feedForward.size() - 1];
        INDArray div = feedForward.get(feedForward.size() - 1).sub(this.labels).div(Integer.valueOf(this.labels.rows()));
        log.info("Ix mean " + div.sum(Integer.MAX_VALUE));
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (int i = 0; i < getNeuralNets().length; i++) {
            arrayList2.add(getNeuralNets()[i].getW());
            arrayList3.add(getNeuralNets()[i].gethBias());
            arrayList4.add(getLayers()[i].conf().getActivationFunction());
        }
        for (int size = arrayList2.size() - 1; size >= 0; size--) {
            iNDArrayArr[size] = feedForward.get(size).transpose().mmul(div);
            log.info("Delta sum at " + size + " is " + iNDArrayArr[size].sum(Integer.MAX_VALUE));
            iNDArrayArr2[size] = Transforms.pow(feedForward.get(size).transpose(), 2).mmul(Transforms.pow(div, 2)).mul(Integer.valueOf(this.labels.rows()));
            applyDropConnectIfNecessary(iNDArrayArr[size]);
            if (size > 0) {
                div = div.mmul(((INDArray) arrayList2.get(size)).transpose()).muli(((ActivationFunction) arrayList4.get(size - 1)).applyDerivative(feedForward.get(size)));
            }
        }
        for (int i2 = 0; i2 < iNDArrayArr.length; i2++) {
            if (this.defaultConfiguration.isConstrainGradientToUnitNorm()) {
                arrayList.add(new Pair(iNDArrayArr[i2].divi(iNDArrayArr[i2].norm2(Integer.MAX_VALUE)), iNDArrayArr2[i2]));
            } else {
                arrayList.add(new Pair(iNDArrayArr[i2], iNDArrayArr2[i2]));
            }
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        int i = 0;
        for (int i2 = 0; i2 < this.neuralNets.length; i2++) {
            int numParams = getNeuralNets()[i2].numParams();
            getNeuralNets()[i2].setParams(iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(i, i + numParams)}));
            getLayers()[i2].setW(getNeuralNets()[i2].getW());
            getLayers()[i2].setB(getNeuralNets()[i2].gethBias());
            i += numParams;
        }
        getOutputLayer().setParams(iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(i, iNDArray.length())}));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
        pretrain(iNDArray, new Object[]{1});
    }

    protected void computeDeltas(List<INDArray> list) {
        INDArray[] iNDArrayArr = new INDArray[getnLayers() + 2];
        List<INDArray> feedForward = feedForward();
        INDArray subi = this.labels.sub(feedForward.get(feedForward.size() - 1)).negi().subi(getOutputLayer().conf().getActivationFunction().applyDerivative(feedForward.get(feedForward.size() - 1)));
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < getNeuralNets().length; i++) {
            arrayList.add(getNeuralNets()[i].getW());
            arrayList2.add(getNeuralNets()[i].gethBias());
            arrayList3.add(getLayers()[i].conf().getActivationFunction());
        }
        arrayList.add(getOutputLayer().getW());
        arrayList2.add(getOutputLayer().getB());
        arrayList3.add(getOutputLayer().conf().getActivationFunction());
        for (int i2 = getnLayers() + 1; i2 >= 0; i2--) {
            if (i2 >= getnLayers() + 1) {
                iNDArrayArr[i2] = subi;
            } else {
                iNDArrayArr[i2] = feedForward.get(i2).transpose().mmul(subi);
                applyDropConnectIfNecessary(iNDArrayArr[i2]);
                INDArray transpose = ((INDArray) arrayList.get(i2)).addRowVector((INDArray) arrayList2.get(i2)).transpose();
                INDArray iNDArray = feedForward.get(i2);
                if (i2 > 0) {
                    subi = subi.mmul(transpose).muli(((ActivationFunction) arrayList3.get(i2 - 1)).applyDerivative(iNDArray));
                }
            }
        }
        for (int i3 = 0; i3 < iNDArrayArr.length; i3++) {
            if (this.defaultConfiguration.isConstrainGradientToUnitNorm()) {
                list.add(iNDArrayArr[i3].divi(iNDArrayArr[i3].norm2(Integer.MAX_VALUE)));
            } else {
                list.add(iNDArrayArr[i3]);
            }
        }
    }

    public void backPropStep() {
        List<Pair<INDArray, INDArray>> backPropGradient = backPropGradient();
        for (int i = 0; i < this.layers.length; i++) {
            this.layers[i].getW().subi(backPropGradient.get(i).getFirst());
            this.layers[i].getB().subi(backPropGradient.get(i).getSecond());
            if (i < this.neuralNets.length) {
                this.neuralNets[i].setW(this.layers[i].getW());
                this.neuralNets[i].sethBias(this.layers[i].getB());
            }
        }
    }

    public void backPropStepR(INDArray iNDArray) {
        List<Pair<INDArray, INDArray>> backPropGradientR = backPropGradientR(iNDArray);
        for (int i = 0; i < this.neuralNets.length; i++) {
            if (backPropGradientR.size() < this.neuralNets.length) {
                this.neuralNets[i].getW().subi(backPropGradientR.get(i).getFirst());
                this.neuralNets[i].gethBias().subi(backPropGradientR.get(i).getSecond());
                this.layers[i].setW(this.neuralNets[i].getW());
                this.layers[i].setB(this.neuralNets[i].gethBias());
            }
        }
        getOutputLayer().getW().subi(backPropGradientR.get(backPropGradientR.size() - 1).getFirst());
        getOutputLayer().getB().subi(backPropGradientR.get(backPropGradientR.size() - 1).getSecond());
    }

    public Layer[] getLayers() {
        return this.layers;
    }

    public void setLayers(Layer[] layerArr) {
        this.layers = layerArr;
    }

    public void setNeuralNets(NeuralNetwork[] neuralNetworkArr) {
        this.neuralNets = neuralNetworkArr;
    }

    public INDArray getBackPropRGradient(INDArray iNDArray) {
        return pack(backPropGradientR(iNDArray));
    }

    public Pair<INDArray, INDArray> getBackPropGradient2() {
        List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> backPropGradient2 = backPropGradient2();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < backPropGradient2.size(); i++) {
            arrayList.add(backPropGradient2.get(i).getFirst());
            arrayList2.add(backPropGradient2.get(i).getSecond());
        }
        return new Pair<>(pack(arrayList), pack(arrayList2));
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public BaseMultiLayerNetwork m15clone() {
        BaseMultiLayerNetwork baseMultiLayerNetwork = null;
        try {
            baseMultiLayerNetwork = (BaseMultiLayerNetwork) getClass().newInstance();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InstantiationException e2) {
            e2.printStackTrace();
        }
        baseMultiLayerNetwork.update(this);
        return baseMultiLayerNetwork;
    }

    public List<INDArray> weightMatrices() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.neuralNets.length; i++) {
            arrayList.add(this.neuralNets[i].getW());
        }
        arrayList.add(getOutputLayer().getW());
        return arrayList;
    }

    public List<INDArray> feedForwardR(List<INDArray> list, INDArray iNDArray) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(Nd4j.zeros(this.input.rows(), this.input.columns()));
        List<Pair<INDArray, INDArray>> unPack = unPack(iNDArray);
        List<INDArray> weightMatrices = weightMatrices();
        for (int i = 0; i < this.neuralNets.length; i++) {
            ActivationFunction activationFunction = getNeuralNets()[i].conf().getActivationFunction();
            if (getNeuralNets()[i] instanceof AutoEncoder) {
                activationFunction = ((AutoEncoder) getNeuralNets()[i]).conf.getActivationFunction();
            }
            arrayList.add(((INDArray) arrayList.get(i)).mmul(weightMatrices.get(i)).addi(list.get(i).mmul(unPack.get(i).getFirst().addRowVector(unPack.get(i).getSecond()))).muli(activationFunction.applyDerivative(list.get(i + 1))));
        }
        arrayList.add(((INDArray) arrayList.get(arrayList.size() - 1)).mmul(weightMatrices.get(weightMatrices.size() - 1)).addi(list.get(list.size() - 2).mmul(unPack.get(unPack.size() - 1).getFirst().addRowVector(unPack.get(unPack.size() - 1).getSecond()))).muli(getOutputLayer().conf().getActivationFunction().applyDerivative(list.get(list.size() - 1))));
        return arrayList;
    }

    public List<INDArray> feedForwardR(INDArray iNDArray) {
        return feedForwardR(feedForward(), iNDArray);
    }

    public void backProp(double d, int i, TrainingEvaluator trainingEvaluator) {
        if (this.useGaussNewtonVectorProductBackProp) {
            new BackPropROptimizer(this, d, i).optimize(trainingEvaluator, i, this.lineSearchBackProp);
        } else {
            new BackPropOptimizer(this, d, i).optimize(trainingEvaluator, i, this.lineSearchBackProp);
        }
    }

    public void backProp(double d, int i) {
        backProp(d, i, null);
    }

    public boolean isUseDropConnect() {
        return this.useDropConnect;
    }

    public void setUseDropConnect(boolean z) {
        this.useDropConnect = z;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray params() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < getnLayers(); i++) {
            arrayList.add(getNeuralNets()[i].getW());
            arrayList.add(getNeuralNets()[i].gethBias());
        }
        arrayList.add(getOutputLayer().params());
        return Nd4j.toFlattened(arrayList);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams() {
        int i = 0;
        for (int i2 = 0; i2 < this.neuralNets.length; i2++) {
            i += this.neuralNets[i2].numParams() - this.neuralNets[i2].getvBias().length();
        }
        return i + getOutputLayer().numParams();
    }

    public INDArray pack() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.neuralNets.length; i++) {
            arrayList.add(new Pair<>(this.neuralNets[i].getW(), this.neuralNets[i].gethBias()));
        }
        arrayList.add(new Pair<>(getOutputLayer().getW(), getOutputLayer().getB()));
        return pack(arrayList);
    }

    public INDArray pack(List<Pair<INDArray, INDArray>> list) {
        if (list.size() != this.neuralNets.length + 1) {
            throw new IllegalArgumentException("Illegal number of neuralNets passed in. Was " + list.size() + " when should have been " + (this.neuralNets.length + 1));
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(list.get(i).getFirst());
            arrayList.add(list.get(i).getSecond());
        }
        INDArray flattened = Nd4j.toFlattened(arrayList);
        if (flattened.length() != numParams()) {
            throw new IllegalStateException("Illegal number of parameters found in the layers with a difference of " + Math.abs(flattened.length() - numParams()));
        }
        return flattened;
    }

    public double score(org.nd4j.linalg.dataset.api.DataSet dataSet) {
        return score(dataSet.getFeatureMatrix(), dataSet.getLabels());
    }

    protected List<Pair<INDArray, INDArray>> backPropGradient() {
        ArrayList arrayList = new ArrayList();
        computeDeltas(arrayList);
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < this.neuralNets.length; i++) {
            arrayList2.add(new Pair(this.neuralNets[i].getW(), this.neuralNets[i].gethBias()));
        }
        arrayList2.add(new Pair(getOutputLayer().getW(), getOutputLayer().getB()));
        ArrayList arrayList3 = new ArrayList();
        for (int i2 = 0; i2 < getnLayers() + 1; i2++) {
            INDArray iNDArray = arrayList.get(i2);
            if (iNDArray.length() != getLayers()[i2].getW().length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }
            arrayList3.add(new Pair<>(iNDArray, arrayList.get(i2).sum(0)));
        }
        if (this.mask == null) {
            initMask();
        }
        for (int i3 = 0; i3 < arrayList3.size(); i3++) {
            if (i3 < getnLayers()) {
                INDArray first = arrayList3.get(i3).getFirst();
                INDArray second = arrayList3.get(i3).getSecond();
                if (!$assertionsDisabled && !Arrays.equals(first.shape(), getNeuralNets()[i3].getW().shape())) {
                    throw new AssertionError("Illegal shape for layer " + i3 + " weight gradient, should have been " + Arrays.toString(getNeuralNets()[i3].getW().shape()) + " but was " + Arrays.toString(first.shape()));
                }
                if (!$assertionsDisabled && !Arrays.equals(second.shape(), getNeuralNets()[i3].gethBias().shape())) {
                    throw new AssertionError("Illegal shape for layer " + i3 + " bias   gradient, should have been " + Arrays.toString(getNeuralNets()[i3].gethBias().shape()) + " but was " + Arrays.toString(second.shape()));
                }
            } else {
                INDArray first2 = arrayList3.get(i3).getFirst();
                INDArray second2 = arrayList3.get(i3).getSecond();
                if (!$assertionsDisabled && !Arrays.equals(first2.shape(), getOutputLayer().getW().shape())) {
                    throw new AssertionError("Illegal shape for output  weight gradient, should have been " + Arrays.toString(getOutputLayer().getW().shape()) + " but was " + Arrays.toString(first2.shape()));
                }
                if (!$assertionsDisabled && !Arrays.equals(second2.shape(), getOutputLayer().getB().shape())) {
                    throw new AssertionError("Illegal shape for output layer  bias   gradient, should have been " + Arrays.toString(getOutputLayer().getB().shape()) + " but was " + Arrays.toString(second2.shape()));
                }
            }
        }
        INDArray pack = pack(arrayList3);
        pack.addi(this.mask.mul(params().mul(Float.valueOf(this.defaultConfiguration.getL2()))));
        return unPack(pack);
    }

    public List<Pair<INDArray, INDArray>> unPack(INDArray iNDArray) {
        int numParams = numParams();
        if (iNDArray.length() != numParams) {
            throw new IllegalArgumentException("Parameter vector not equal of length to " + numParams);
        }
        if (iNDArray.rows() != 1) {
            iNDArray = iNDArray.reshape(1, iNDArray.length());
        }
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (int i2 = 0; i2 < this.layers.length; i2++) {
            int length = this.layers[i2].getW().length() + this.layers[i2].getB().length();
            INDArray iNDArray2 = iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(i, i + length)});
            INDArray iNDArray3 = iNDArray2.get(new NDArrayIndex[]{NDArrayIndex.interval(0, this.layers[i2].getW().length())});
            INDArray iNDArray4 = iNDArray2.get(new NDArrayIndex[]{NDArrayIndex.interval(this.layers[i2].getW().length(), iNDArray2.length())});
            if (iNDArray3.length() + iNDArray4.length() != length) {
                if (iNDArray4.length() != this.layers[i2].getB().length()) {
                    throw new IllegalStateException("Hidden bias on layer " + i2 + " was off");
                }
                if (iNDArray3.length() != this.layers[i2].getW().length()) {
                    throw new IllegalStateException("Weight portion on layer " + i2 + " was off");
                }
            }
            arrayList.add(new Pair(iNDArray3.reshape(this.layers[i2].getW().rows(), this.layers[i2].getW().columns()), iNDArray4.reshape(this.layers[i2].getB().rows(), this.layers[i2].getB().columns())));
            i += length;
        }
        return arrayList;
    }

    protected List<Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>> backPropGradient2() {
        List<Pair<INDArray, INDArray>> computeDeltas2 = computeDeltas2();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < computeDeltas2.size(); i++) {
            INDArray first = computeDeltas2.get(i).getFirst();
            INDArray second = computeDeltas2.get(i).getSecond();
            if (i < getNeuralNets().length && first.length() != getNeuralNets()[i].getW().length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }
            if (i == getNeuralNets().length && first.length() != getOutputLayer().getW().length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }
            INDArray mean = computeDeltas2.get(i).getFirst().mean(1);
            INDArray mean2 = computeDeltas2.get(i).getSecond().mean(1);
            arrayList2.add(new Pair<>(first, mean));
            arrayList3.add(new Pair<>(second, mean2));
            if (i < getNeuralNets().length && mean.length() != this.neuralNets[i].gethBias().length()) {
                throw new IllegalStateException("Bias change not equal to weight change");
            }
            if (i == getNeuralNets().length && mean.length() != getOutputLayer().getB().length()) {
                throw new IllegalStateException("Bias change not equal to weight change");
            }
        }
        INDArray pack = pack(arrayList2);
        INDArray pack2 = pack(arrayList3);
        INDArray params = params();
        if (this.mask == null) {
            initMask();
        }
        pack.addi(params.mul(Float.valueOf(this.defaultConfiguration.getL2())).muli(this.mask));
        pack2.addi(Transforms.pow(this.mask.mul(Float.valueOf(this.defaultConfiguration.getL2())).add(Nd4j.valueArrayOf(pack.rows(), pack.columns(), this.dampingFactor)), Double.valueOf(0.75d)));
        List<Pair<INDArray, INDArray>> unPack = unPack(pack);
        List<Pair<INDArray, INDArray>> unPack2 = unPack(pack2);
        for (int i2 = 0; i2 < unPack.size(); i2++) {
            arrayList.add(new Pair(unPack.get(i2), unPack2.get(i2)));
        }
        return arrayList;
    }

    protected List<Pair<INDArray, INDArray>> backPropGradientR(INDArray iNDArray) {
        if (this.mask == null) {
            initMask();
        }
        List<INDArray> computeDeltasR = computeDeltasR(iNDArray);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < getnLayers(); i++) {
            INDArray iNDArray2 = computeDeltasR.get(i);
            if (iNDArray2.length() != getNeuralNets()[i].getW().length()) {
                throw new IllegalStateException("Gradient change not equal to weight change");
            }
            INDArray mean = computeDeltasR.get(i).mean(1);
            if (mean.length() != this.neuralNets[i].gethBias().length()) {
                throw new IllegalStateException("Bias change not equal to weight change");
            }
            arrayList.add(new Pair<>(iNDArray2, mean));
        }
        arrayList.add(new Pair<>(computeDeltasR.get(getnLayers()), computeDeltasR.get(getnLayers()).mean(1)));
        return unPack(pack(arrayList).addi(this.mask.mul(Float.valueOf(this.defaultConfiguration.getL2())).mul(iNDArray)).addi(iNDArray.mul(Double.valueOf(this.dampingFactor))));
    }

    public void finetune(DataSetIterator dataSetIterator, double d, int i) {
        dataSetIterator.reset();
        while (dataSetIterator.hasNext()) {
            DataSet next = dataSetIterator.next();
            if (next.getFeatureMatrix() == null || next.getLabels() == null) {
                return;
            }
            setInput(next.getFeatureMatrix());
            setLabels(next.getLabels());
            feedForward();
            this.optimizer = new MultiLayerNetworkOptimizer(this, d);
            this.optimizer.optimize(next.getLabels(), d, i);
        }
    }

    public void finetune(DataSetIterator dataSetIterator, double d, int i, TrainingEvaluator trainingEvaluator) {
        dataSetIterator.reset();
        while (dataSetIterator.hasNext()) {
            DataSet next = dataSetIterator.next();
            if (next.getFeatureMatrix() == null || next.getLabels() == null) {
                return;
            }
            setInput(next.getFeatureMatrix());
            setLabels(next.getLabels());
            this.optimizer = new MultiLayerNetworkOptimizer(this, d);
            this.optimizer.optimize(this.labels, d, i, trainingEvaluator);
        }
    }

    public void finetune(INDArray iNDArray, double d, int i) {
        this.labels = iNDArray;
        getOutputLayer().setLabels(iNDArray);
        feedForward();
        if (iNDArray != null) {
            this.labels = iNDArray;
        }
        this.optimizer = new MultiLayerNetworkOptimizer(this, d);
        this.optimizer.optimize(this.labels, d, i);
    }

    public void finetune(INDArray iNDArray, double d, int i, TrainingEvaluator trainingEvaluator) {
        feedForward();
        if (iNDArray != null) {
            this.labels = iNDArray;
        }
        this.optimizer = new MultiLayerNetworkOptimizer(this, d);
        this.optimizer.optimize(this.labels, d, i, trainingEvaluator);
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public int[] predict(INDArray iNDArray) {
        INDArray output = output(iNDArray);
        int[] iArr = new int[iNDArray.rows()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = Nd4j.getBlasWrapper().iamax(output.getRow(i));
        }
        return iArr;
    }

    public INDArray labelProbabilities(INDArray iNDArray) {
        List<INDArray> feedForward = feedForward(iNDArray);
        return getOutputLayer().labelProbabilities(feedForward.get(feedForward.size() - 1));
    }

    public void fit(INDArray iNDArray, INDArray iNDArray2) {
        pretrain(iNDArray, new Object[]{Integer.valueOf(this.defaultConfiguration.getK()), Float.valueOf(this.defaultConfiguration.getLr())});
        finetune(iNDArray2, this.defaultConfiguration.getLr(), this.defaultConfiguration.getNumIterations());
    }

    public void fit(org.nd4j.linalg.dataset.api.DataSet dataSet) {
        fit(dataSet.getFeatureMatrix(), dataSet.getLabels());
    }

    public void fit(INDArray iNDArray, int[] iArr) {
        getOutputLayer().fit(iNDArray, iArr);
    }

    public INDArray output(INDArray iNDArray) {
        List<INDArray> feedForward = feedForward(iNDArray);
        return feedForward.get(feedForward.size() - 1);
    }

    public INDArray reconstruct(INDArray iNDArray, int i) {
        return feedForward(iNDArray).get(i - 1);
    }

    public void printConfiguration() {
        StringBuffer stringBuffer = new StringBuffer();
        int i = 0;
        Iterator<NeuralNetConfiguration> it = getLayerWiseConfigurations().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            stringBuffer.append(" Layer " + i2 + " conf " + it.next());
        }
        log.info(stringBuffer.toString());
    }

    @Override // org.deeplearning4j.nn.api.Persistable
    public void write(OutputStream outputStream) {
        SerializationUtils.writeObject(this, outputStream);
    }

    @Override // org.deeplearning4j.nn.api.Persistable
    public void load(InputStream inputStream) {
        update((BaseMultiLayerNetwork) SerializationUtils.readObject(inputStream));
    }

    public void update(BaseMultiLayerNetwork baseMultiLayerNetwork) {
        if (baseMultiLayerNetwork.neuralNets != null && baseMultiLayerNetwork.getnLayers() > 0) {
            setnLayers(baseMultiLayerNetwork.getNeuralNets().length);
            this.neuralNets = new NeuralNetwork[baseMultiLayerNetwork.getNeuralNets().length];
            for (int i = 0; i < this.neuralNets.length; i++) {
                if (getnLayers() > i && baseMultiLayerNetwork.getnLayers() > i) {
                    if (baseMultiLayerNetwork.getNeuralNets()[i] == null) {
                        throw new IllegalStateException("Will not clone uninitialized network, layer " + i + " of network was null");
                    }
                    getNeuralNets()[i] = baseMultiLayerNetwork.getNeuralNets()[i].mo11clone();
                }
            }
        }
        this.hiddenLayerSizes = baseMultiLayerNetwork.hiddenLayerSizes;
        this.defaultConfiguration = baseMultiLayerNetwork.defaultConfiguration;
        this.errorTolerance = baseMultiLayerNetwork.errorTolerance;
        this.forceNumEpochs = baseMultiLayerNetwork.forceNumEpochs;
        this.input = baseMultiLayerNetwork.input;
        this.labels = baseMultiLayerNetwork.labels;
        this.learningRateUpdate = baseMultiLayerNetwork.learningRateUpdate;
        this.shouldBackProp = baseMultiLayerNetwork.shouldBackProp;
        this.weightTransforms = baseMultiLayerNetwork.weightTransforms;
        this.visibleBiasTransforms = baseMultiLayerNetwork.visibleBiasTransforms;
        this.hiddenBiasTransforms = baseMultiLayerNetwork.hiddenBiasTransforms;
        this.useDropConnect = baseMultiLayerNetwork.useDropConnect;
        this.useGaussNewtonVectorProductBackProp = baseMultiLayerNetwork.useGaussNewtonVectorProductBackProp;
        if (baseMultiLayerNetwork.neuralNets == null || baseMultiLayerNetwork.neuralNets.length <= 0) {
            return;
        }
        this.neuralNets = new NeuralNetwork[baseMultiLayerNetwork.neuralNets.length];
        for (int i2 = 0; i2 < this.neuralNets.length; i2++) {
            getNeuralNets()[i2] = baseMultiLayerNetwork.getNeuralNets()[i2].mo11clone();
        }
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public double score(INDArray iNDArray, INDArray iNDArray2) {
        feedForward(iNDArray);
        setLabels(iNDArray2);
        Evaluation evaluation = new Evaluation();
        evaluation.eval(iNDArray2, labelProbabilities(iNDArray));
        return evaluation.f1();
    }

    public int numLabels() {
        return this.labels.columns();
    }

    public double score(DataSet dataSet) {
        feedForward(dataSet.getFeatureMatrix());
        setLabels(dataSet.getLabels());
        return score();
    }

    @Override // org.deeplearning4j.nn.api.Classifier, org.deeplearning4j.nn.api.Model
    public double score() {
        feedForward();
        return getOutputLayer().score();
    }

    public double score(INDArray iNDArray) {
        INDArray params = params();
        setParameters(iNDArray);
        double score = score();
        double l2 = 0.5f * this.defaultConfiguration.getL2() * ((Double) Transforms.pow(this.mask.mul(iNDArray), 2).sum(Integer.MAX_VALUE).element()).doubleValue();
        setParameters(params);
        return score + l2;
    }

    protected void applyTransforms() {
        if (this.neuralNets == null || this.neuralNets.length < 1) {
            throw new IllegalStateException("Layers not initialized");
        }
        for (int i = 0; i < this.neuralNets.length; i++) {
            if (this.weightTransforms.containsKey(Integer.valueOf(i))) {
                this.neuralNets[i].setW((INDArray) this.weightTransforms.get(Integer.valueOf(i)).apply(this.neuralNets[i].getW()));
            }
            if (this.hiddenBiasTransforms.containsKey(Integer.valueOf(i))) {
                this.neuralNets[i].sethBias((INDArray) getHiddenBiasTransforms().get(Integer.valueOf(i)).apply(this.neuralNets[i].gethBias()));
            }
            if (this.visibleBiasTransforms.containsKey(Integer.valueOf(i))) {
                this.neuralNets[i].setvBias((INDArray) getVisibleBiasTransforms().get(Integer.valueOf(i)).apply(this.neuralNets[i].getvBias()));
            }
        }
    }

    public abstract NeuralNetwork createLayer(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, int i);

    public abstract void pretrain(DataSetIterator dataSetIterator, Object[] objArr);

    public abstract void pretrain(INDArray iNDArray, Object[] objArr);

    public abstract NeuralNetwork[] createNetworkLayers(int i);

    public Layer createHiddenLayer(int i, INDArray iNDArray) {
        return new Layer.Builder().withInput(iNDArray).conf(this.layerWiseConfigurations.get(i)).build();
    }

    public void merge(BaseMultiLayerNetwork baseMultiLayerNetwork, int i) {
        if (baseMultiLayerNetwork.getnLayers() != getnLayers()) {
            throw new IllegalArgumentException("Unable to merge networks that are not of equal length");
        }
        for (int i2 = 0; i2 < getnLayers(); i2++) {
            NeuralNetwork neuralNetwork = this.neuralNets[i2];
            neuralNetwork.merge(baseMultiLayerNetwork.neuralNets[i2], i);
            getLayers()[i2].setB(neuralNetwork.gethBias());
            getLayers()[i2].setW(neuralNetwork.getW());
        }
        getOutputLayer().merge(baseMultiLayerNetwork.getOutputLayer(), i);
    }

    public INDArray getLabels() {
        return this.labels;
    }

    public void setInput(INDArray iNDArray) {
        if (iNDArray != null && this.neuralNets == null) {
            initializeLayers(iNDArray);
        }
        this.input = iNDArray;
    }

    private void initMask() {
        setMask(Nd4j.ones(1, pack().length()));
        List<Pair<INDArray, INDArray>> unPack = unPack(getMask());
        for (int i = 0; i < unPack.size(); i++) {
            unPack.get(i).setSecond(Nd4j.zeros(unPack.get(i).getSecond().rows(), unPack.get(i).getSecond().columns()));
        }
        setMask(pack(unPack));
    }

    public boolean isShouldBackProp() {
        return this.shouldBackProp;
    }

    public INDArray getInput() {
        return this.input;
    }

    public synchronized NeuralNetwork[] getNeuralNets() {
        return this.neuralNets;
    }

    public boolean isForceNumEpochs() {
        return this.forceNumEpochs;
    }

    public int[] getHiddenLayerSizes() {
        return this.hiddenLayerSizes;
    }

    public void setHiddenLayerSizes(int[] iArr) {
        this.hiddenLayerSizes = iArr;
    }

    public Map<Integer, MatrixTransform> getWeightTransforms() {
        return this.weightTransforms;
    }

    public void setLabels(INDArray iNDArray) {
        this.labels = iNDArray;
    }

    public void setForceNumEpochs(boolean z) {
        this.forceNumEpochs = z;
    }

    public boolean isSampleFromHiddenActivations() {
        return this.sampleFromHiddenActivations;
    }

    public void setSampleFromHiddenActivations(boolean z) {
        this.sampleFromHiddenActivations = z;
    }

    public Map<Integer, MatrixTransform> getHiddenBiasTransforms() {
        return this.hiddenBiasTransforms;
    }

    public Map<Integer, MatrixTransform> getVisibleBiasTransforms() {
        return this.visibleBiasTransforms;
    }

    public int getnLayers() {
        return this.neuralNets.length;
    }

    public void setnLayers(int i) {
        this.neuralNets = createNetworkLayers(i);
    }

    public void setShouldBackProp(boolean z) {
        this.shouldBackProp = z;
    }

    public void setLayers(NeuralNetwork[] neuralNetworkArr) {
        this.neuralNets = neuralNetworkArr;
    }

    public boolean isUseGaussNewtonVectorProductBackProp() {
        return this.useGaussNewtonVectorProductBackProp;
    }

    public void setUseGaussNewtonVectorProductBackProp(boolean z) {
        this.useGaussNewtonVectorProductBackProp = z;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public void setMask(INDArray iNDArray) {
        this.mask = iNDArray;
    }

    public void clearInput() {
        this.input = null;
        for (int i = 0; i < this.neuralNets.length; i++) {
            this.neuralNets[i].clearInput();
            this.layers[i].setInput(null);
        }
    }

    public OutputLayer getOutputLayer() {
        return (OutputLayer) getLayers()[getLayers().length - 1];
    }

    public boolean isLineSearchBackProp() {
        return this.lineSearchBackProp;
    }

    public void setLineSearchBackProp(boolean z) {
        this.lineSearchBackProp = z;
    }

    public void setParameters(INDArray iNDArray) {
        for (int i = 0; i < getNeuralNets().length; i++) {
            ParamRange startIndexForLayer = startIndexForLayer(i);
            INDArray iNDArray2 = iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(0, iNDArray.rows()), NDArrayIndex.interval(startIndexForLayer.getwStart(), startIndexForLayer.getwEnd())});
            INDArray iNDArray3 = iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(0, iNDArray.rows()), NDArrayIndex.interval(startIndexForLayer.getBiasStart(), startIndexForLayer.getBiasEnd())});
            getNeuralNets()[i].setW(iNDArray2.reshape(getNeuralNets()[i].getW().rows(), getNeuralNets()[i].getW().columns()));
            getNeuralNets()[i].sethBias(iNDArray3.reshape(getNeuralNets()[i].gethBias().rows(), getNeuralNets()[i].gethBias().columns()));
        }
        ParamRange startIndexForLayer2 = startIndexForLayer(getNeuralNets().length);
        INDArray iNDArray4 = iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(0, iNDArray.rows()), NDArrayIndex.interval(startIndexForLayer2.getwStart(), startIndexForLayer2.getwEnd())});
        INDArray iNDArray5 = iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(0, iNDArray.rows()), NDArrayIndex.interval(startIndexForLayer2.getBiasStart(), startIndexForLayer2.getBiasEnd())});
        getOutputLayer().setW(iNDArray4.reshape(getOutputLayer().getW().rows(), getOutputLayer().getW().columns()));
        getOutputLayer().setB(iNDArray5.reshape(getOutputLayer().getB().rows(), getOutputLayer().getB().columns()));
    }

    public ParamRange startIndexForLayer(int i) {
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            i2 = i2 + getNeuralNets()[i3].getW().length() + getNeuralNets()[i3].gethBias().length();
        }
        if (i < getNeuralNets().length) {
            int length = i2 + getNeuralNets()[i].getW().length();
            return new ParamRange(i2, length, length, length + getNeuralNets()[i].gethBias().length());
        }
        int length2 = i2 + getOutputLayer().getW().length();
        return new ParamRange(i2, length2, length2, length2 + getOutputLayer().getB().length());
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void iterate(INDArray iNDArray, Object[] objArr) {
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, INDArray iNDArray2, Object[] objArr) {
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(org.nd4j.linalg.dataset.api.DataSet dataSet, Object[] objArr) {
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void fit(INDArray iNDArray, int[] iArr, Object[] objArr) {
    }

    @Override // org.deeplearning4j.nn.api.Classifier
    public void iterate(INDArray iNDArray, int[] iArr, Object[] objArr) {
    }

    static {
        $assertionsDisabled = !BaseMultiLayerNetwork.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger(BaseMultiLayerNetwork.class);
    }
}
