package de.jungblut.math.activation;

import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import de.jungblut.math.sparse.SparseDoubleVector;
import java.util.Iterator;

/* loaded from: input_file:de/jungblut/math/activation/AbstractActivationFunction.class */
public abstract class AbstractActivationFunction implements ActivationFunction {
    @Override // de.jungblut.math.activation.ActivationFunction
    public DoubleVector apply(DoubleVector doubleVector) {
        DoubleVector newInstance = newInstance(doubleVector);
        if (doubleVector.isSparse()) {
            Iterator iterateNonZero = doubleVector.iterateNonZero();
            while (iterateNonZero.hasNext()) {
                DoubleVector.DoubleVectorElement doubleVectorElement = (DoubleVector.DoubleVectorElement) iterateNonZero.next();
                newInstance.set(doubleVectorElement.getIndex(), apply(doubleVectorElement.getValue()));
            }
        } else {
            for (int i = 0; i < doubleVector.getDimension(); i++) {
                newInstance.set(i, apply(doubleVector.get(i)));
            }
        }
        return newInstance;
    }

    @Override // de.jungblut.math.activation.ActivationFunction
    public DoubleMatrix apply(DoubleMatrix doubleMatrix) {
        DoubleMatrix newInstance = newInstance(doubleMatrix);
        if (doubleMatrix.isSparse()) {
            for (int i : doubleMatrix.rowIndices()) {
                DoubleVector rowVector = doubleMatrix.getRowVector(i);
                if (rowVector.getLength() > 0) {
                    newInstance.setRowVector(i, apply(rowVector));
                }
            }
        } else {
            for (int i2 = 0; i2 < doubleMatrix.getRowCount(); i2++) {
                for (int i3 = 0; i3 < doubleMatrix.getColumnCount(); i3++) {
                    newInstance.set(i2, i3, apply(doubleMatrix.get(i2, i3)));
                }
            }
        }
        return newInstance;
    }

    @Override // de.jungblut.math.activation.ActivationFunction
    public DoubleVector gradient(DoubleVector doubleVector) {
        DoubleVector newInstance = newInstance(doubleVector);
        if (doubleVector.isSparse()) {
            Iterator iterateNonZero = doubleVector.iterateNonZero();
            while (iterateNonZero.hasNext()) {
                DoubleVector.DoubleVectorElement doubleVectorElement = (DoubleVector.DoubleVectorElement) iterateNonZero.next();
                newInstance.set(doubleVectorElement.getIndex(), gradient(doubleVectorElement.getValue()));
            }
        } else {
            for (int i = 0; i < doubleVector.getDimension(); i++) {
                newInstance.set(i, gradient(doubleVector.get(i)));
            }
        }
        return newInstance;
    }

    @Override // de.jungblut.math.activation.ActivationFunction
    public DoubleMatrix gradient(DoubleMatrix doubleMatrix) {
        DoubleMatrix newInstance = newInstance(doubleMatrix);
        if (doubleMatrix.isSparse()) {
            for (int i : doubleMatrix.columnIndices()) {
                newInstance.setColumnVector(i, gradient(doubleMatrix.getColumnVector(i)));
            }
        } else {
            for (int i2 = 0; i2 < doubleMatrix.getRowCount(); i2++) {
                for (int i3 = 0; i3 < doubleMatrix.getColumnCount(); i3++) {
                    newInstance.set(i2, i3, gradient(doubleMatrix.get(i2, i3)));
                }
            }
        }
        return newInstance;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DoubleMatrix newInstance(DoubleMatrix doubleMatrix) {
        return doubleMatrix.isSparse() ? new SparseDoubleRowMatrix(doubleMatrix.getRowCount(), doubleMatrix.getColumnCount()) : new DenseDoubleMatrix(doubleMatrix.getRowCount(), doubleMatrix.getColumnCount());
    }

    protected DoubleVector newInstance(DoubleVector doubleVector) {
        return doubleVector.isSparse() ? new SparseDoubleVector(doubleVector.getDimension()) : new DenseDoubleVector(doubleVector.getDimension());
    }

    public String toString() {
        return getClass().getSimpleName();
    }
}
