package org.deeplearning4j.models.classifiers.dbn;

import java.util.List;
import java.util.Map;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.models.featuredetectors.rbm.RBM;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.Layer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.transformation.MatrixTransform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/classifiers/dbn/DBN.class */
public class DBN extends BaseMultiLayerNetwork {
    private static final long serialVersionUID = -9068772752220902983L;
    private static Logger log = LoggerFactory.getLogger(DBN.class);
    private boolean useRBMPropUpAsActivations = true;

    /* loaded from: input_file:org/deeplearning4j/models/classifiers/dbn/DBN$Builder.class */
    public static class Builder extends BaseMultiLayerNetwork.Builder<DBN> {
        private boolean useRBMPropUpAsActivation = false;

        public Builder() {
            this.clazz = DBN.class;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: configure, reason: merged with bridge method [inline-methods] */
        public BaseMultiLayerNetwork.Builder<DBN> configure2(NeuralNetConfiguration neuralNetConfiguration) {
            super.configure2(neuralNetConfiguration);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: useGaussNewtonVectorProductBackProp */
        public BaseMultiLayerNetwork.Builder<DBN> useGaussNewtonVectorProductBackProp2(boolean z) {
            super.useGaussNewtonVectorProductBackProp2(z);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: useDropConnection */
        public BaseMultiLayerNetwork.Builder<DBN> useDropConnection2(boolean z) {
            super.useDropConnection2(z);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<DBN> layerWiseCOnfiguration(List<NeuralNetConfiguration> list) {
            super.layerWiseCOnfiguration(list);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: lineSearchBackProp */
        public BaseMultiLayerNetwork.Builder<DBN> lineSearchBackProp2(boolean z) {
            super.lineSearchBackProp2(z);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<DBN> withVisibleBiasTransforms(Map<Integer, MatrixTransform> map) {
            super.withVisibleBiasTransforms(map);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<DBN> withHiddenBiasTransforms(Map<Integer, MatrixTransform> map) {
            super.withHiddenBiasTransforms(map);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: forceEpochs */
        public BaseMultiLayerNetwork.Builder<DBN> forceEpochs2() {
            this.shouldForceEpochs = true;
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: disableBackProp */
        public BaseMultiLayerNetwork.Builder<DBN> disableBackProp2() {
            this.backProp = false;
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: transformWeightsAt */
        public BaseMultiLayerNetwork.Builder<DBN> transformWeightsAt2(int i, MatrixTransform matrixTransform) {
            this.weightTransforms.put(Integer.valueOf(i), matrixTransform);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<DBN> transformWeightsAt(Map<Integer, MatrixTransform> map) {
            this.weightTransforms.putAll(map);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: hiddenLayerSizes */
        public BaseMultiLayerNetwork.Builder<DBN> hiddenLayerSizes2(Integer[] numArr) {
            super.hiddenLayerSizes2(numArr);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: hiddenLayerSizes */
        public BaseMultiLayerNetwork.Builder<DBN> hiddenLayerSizes2(int[] iArr) {
            super.hiddenLayerSizes2(iArr);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withInput */
        public BaseMultiLayerNetwork.Builder<DBN> withInput2(INDArray iNDArray) {
            super.withInput2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withLabels */
        public BaseMultiLayerNetwork.Builder<DBN> withLabels2(INDArray iNDArray) {
            super.withLabels2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<DBN> withClazz(Class<? extends BaseMultiLayerNetwork> cls) {
            this.clazz = cls;
            return this;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public DBN build() {
            DBN dbn = (DBN) super.build();
            dbn.useRBMPropUpAsActivations = this.useRBMPropUpAsActivation;
            dbn.initializeLayers(Nd4j.zeros(1, dbn.defaultConfiguration.getnIn()));
            return dbn;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withClazz, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<DBN> withClazz2(Class cls) {
            return withClazz((Class<? extends BaseMultiLayerNetwork>) cls);
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: transformWeightsAt, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<DBN> transformWeightsAt2(Map map) {
            return transformWeightsAt((Map<Integer, MatrixTransform>) map);
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withHiddenBiasTransforms, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<DBN> withHiddenBiasTransforms2(Map map) {
            return withHiddenBiasTransforms((Map<Integer, MatrixTransform>) map);
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withVisibleBiasTransforms, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<DBN> withVisibleBiasTransforms2(Map map) {
            return withVisibleBiasTransforms((Map<Integer, MatrixTransform>) map);
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: layerWiseCOnfiguration, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<DBN> layerWiseCOnfiguration2(List list) {
            return layerWiseCOnfiguration((List<NeuralNetConfiguration>) list);
        }
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public Layer createHiddenLayer(int i, INDArray iNDArray) {
        return (Layer) super.createHiddenLayer(i, iNDArray);
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public void pretrain(DataSetIterator dataSetIterator, Object[] objArr) {
        int intValue = objArr.length > 3 ? ((Integer) objArr[3]).intValue() : 1;
        for (int i = 0; i < intValue; i++) {
            pretrain(this.input, this.defaultConfiguration.getK(), this.defaultConfiguration.getLr(), this.defaultConfiguration.getNumIterations());
        }
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public void pretrain(INDArray iNDArray, Object[] objArr) {
        pretrain(iNDArray, this.defaultConfiguration.getK(), this.defaultConfiguration.getLr(), this.defaultConfiguration.getNumIterations());
    }

    public void pretrain(DataSetIterator dataSetIterator, int i, float f, int i2) {
        for (int i3 = 0; i3 < getnLayers(); i3++) {
            if (i3 == 0) {
                while (dataSetIterator.hasNext()) {
                    DataSet next = dataSetIterator.next();
                    this.input = next.getFeatureMatrix();
                    if (getInput() == null || getNeuralNets() == null || getNeuralNets()[0] == null || getNeuralNets() == null || getNeuralNets()[0] == null) {
                        setInput(this.input);
                        initializeLayers(this.input);
                    } else {
                        setInput(this.input);
                    }
                    float lr = this.layerWiseConfigurations.get(i3).getLr();
                    if (isForceNumEpochs()) {
                        for (int i4 = 0; i4 < i2; i4++) {
                            log.info("Error on iteration " + i4 + " for layer " + (i3 + 1) + " is " + getNeuralNets()[i3].score());
                            getNeuralNets()[i3].iterate(next.getFeatureMatrix(), new Object[]{Integer.valueOf(i), Float.valueOf(f)});
                            getNeuralNets()[i3].iterationDone(i4);
                        }
                    } else {
                        getNeuralNets()[i3].fit(next.getFeatureMatrix(), new Object[]{Integer.valueOf(i), Float.valueOf(lr), Integer.valueOf(i2)});
                    }
                }
                dataSetIterator.reset();
            } else {
                while (dataSetIterator.hasNext()) {
                    INDArray featureMatrix = dataSetIterator.next().getFeatureMatrix();
                    for (int i5 = 1; i5 <= i3; i5++) {
                        featureMatrix = activationFromPrevLayer(i5, featureMatrix);
                    }
                    log.info("Training on layer " + (i3 + 1));
                    float lr2 = this.layerWiseConfigurations.get(i3).getLr();
                    if (isForceNumEpochs()) {
                        for (int i6 = 0; i6 < i2; i6++) {
                            log.info("Error on epoch " + i6 + " for layer " + (i3 + 1) + " is " + getNeuralNets()[i3].score());
                            getNeuralNets()[i3].iterate(featureMatrix, new Object[]{Integer.valueOf(i), Float.valueOf(f)});
                            getNeuralNets()[i3].iterationDone(i6);
                        }
                    } else {
                        getNeuralNets()[i3].fit(featureMatrix, new Object[]{Integer.valueOf(i), Float.valueOf(lr2), Integer.valueOf(i2)});
                    }
                }
                dataSetIterator.reset();
            }
        }
    }

    public void pretrain(INDArray iNDArray, int i, float f, int i2) {
        if (isUseGaussNewtonVectorProductBackProp()) {
            log.warn("WARNING; Gauss newton back vector back propagation is primarily used for hessian free which does not involve pretrain; just finetune. Use this at your own risk");
        }
        if (getInput() == null || getNeuralNets() == null || getNeuralNets()[0] == null || getNeuralNets() == null || getNeuralNets()[0] == null) {
            setInput(iNDArray);
            initializeLayers(iNDArray);
        } else {
            setInput(iNDArray);
        }
        INDArray iNDArray2 = null;
        int i3 = 0;
        while (i3 < getnLayers()) {
            iNDArray2 = i3 == 0 ? getInput() : activationFromPrevLayer(i3 - 1, iNDArray2);
            log.info("Training on layer " + (i3 + 1));
            float lr = this.layers[i3].conf().getLr();
            if (isForceNumEpochs()) {
                for (int i4 = 0; i4 < i2; i4++) {
                    log.info("Error on epoch " + i4 + " for layer " + (i3 + 1) + " is " + getNeuralNets()[i3].score());
                    getNeuralNets()[i3].iterate(iNDArray2, new Object[]{Integer.valueOf(i), Float.valueOf(f)});
                    getNeuralNets()[i3].iterationDone(i4);
                }
            } else {
                getNeuralNets()[i3].fit(iNDArray2, new Object[]{Integer.valueOf(i), Float.valueOf(lr), Integer.valueOf(i2)});
            }
            i3++;
        }
    }

    public void pretrain(int i, float f, int i2) {
        pretrain(getInput(), i, f, i2);
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public NeuralNetwork createLayer(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, int i) {
        return new RBM.Builder().withInput2(iNDArray).withWeights2(iNDArray2).withHBias2(iNDArray3).withVisibleBias2(iNDArray4).configure(this.layerWiseConfigurations.get(i)).build();
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public NeuralNetwork[] createNetworkLayers(int i) {
        return new RBM[i];
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, Object[] objArr) {
        pretrain(iNDArray, this.defaultConfiguration.getK(), this.defaultConfiguration.getLr(), this.defaultConfiguration.getNumIterations());
    }
}
