package org.deeplearning4j.nn;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.lang.reflect.Constructor;
import java.util.Arrays;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.Persistable;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.optimize.optimizers.NeuralNetworkOptimizer;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.Dl4jReflection;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/BaseNeuralNetwork.class */
public abstract class BaseNeuralNetwork implements NeuralNetwork, Persistable {
    private static final long serialVersionUID = -7074102204433996574L;
    protected INDArray W;
    protected INDArray hBias;
    protected INDArray vBias;
    protected INDArray input;
    protected transient NeuralNetworkOptimizer optimizer;
    protected INDArray doMask;
    private static Logger log;
    protected INDArray wGradient;
    protected INDArray vBiasGradient;
    protected INDArray hBiasGradient;
    protected int lastMiniBatchSize = 1;
    protected AdaGrad wAdaGrad;
    protected AdaGrad hBiasAdaGrad;
    protected AdaGrad vBiasAdaGrad;
    protected NeuralNetConfiguration conf;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/deeplearning4j/nn/BaseNeuralNetwork$Builder.class */
    public static class Builder<E extends BaseNeuralNetwork> {
        private E ret = null;
        private INDArray W;
        protected Class<? extends NeuralNetwork> clazz;
        private INDArray vBias;
        private INDArray hBias;
        private INDArray input;
        private NeuralNetConfiguration conf;

        public Builder<E> configure(NeuralNetConfiguration neuralNetConfiguration) {
            this.conf = neuralNetConfiguration;
            return this;
        }

        public E buildEmpty() {
            try {
                return (E) this.clazz.newInstance();
            } catch (IllegalAccessException | InstantiationException e) {
                throw new RuntimeException(e);
            }
        }

        public Builder<E> withClazz(Class<? extends BaseNeuralNetwork> cls) {
            this.clazz = cls;
            return this;
        }

        /* renamed from: withInput */
        public Builder<E> withInput2(INDArray iNDArray) {
            this.input = iNDArray;
            return this;
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* renamed from: asType */
        public Builder<E> asType2(Class<E> cls) {
            this.clazz = cls;
            return this;
        }

        /* renamed from: withWeights */
        public Builder<E> withWeights2(INDArray iNDArray) {
            this.W = iNDArray;
            return this;
        }

        /* renamed from: withVisibleBias */
        public Builder<E> withVisibleBias2(INDArray iNDArray) {
            this.vBias = iNDArray;
            return this;
        }

        /* renamed from: withHBias */
        public Builder<E> withHBias2(INDArray iNDArray) {
            this.hBias = iNDArray;
            return this;
        }

        public E build() {
            return buildWithInput();
        }

        private E buildWithInput() {
            for (Constructor<?> constructor : this.clazz.getDeclaredConstructors()) {
                constructor.setAccessible(true);
                Class<?>[] parameterTypes = constructor.getParameterTypes();
                if (parameterTypes != null && parameterTypes.length > 0 && parameterTypes[0].isAssignableFrom(INDArray.class)) {
                    try {
                        this.ret = (E) constructor.newInstance(this.input, this.W, this.hBias, this.vBias, this.conf);
                        return this.ret;
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            }
            return this.ret;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseNeuralNetwork() {
    }

    public BaseNeuralNetwork(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, NeuralNetConfiguration neuralNetConfiguration) {
        this.input = iNDArray;
        this.W = iNDArray2;
        this.conf = neuralNetConfiguration;
        if (this.W != null) {
            this.wAdaGrad = new AdaGrad(this.W.rows(), this.W.columns());
        }
        this.vBias = iNDArray4;
        if (this.vBias != null) {
            this.vBiasAdaGrad = new AdaGrad(this.vBias.rows(), this.vBias.columns());
        }
        this.hBias = iNDArray3;
        if (this.hBias != null) {
            this.hBiasAdaGrad = new AdaGrad(this.hBias.rows(), this.hBias.columns());
        }
        initWeights();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray params() {
        return Nd4j.toFlattened(new INDArray[]{this.W, this.vBias, this.hBias});
    }

    public double l2RegularizedCoefficient() {
        return ((((Double) Transforms.pow(getW(), 2).sum(Integer.MAX_VALUE).element()).doubleValue() / 2.0d) * this.conf.getL2()) + 9.999999974752427E-7d;
    }

    protected void initWeights() {
        if (this.conf.getnIn() < 1) {
            throw new IllegalStateException("Number of visible can not be less than 1");
        }
        if (this.conf.getnOut() < 1) {
            throw new IllegalStateException("Number of hidden can not be less than 1");
        }
        int i = this.conf.getnIn();
        int i2 = this.conf.getnOut();
        if (this.W == null) {
            this.W = Nd4j.zeros(i, i2);
            for (int i3 = 0; i3 < this.W.rows(); i3++) {
                this.W.putRow(i3, Nd4j.create(this.conf.getDist().sample(this.W.columns())));
            }
        }
        this.wAdaGrad = new AdaGrad(this.W.rows(), this.W.columns());
        if (this.hBias == null) {
            this.hBias = Nd4j.zeros(i2);
        }
        this.hBiasAdaGrad = new AdaGrad(this.hBias.rows(), this.hBias.columns());
        if (this.vBias == null) {
            if (this.input != null) {
                this.vBias = Nd4j.zeros(i);
            } else {
                this.vBias = Nd4j.zeros(i);
            }
        }
        this.vBiasAdaGrad = new AdaGrad(this.vBias.rows(), this.vBias.columns());
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams() {
        return (this.conf.getnIn() * this.conf.getnOut()) + this.conf.getnIn() + this.conf.getnOut();
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
        if (!$assertionsDisabled && iNDArray.length() != numParams()) {
            throw new AssertionError("Illegal number of parameters passed in, must be of length " + numParams());
        }
        int i = this.conf.getnIn() * this.conf.getnOut();
        INDArray iNDArray2 = iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(0, i)});
        INDArray iNDArray3 = iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(i, i + this.conf.getnIn())});
        INDArray iNDArray4 = iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(i + this.conf.getnIn(), i + this.conf.getnIn() + this.conf.getnOut())});
        setW(iNDArray2.reshape(this.conf.getnIn(), this.conf.getnOut()));
        setvBias(iNDArray3.dup());
        sethBias(iNDArray4.dup());
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void backProp(double d, int i, Object[] objArr) {
        double score = LossFunctions.score(this.input, LossFunctions.LossFunction.SQUARED_LOSS, transform(this.input), this.conf.getL2(), this.conf.isUseRegularization());
        NeuralNetwork mo11clone = mo11clone();
        while (1 != 0 && i <= i) {
            double score2 = LossFunctions.score(this.input, LossFunctions.LossFunction.SQUARED_LOSS, transform(this.input), this.conf.getL2(), this.conf.isUseRegularization());
            if (score2 > score || (score < 0.0d && score2 < score)) {
                update((BaseNeuralNetwork) mo11clone);
                log.info("Converged for new recon; breaking...");
                return;
            }
            if (Double.isNaN(score2) || Double.isInfinite(score2)) {
                update((BaseNeuralNetwork) mo11clone);
                log.info("Converged for new recon; breaking...");
                return;
            } else {
                if (score2 == score) {
                    return;
                }
                score = score2;
                mo11clone = mo11clone();
                log.info("Recon went down " + score);
                i++;
                int renderWeightsEveryNumEpochs = this.conf.getRenderWeightsEveryNumEpochs();
                if (renderWeightsEveryNumEpochs > 0) {
                    NeuralNetPlotter neuralNetPlotter = new NeuralNetPlotter();
                    if (i % renderWeightsEveryNumEpochs == 0) {
                        neuralNetPlotter.plotNetworkGradient(this, getGradient(objArr), getInput().rows());
                    }
                }
            }
        }
    }

    public void fit(INDArray iNDArray) {
        fit(iNDArray, null);
    }

    protected void applySparsity(INDArray iNDArray) {
        if (this.conf.isUseAdaGrad()) {
            iNDArray.addi(this.hBiasAdaGrad.getLearningRates(this.hBias).neg().muli(Float.valueOf(this.conf.getSparsity())).mul(iNDArray.mul(Float.valueOf(this.conf.getSparsity()))));
        } else {
            iNDArray.addi(iNDArray.mul(Float.valueOf(this.conf.getSparsity())).mul(Float.valueOf((-this.conf.getLr()) * this.conf.getSparsity())));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateGradientAccordingToParams(NeuralNetworkGradient neuralNetworkGradient, int i, double d) {
        int intValue;
        INDArray iNDArray = neuralNetworkGradient.getwGradient();
        INDArray iNDArray2 = neuralNetworkGradient.gethBiasGradient();
        INDArray iNDArray3 = neuralNetworkGradient.getvBiasGradient();
        if (i != 0 && this.conf.getResetAdaGradIterations() > 0 && i % this.conf.getResetAdaGradIterations() == 0) {
            this.wAdaGrad.historicalGradient = null;
            this.hBiasAdaGrad.historicalGradient = null;
            this.vBiasAdaGrad.historicalGradient = null;
            if (this.W != null && this.wAdaGrad == null) {
                this.wAdaGrad = new AdaGrad(this.W.rows(), this.W.columns());
            }
            if (this.vBias != null && this.vBiasAdaGrad == null) {
                this.vBiasAdaGrad = new AdaGrad(this.vBias.rows(), this.vBias.columns());
            }
            if (this.hBias != null && this.hBiasAdaGrad == null) {
                this.hBiasAdaGrad = new AdaGrad(this.hBias.rows(), this.hBias.columns());
            }
            log.info("Resetting adagrad");
        }
        INDArray learningRates = this.wAdaGrad.getLearningRates(iNDArray);
        double momentum = this.conf.getMomentum();
        if (this.conf.getMomentumAfter() != null && !this.conf.getMomentumAfter().isEmpty() && i >= (intValue = this.conf.getMomentumAfter().keySet().iterator().next().intValue())) {
            momentum = this.conf.getMomentumAfter().get(Integer.valueOf(intValue)).floatValue();
        }
        if (this.conf.isUseAdaGrad()) {
            iNDArray.muli(learningRates);
        } else {
            iNDArray.muli(Double.valueOf(d));
        }
        if (this.conf.isUseAdaGrad()) {
            iNDArray2.muli(this.hBiasAdaGrad.getLearningRates(iNDArray2));
        } else {
            iNDArray2.muli(Double.valueOf(d));
        }
        if (this.conf.isUseAdaGrad()) {
            iNDArray3.muli(this.vBiasAdaGrad.getLearningRates(iNDArray3));
        } else {
            iNDArray3.muli(Double.valueOf(d));
        }
        if (this.hBiasGradient != null && this.conf.getSparsity() != 0.0f) {
            applySparsity(iNDArray2);
        }
        if (momentum != 0.0d && this.wGradient != null) {
            iNDArray.addi(this.wGradient.mul(Double.valueOf(momentum)).addi(iNDArray.mul(Double.valueOf(1.0d - momentum))));
        }
        if (momentum != 0.0d && this.vBiasGradient != null) {
            iNDArray3.addi(this.vBiasGradient.mul(Double.valueOf(momentum)).addi(iNDArray3.mul(Double.valueOf(1.0d - momentum))));
        }
        if (momentum != 0.0d && this.hBiasGradient != null) {
            iNDArray2.addi(this.hBiasGradient.mul(Double.valueOf(momentum)).addi(iNDArray2.mul(Double.valueOf(1.0d - momentum))));
        }
        iNDArray.divi(Integer.valueOf(this.lastMiniBatchSize));
        iNDArray3.divi(Integer.valueOf(this.lastMiniBatchSize));
        iNDArray2.divi(Integer.valueOf(this.lastMiniBatchSize));
        if (this.conf.isUseRegularization() && this.conf.getL2() > 0.0f) {
            if (this.conf.isUseAdaGrad()) {
                iNDArray.subi(this.W.mul(Float.valueOf(this.conf.getL2())).muli(learningRates));
            } else {
                iNDArray.subi(this.W.mul(Double.valueOf(this.conf.getL2() * d)));
            }
        }
        if (this.conf.isConstrainGradientToUnitNorm()) {
            iNDArray.divi(iNDArray.norm2(Integer.MAX_VALUE));
            iNDArray3.divi(iNDArray3.norm2(Integer.MAX_VALUE));
            iNDArray2.divi(iNDArray2.norm2(Integer.MAX_VALUE));
        }
        this.wGradient = iNDArray;
        this.vBiasGradient = iNDArray3;
        this.hBiasGradient = iNDArray2;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public double score() {
        return this.conf.getLossFunction() != LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY ? LossFunctions.score(this.input, this.conf.getLossFunction(), transform(this.input), this.conf.getL2(), this.conf.isUseRegularization()) : -LossFunctions.reconEntropy(this.input, this.hBias, this.vBias, this.W, this.conf.getActivationFunction());
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void clearInput() {
        this.input = null;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public AdaGrad getAdaGrad() {
        return this.wAdaGrad;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void setAdaGrad(AdaGrad adaGrad) {
        this.wAdaGrad = adaGrad;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public NeuralNetwork transpose() {
        try {
            Constructor<?> emptyConstructor = Dl4jReflection.getEmptyConstructor(getClass());
            emptyConstructor.setAccessible(true);
            NeuralNetwork neuralNetwork = (NeuralNetwork) emptyConstructor.newInstance(new Object[0]);
            neuralNetwork.setVBiasAdaGrad(this.hBiasAdaGrad);
            neuralNetwork.sethBias(this.vBias.dup());
            neuralNetwork.setConf(this.conf);
            neuralNetwork.setvBias(Nd4j.zeros(this.hBias.rows(), this.hBias.columns()));
            neuralNetwork.setW(this.W.transpose());
            return neuralNetwork;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
        this.conf = neuralNetConfiguration;
    }

    @Override // 
    /* renamed from: clone */
    public NeuralNetwork mo11clone() {
        try {
            Constructor<?> emptyConstructor = Dl4jReflection.getEmptyConstructor(getClass());
            emptyConstructor.setAccessible(true);
            NeuralNetwork neuralNetwork = (NeuralNetwork) emptyConstructor.newInstance(new Object[0]);
            neuralNetwork.setConf(this.conf);
            neuralNetwork.setHbiasAdaGrad(this.hBiasAdaGrad);
            neuralNetwork.setVBiasAdaGrad(this.vBiasAdaGrad);
            neuralNetwork.sethBias(this.hBias.dup());
            neuralNetwork.setvBias(this.vBias.dup());
            neuralNetwork.setW(this.W.dup());
            neuralNetwork.setAdaGrad(this.wAdaGrad);
            return neuralNetwork;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void merge(NeuralNetwork neuralNetwork, int i) {
        this.W.addi(neuralNetwork.getW().sub(this.W).divi(Integer.valueOf(i)));
        this.hBias.addi(neuralNetwork.gethBias().sub(this.hBias).divi(Integer.valueOf(i)));
        this.vBias.addi(neuralNetwork.getvBias().subi(this.vBias).divi(Integer.valueOf(i)));
    }

    public void update(BaseNeuralNetwork baseNeuralNetwork) {
        this.W = baseNeuralNetwork.W;
        this.conf = baseNeuralNetwork.conf;
        this.hBias = baseNeuralNetwork.hBias;
        this.vBias = baseNeuralNetwork.vBias;
        this.wAdaGrad = baseNeuralNetwork.wAdaGrad;
        this.hBiasAdaGrad = baseNeuralNetwork.hBiasAdaGrad;
        this.vBiasAdaGrad = baseNeuralNetwork.vBiasAdaGrad;
    }

    @Override // org.deeplearning4j.nn.api.Persistable
    public void load(InputStream inputStream) {
        try {
            update((BaseNeuralNetwork) new ObjectInputStream(inputStream).readObject());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray getW() {
        return this.W;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void setW(INDArray iNDArray) {
        if (!$assertionsDisabled && !Arrays.equals(iNDArray.shape(), new int[]{this.conf.getnIn(), this.conf.getnOut()})) {
            throw new AssertionError("Invalid shape for w, must be " + Arrays.toString(new int[]{this.conf.getnIn(), this.conf.getnOut()}));
        }
        this.W = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray gethBias() {
        return this.hBias;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void sethBias(INDArray iNDArray) {
        if (!$assertionsDisabled && !Arrays.equals(iNDArray.shape(), new int[]{this.conf.getnOut()})) {
            throw new AssertionError("Illegal shape for visible bias, must be of shape " + new int[]{this.conf.getnOut()});
        }
        this.hBias = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray getvBias() {
        return this.vBias;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void setvBias(INDArray iNDArray) {
        if (!$assertionsDisabled && !Arrays.equals(iNDArray.shape(), new int[]{this.conf.getnIn()})) {
            throw new AssertionError("Illegal shape for visible bias, must be of shape " + Arrays.toString(new int[]{this.conf.getnIn()}));
        }
        this.vBias = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray getInput() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void setInput(INDArray iNDArray) {
        this.input = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public AdaGrad gethBiasAdaGrad() {
        return this.hBiasAdaGrad;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void setHbiasAdaGrad(AdaGrad adaGrad) {
        this.hBiasAdaGrad = adaGrad;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public AdaGrad getVBiasAdaGrad() {
        return this.vBiasAdaGrad;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public void setVBiasAdaGrad(AdaGrad adaGrad) {
        this.vBiasAdaGrad = adaGrad;
    }

    @Override // org.deeplearning4j.nn.api.Persistable
    public void write(OutputStream outputStream) {
        try {
            new ObjectOutputStream(outputStream).writeObject(this);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public abstract INDArray transform(INDArray iNDArray);

    /* JADX INFO: Access modifiers changed from: protected */
    public void applyDropOutIfNecessary(INDArray iNDArray) {
        if (this.conf.getDropOut() > 0.0f) {
            this.doMask = Nd4j.rand(iNDArray.rows(), iNDArray.columns()).gt(Float.valueOf(this.conf.getDropOut()));
        } else {
            this.doMask = Nd4j.ones(iNDArray.rows(), iNDArray.columns());
        }
        iNDArray.muli(this.doMask);
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork
    public INDArray hBiasMean() {
        return getInput().mmul(getW()).addRowVector(gethBias());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public INDArray preProcessInput(INDArray iNDArray) {
        return this.conf.isConcatBiases() ? Nd4j.hstack(new INDArray[]{iNDArray, Nd4j.ones(iNDArray.rows(), 1)}) : iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.NeuralNetwork, org.deeplearning4j.optimize.api.IterationListener
    public void iterationDone(int i) {
        int renderWeightsEveryNumEpochs = this.conf.getRenderWeightsEveryNumEpochs();
        if (renderWeightsEveryNumEpochs <= 0) {
            return;
        }
        if (i % renderWeightsEveryNumEpochs == 0 || i == 0) {
            new NeuralNetPlotter().plotNetworkGradient(this, getGradient(new Object[]{1, Double.valueOf(0.001d), 1000}), getInput().rows());
        }
    }

    static {
        $assertionsDisabled = !BaseNeuralNetwork.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger(BaseNeuralNetwork.class);
    }
}
