package org.bigml.mimir.deepnet.layers;

import java.io.IOException;
import java.io.ObjectInputStream;
import org.bigml.mimir.deepnet.layers.Activation;
import org.bigml.mimir.deepnet.layers.twod.OutputTensor;
import org.bigml.mimir.math.Matrices;

/* loaded from: input_file:org/bigml/mimir/deepnet/layers/Dense.class */
public class Dense implements Layer {
    private final float[] _offset;
    private final float[][] _weights;
    private final Activation.ActivationFn _afn;
    private transient OutputTensor _output;
    private static final long serialVersionUID = 1;

    public Dense(String str, double[][] dArr, double[] dArr2) {
        this._afn = Activation.getActivator(str);
        this._weights = Matrices.toFloat(dArr);
        if (dArr2 != null) {
            this._offset = Matrices.toFloat(dArr2);
        } else {
            this._offset = null;
        }
        this._output = new OutputTensor(new int[]{this._weights.length});
    }

    @Override // org.bigml.mimir.deepnet.layers.Layer
    public int getOutputLength() {
        return this._weights.length;
    }

    @Override // org.bigml.mimir.deepnet.layers.Layer
    public float[] propagate(float[] fArr) {
        if (fArr.length != this._weights[0].length) {
            throw new IllegalArgumentException("Input size " + fArr.length + " != " + this._weights[0].length);
        }
        float[] fArr2 = this._output.get();
        for (int i = 0; i < this._weights.length; i++) {
            float f = 0.0f;
            for (int i2 = 0; i2 < fArr.length; i2++) {
                f += fArr[i2] * this._weights[i][i2];
            }
            fArr2[i] = f;
        }
        if (this._offset != null) {
            for (int i3 = 0; i3 < this._offset.length; i3++) {
                int i4 = i3;
                fArr2[i4] = fArr2[i4] + this._offset[i3];
            }
        }
        return (this._afn == null || this._afn.equals(Activation.ActivationFn.IDENTITY)) ? fArr2 : Activation.activate(fArr2, this._afn);
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        this._output = new OutputTensor(new int[]{getOutputLength()});
    }
}
