package org.deeplearning4j.models.classifiers.sda;

import java.util.List;
import java.util.Map;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.models.featuredetectors.da.DenoisingAutoEncoder;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.transformation.MatrixTransform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/classifiers/sda/StackedDenoisingAutoEncoder.class */
public class StackedDenoisingAutoEncoder extends BaseMultiLayerNetwork {
    private static final long serialVersionUID = 1448581794985193009L;
    private static Logger log = LoggerFactory.getLogger(StackedDenoisingAutoEncoder.class);

    /* loaded from: input_file:org/deeplearning4j/models/classifiers/sda/StackedDenoisingAutoEncoder$Builder.class */
    public static class Builder extends BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> {
        public Builder() {
            this.clazz = StackedDenoisingAutoEncoder.class;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: configure, reason: merged with bridge method [inline-methods] */
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> configure2(NeuralNetConfiguration neuralNetConfiguration) {
            super.configure2(neuralNetConfiguration);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: useGaussNewtonVectorProductBackProp */
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> useGaussNewtonVectorProductBackProp2(boolean z) {
            super.useGaussNewtonVectorProductBackProp2(z);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: pretrain, reason: merged with bridge method [inline-methods] */
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> pretrain2(boolean z) {
            this.pretrain = z;
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: useDropConnection */
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> useDropConnection2(boolean z) {
            super.useDropConnection2(z);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> withVisibleBiasTransforms(Map<Integer, MatrixTransform> map) {
            super.withVisibleBiasTransforms(map);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> withHiddenBiasTransforms(Map<Integer, MatrixTransform> map) {
            super.withHiddenBiasTransforms(map);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: forceIterations */
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> forceIterations2() {
            this.shouldForceEpochs = true;
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: disableBackProp */
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> disableBackProp2() {
            this.backProp = false;
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: transformWeightsAt */
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> transformWeightsAt2(int i, MatrixTransform matrixTransform) {
            this.weightTransforms.put(Integer.valueOf(i), matrixTransform);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> transformWeightsAt(Map<Integer, MatrixTransform> map) {
            this.weightTransforms.putAll(map);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> layerWiseConfiguration(List<NeuralNetConfiguration> list) {
            super.layerWiseConfiguration(list);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: hiddenLayerSizes */
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> hiddenLayerSizes2(Integer[] numArr) {
            super.hiddenLayerSizes2(numArr);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: hiddenLayerSizes */
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> hiddenLayerSizes2(int[] iArr) {
            super.hiddenLayerSizes2(iArr);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withInput */
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> withInput2(INDArray iNDArray) {
            super.withInput2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withLabels */
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> withLabels2(INDArray iNDArray) {
            super.withLabels2(iNDArray);
            return this;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> withClazz(Class<? extends BaseMultiLayerNetwork> cls) {
            this.clazz = cls;
            return this;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        public StackedDenoisingAutoEncoder build() {
            StackedDenoisingAutoEncoder stackedDenoisingAutoEncoder = (StackedDenoisingAutoEncoder) super.build();
            if (stackedDenoisingAutoEncoder.defaultConfiguration == null) {
                stackedDenoisingAutoEncoder.defaultConfiguration = this.layerWiseConfiguration.get(0);
            }
            stackedDenoisingAutoEncoder.initializeLayers(Nd4j.zeros(1, stackedDenoisingAutoEncoder.defaultConfiguration.getnIn()));
            return stackedDenoisingAutoEncoder;
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withClazz, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> withClazz2(Class cls) {
            return withClazz((Class<? extends BaseMultiLayerNetwork>) cls);
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: transformWeightsAt, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> transformWeightsAt2(Map map) {
            return transformWeightsAt((Map<Integer, MatrixTransform>) map);
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withHiddenBiasTransforms, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> withHiddenBiasTransforms2(Map map) {
            return withHiddenBiasTransforms((Map<Integer, MatrixTransform>) map);
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: withVisibleBiasTransforms, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> withVisibleBiasTransforms2(Map map) {
            return withVisibleBiasTransforms((Map<Integer, MatrixTransform>) map);
        }

        @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork.Builder
        /* renamed from: layerWiseConfiguration, reason: avoid collision after fix types in other method */
        public /* bridge */ /* synthetic */ BaseMultiLayerNetwork.Builder<StackedDenoisingAutoEncoder> layerWiseConfiguration2(List list) {
            return layerWiseConfiguration((List<NeuralNetConfiguration>) list);
        }
    }

    public void pretrain(float f, float f2, int i) {
        pretrain(getInput(), f, f2, i);
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public void pretrain(DataSetIterator dataSetIterator, Object[] objArr) {
        float floatValue = ((Float) objArr[0]).floatValue();
        float floatValue2 = ((Float) objArr[1]).floatValue();
        int intValue = ((Integer) objArr[2]).intValue();
        int intValue2 = objArr.length > 3 ? ((Integer) objArr[3]).intValue() : 1;
        for (int i = 0; i < intValue2; i++) {
            pretrain(dataSetIterator, floatValue, floatValue2, intValue);
        }
    }

    public void pretrain(DataSetIterator dataSetIterator, float f, float f2, int i) {
        for (int i2 = 0; i2 < getnLayers(); i2++) {
            if (i2 == 0) {
                while (dataSetIterator.hasNext()) {
                    DataSet next = dataSetIterator.next();
                    this.input = next.getFeatureMatrix();
                    if (getInput() == null || getNeuralNets() == null || getNeuralNets()[0] == null || getNeuralNets() == null || getNeuralNets()[0] == null) {
                        setInput(this.input);
                        initializeLayers(this.input);
                    } else {
                        setInput(this.input);
                    }
                    float lr = this.layerWiseConfigurations.get(i2).getLr();
                    if (forceNumIterations()) {
                        for (int i3 = 0; i3 < i; i3++) {
                            log.info("Error on iteration " + i3 + " for layer " + (i2 + 1) + " is " + getNeuralNets()[i2].score());
                            getNeuralNets()[i2].iterate(next.getFeatureMatrix(), new Object[]{Float.valueOf(f), Float.valueOf(f2)});
                            getNeuralNets()[i2].iterationDone(i3);
                        }
                    } else {
                        getNeuralNets()[i2].fit(next.getFeatureMatrix(), new Object[]{Float.valueOf(f), Float.valueOf(lr), Integer.valueOf(i)});
                    }
                }
                dataSetIterator.reset();
            } else {
                while (dataSetIterator.hasNext()) {
                    INDArray featureMatrix = dataSetIterator.next().getFeatureMatrix();
                    for (int i4 = 1; i4 <= i2; i4++) {
                        featureMatrix = activationFromPrevLayer(i4, featureMatrix);
                    }
                    log.info("Training on layer " + (i2 + 1));
                    float lr2 = this.layerWiseConfigurations.get(i2).getLr();
                    if (forceNumIterations()) {
                        for (int i5 = 0; i5 < i; i5++) {
                            log.info("Error on iteration " + i5 + " for layer " + (i2 + 1) + " is " + getNeuralNets()[i2].score());
                            getNeuralNets()[i2].iterate(featureMatrix, new Object[]{Float.valueOf(f), Float.valueOf(f2)});
                            getNeuralNets()[i2].iterationDone(i5);
                        }
                    } else {
                        getNeuralNets()[i2].fit(featureMatrix, new Object[]{Float.valueOf(f), Float.valueOf(lr2), Integer.valueOf(i)});
                    }
                }
                dataSetIterator.reset();
            }
        }
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public void pretrain(INDArray iNDArray, Object[] objArr) {
        pretrain(iNDArray, this.defaultConfiguration.getLr(), this.defaultConfiguration.getCorruptionLevel(), this.defaultConfiguration.getNumIterations());
    }

    public void pretrain(INDArray iNDArray, float f, float f2, int i) {
        if (getInput() == null) {
            initializeLayers(iNDArray.dup());
        }
        if (this.pretrain) {
            if (isUseGaussNewtonVectorProductBackProp()) {
                log.warn("Warning; using gauss newton vector back prop with pretrain is known to cause issues with obscenely large activations.");
            }
            this.input = iNDArray;
            INDArray iNDArray2 = null;
            int i2 = 0;
            while (i2 < getnLayers()) {
                iNDArray2 = i2 == 0 ? iNDArray : getNeuralNets()[i2 - 1].sampleHiddenGivenVisible(iNDArray2).getSecond();
                if (forceNumIterations()) {
                    for (int i3 = 0; i3 < i; i3++) {
                        getNeuralNets()[i2].iterate(iNDArray2, new Object[]{Float.valueOf(f2), Float.valueOf(f)});
                        log.info("Error on iteration " + i3 + " for layer " + (i2 + 1) + " is " + getNeuralNets()[i2].score());
                        getNeuralNets()[i2].iterationDone(i3);
                    }
                } else {
                    getNeuralNets()[i2].fit(iNDArray2, new Object[]{Float.valueOf(f2), Float.valueOf(f), Integer.valueOf(i)});
                }
                i2++;
            }
        }
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public NeuralNetwork createLayer(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, int i) {
        return new DenoisingAutoEncoder.Builder().configure(this.layerWiseConfigurations.get(i)).withInput2(iNDArray).withWeights2(iNDArray2).withHBias2(iNDArray3).withVisibleBias2(iNDArray4).build();
    }

    @Override // org.deeplearning4j.nn.BaseMultiLayerNetwork
    public NeuralNetwork[] createNetworkLayers(int i) {
        return new DenoisingAutoEncoder[i];
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray, Object[] objArr) {
    }
}
