package org.deeplearning4j.nn.layers;

import java.lang.reflect.Constructor;
import java.util.Arrays;
import org.deeplearning4j.nn.WeightInitUtil;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/layers/BaseLayer.class */
public abstract class BaseLayer implements org.deeplearning4j.nn.api.Layer {
    protected INDArray W;
    protected INDArray b;
    protected INDArray input;
    protected NeuralNetConfiguration conf;
    protected INDArray dropoutMask;
    static final /* synthetic */ boolean $assertionsDisabled;

    public BaseLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        this.input = iNDArray3;
        this.conf = neuralNetConfiguration;
        if (iNDArray == null) {
            this.W = createWeightMatrix();
        } else {
            this.W = iNDArray;
        }
        if (iNDArray2 == null) {
            this.b = createBias();
        } else {
            this.b = iNDArray2;
        }
    }

    protected INDArray createBias() {
        return Nd4j.zeros(this.conf.getnOut());
    }

    protected INDArray createWeightMatrix() {
        return WeightInitUtil.initWeights(this.conf.getnIn(), this.conf.getnOut(), this.conf.getWeightInit(), this.conf.getActivationFunction(), this.conf.getDist());
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray) {
        if (iNDArray == null) {
            throw new IllegalArgumentException("No null input allowed");
        }
        this.input = iNDArray;
        INDArray mmul = this.input.mmul(this.W);
        if (mmul.columns() != this.b.columns()) {
            throw new IllegalStateException("This is weird");
        }
        if (this.conf.isConcatBiases()) {
            mmul = Nd4j.hstack(new INDArray[]{mmul, this.b});
        } else {
            mmul.addiRowVector(this.b);
        }
        return mmul;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return (INDArray) this.conf.getActivationFunction().apply(getInput().mmul(getW()).addRowVector(getB()));
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        if (iNDArray != null) {
            this.input = Transforms.stabilize(iNDArray, 1.0d);
        }
        return activate();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setConfiguration(NeuralNetConfiguration neuralNetConfiguration) {
        this.conf = neuralNetConfiguration;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray getW() {
        return this.W;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setW(INDArray iNDArray) {
        if (!$assertionsDisabled && (iNDArray.rows() != conf().getnIn() || iNDArray.columns() != this.conf.getnOut())) {
            throw new AssertionError("Weight matrix must be of shape " + Arrays.toString(new int[]{conf().getnIn(), this.conf.getnOut()}));
        }
        this.W = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray getB() {
        return this.b;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setB(INDArray iNDArray) {
        if (!$assertionsDisabled && iNDArray.columns() != conf().getnOut()) {
            throw new AssertionError("The bias must have " + conf().getnOut() + " columns");
        }
        this.b = iNDArray;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray getInput() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setInput(INDArray iNDArray) {
        this.input = iNDArray;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void applyDropOutIfNecessary(INDArray iNDArray) {
        if (this.conf.getDropOut() > 0.0f) {
            this.dropoutMask = Nd4j.rand(iNDArray.rows(), this.conf.getnOut()).gt(Float.valueOf(this.conf.getDropOut()));
        } else {
            this.dropoutMask = Nd4j.ones(iNDArray.rows(), this.conf.getnOut());
        }
        iNDArray.muli(this.dropoutMask);
    }

    public void merge(org.deeplearning4j.nn.api.Layer layer, int i) {
        if (this.conf.isUseRegularization()) {
            this.W.addi(layer.getW().subi(this.W).div(Integer.valueOf(i)));
            this.b.addi(layer.getB().subi(this.b).div(Integer.valueOf(i)));
        } else {
            this.W.addi(layer.getW().subi(this.W));
            this.b.addi(layer.getB().subi(this.b));
        }
    }

    @Override // 
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public org.deeplearning4j.nn.api.Layer mo25clone() {
        org.deeplearning4j.nn.api.Layer layer = null;
        try {
            Constructor<?> constructor = getClass().getConstructor(NeuralNetConfiguration.class, INDArray.class, INDArray.class, INDArray.class);
            Object[] objArr = new Object[4];
            objArr[0] = this.conf;
            objArr[1] = this.W.dup();
            objArr[2] = this.b.dup();
            objArr[3] = this.input != null ? this.input.dup() : null;
            layer = (org.deeplearning4j.nn.api.Layer) constructor.newInstance(objArr);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return layer;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public org.deeplearning4j.nn.api.Layer transpose() {
        org.deeplearning4j.nn.api.Layer layer = null;
        try {
            Constructor<?> constructor = getClass().getConstructor(NeuralNetConfiguration.class, INDArray.class, INDArray.class, INDArray.class);
            NeuralNetConfiguration m21clone = this.conf.m21clone();
            int i = m21clone.getnOut();
            int i2 = m21clone.getnIn();
            m21clone.setnIn(i);
            m21clone.setnOut(i2);
            Object[] objArr = new Object[4];
            objArr[0] = this.conf;
            objArr[1] = this.W.transpose().dup();
            objArr[2] = this.b.transpose().dup();
            objArr[3] = this.input != null ? this.input.transpose().dup() : null;
            layer = (org.deeplearning4j.nn.api.Layer) constructor.newInstance(objArr);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return layer;
    }

    public String toString() {
        return "BaseLayer{W=" + this.W + ", b=" + this.b + ", input=" + this.input + ", conf=" + this.conf + ", dropoutMask=" + this.dropoutMask + '}';
    }

    static {
        $assertionsDisabled = !BaseLayer.class.desiredAssertionStatus();
    }
}
