package de.jungblut.math.activation;

import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:de/jungblut/math/activation/ActivationFunctionSelector.class */
public enum ActivationFunctionSelector {
    LINEAR,
    LOG,
    SIGMOID,
    SOFTMAX,
    TANH,
    ELLIOT;

    private static final ActivationFunction[] FUNCTIONS = {new LinearActivationFunction(), new AbstractActivationFunction() { // from class: de.jungblut.math.activation.LogActivationFunction
        @Override // de.jungblut.math.activation.ActivationFunction
        public double apply(double d) {
            return d >= 0.0d ? Math.log(1.0d + d) : -Math.log(1.0d - d);
        }

        @Override // de.jungblut.math.activation.ActivationFunction
        public double gradient(double d) {
            return d >= 0.0d ? 1.0d / (1.0d + d) : 1.0d / (1.0d - d);
        }
    }, new SigmoidActivationFunction(), new AbstractActivationFunction() { // from class: de.jungblut.math.activation.SoftMaxActivationFunction
        @Override // de.jungblut.math.activation.ActivationFunction
        public double apply(double d) {
            return d;
        }

        @Override // de.jungblut.math.activation.AbstractActivationFunction, de.jungblut.math.activation.ActivationFunction
        public DoubleVector apply(DoubleVector doubleVector) {
            DoubleVector exp = doubleVector.subtract(doubleVector.max()).exp();
            return exp.divide(exp.sum());
        }

        @Override // de.jungblut.math.activation.AbstractActivationFunction, de.jungblut.math.activation.ActivationFunction
        public DoubleMatrix apply(DoubleMatrix doubleMatrix) {
            DoubleMatrix newInstance = newInstance(doubleMatrix);
            for (int i = 0; i < doubleMatrix.getRowCount(); i++) {
                DoubleVector apply = apply(doubleMatrix.getRowVector(i));
                if (apply.getLength() != 0) {
                    newInstance.setRowVector(i, apply);
                }
            }
            return newInstance;
        }

        @Override // de.jungblut.math.activation.ActivationFunction
        public double gradient(double d) {
            return d;
        }

        @Override // de.jungblut.math.activation.AbstractActivationFunction, de.jungblut.math.activation.ActivationFunction
        public DoubleVector gradient(DoubleVector doubleVector) {
            return doubleVector;
        }

        @Override // de.jungblut.math.activation.AbstractActivationFunction, de.jungblut.math.activation.ActivationFunction
        public DoubleMatrix gradient(DoubleMatrix doubleMatrix) {
            return doubleMatrix;
        }
    }, new AbstractActivationFunction() { // from class: de.jungblut.math.activation.TanhActivationFunction
        @Override // de.jungblut.math.activation.ActivationFunction
        public double apply(double d) {
            return FastMath.tanh(d);
        }

        @Override // de.jungblut.math.activation.ActivationFunction
        public double gradient(double d) {
            double tanh = FastMath.tanh(d);
            return 1.0d - (tanh * tanh);
        }
    }, new AbstractActivationFunction() { // from class: de.jungblut.math.activation.ElliotActivationFunction
        @Override // de.jungblut.math.activation.ActivationFunction
        public double apply(double d) {
            return elliot(d);
        }

        @Override // de.jungblut.math.activation.ActivationFunction
        public double gradient(double d) {
            return elliotGradient(d);
        }

        static double elliot(double d) {
            return ((d / 2.0d) / (1.0d + Math.abs(d))) + 0.5d;
        }

        static double elliotGradient(double d) {
            double abs = 1.0d + Math.abs(d);
            return (1.0d / (abs * abs)) * 2.0d;
        }
    }};

    public ActivationFunction get() {
        return FUNCTIONS[ordinal()];
    }
}
