package org.bigml.mimir.deepnet.layers;

import java.io.IOException;
import java.io.ObjectInputStream;
import org.bigml.mimir.deepnet.layers.twod.Layer2D;
import org.bigml.mimir.deepnet.layers.twod.OutputTensor;
import org.bigml.mimir.image.FeaturizeTest;
import org.bigml.mimir.image.WaveletTest;
import org.bigml.mimir.image.featurize.HOGFeaturizer;
import org.bigml.mimir.math.Matrices;
import org.bigml.mimir.math.Vectors;

/* loaded from: input_file:org/bigml/mimir/deepnet/layers/Activation.class */
public class Activation implements Layer, Layer2D {
    private final ActivationFn _afn;
    private int[] _inputShape;
    private int _index;
    private transient OutputTensor _output;
    private static final long serialVersionUID = 1;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.bigml.mimir.deepnet.layers.Activation$1, reason: invalid class name */
    /* loaded from: input_file:org/bigml/mimir/deepnet/layers/Activation$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$bigml$mimir$deepnet$layers$Activation$ActivationFn = new int[ActivationFn.values().length];

        static {
            try {
                $SwitchMap$org$bigml$mimir$deepnet$layers$Activation$ActivationFn[ActivationFn.RELU.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$bigml$mimir$deepnet$layers$Activation$ActivationFn[ActivationFn.SOFTPLUS.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$bigml$mimir$deepnet$layers$Activation$ActivationFn[ActivationFn.SOFTMAX.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$bigml$mimir$deepnet$layers$Activation$ActivationFn[ActivationFn.SIGMOID.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$bigml$mimir$deepnet$layers$Activation$ActivationFn[ActivationFn.TANH.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$bigml$mimir$deepnet$layers$Activation$ActivationFn[ActivationFn.SELU.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$bigml$mimir$deepnet$layers$Activation$ActivationFn[ActivationFn.RELU6.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$bigml$mimir$deepnet$layers$Activation$ActivationFn[ActivationFn.LEAKY_RELU.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$bigml$mimir$deepnet$layers$Activation$ActivationFn[ActivationFn.SWISH.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$bigml$mimir$deepnet$layers$Activation$ActivationFn[ActivationFn.MISH.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
        }
    }

    /* loaded from: input_file:org/bigml/mimir/deepnet/layers/Activation$ActivationFn.class */
    public enum ActivationFn {
        IDENTITY,
        RELU,
        SOFTPLUS,
        SOFTMAX,
        RELU6,
        SIGMOID,
        TANH,
        SELU,
        SWISH,
        MISH,
        LEAKY_RELU
    }

    public static ActivationFn getActivator(String str) {
        if (str == null) {
            return null;
        }
        String lowerCase = str.toLowerCase();
        if (lowerCase.equals("identity")) {
            return ActivationFn.IDENTITY;
        }
        if (lowerCase.equals("leaky_relu")) {
            return ActivationFn.LEAKY_RELU;
        }
        if (lowerCase.equals("linear")) {
            return ActivationFn.IDENTITY;
        }
        if (lowerCase.equals("mish")) {
            return ActivationFn.MISH;
        }
        if (lowerCase.equals("relu")) {
            return ActivationFn.RELU;
        }
        if (lowerCase.equals("relu6")) {
            return ActivationFn.RELU6;
        }
        if (lowerCase.equals("selu")) {
            return ActivationFn.SELU;
        }
        if (lowerCase.equals("sigmoid")) {
            return ActivationFn.SIGMOID;
        }
        if (lowerCase.equals("softmax")) {
            return ActivationFn.SOFTMAX;
        }
        if (lowerCase.equals("softplus")) {
            return ActivationFn.SOFTPLUS;
        }
        if (lowerCase.equals("swish")) {
            return ActivationFn.SWISH;
        }
        if (lowerCase.equals("tanh")) {
            return ActivationFn.TANH;
        }
        throw new IllegalArgumentException(String.format("Unknown activation fn '%s'", lowerCase));
    }

    public static float[] activate(float[] fArr, ActivationFn activationFn) {
        if (activationFn == null) {
            return fArr;
        }
        switch (AnonymousClass1.$SwitchMap$org$bigml$mimir$deepnet$layers$Activation$ActivationFn[activationFn.ordinal()]) {
            case 1:
                return Vectors.ReLU(fArr);
            case FeaturizeTest.LEVELS /* 2 */:
                return Vectors.softPlus(fArr);
            case WaveletTest.LEVELS /* 3 */:
                return Vectors.softmax(fArr);
            case 4:
                return Vectors.sigmoid(fArr);
            case 5:
                return Vectors.tanh(fArr);
            case 6:
                return Vectors.selu(fArr);
            case 7:
                return Vectors.relu6(fArr);
            case 8:
                return Vectors.leakyReLU(fArr);
            case HOGFeaturizer.N_BINS /* 9 */:
                return Vectors.swish(fArr);
            case 10:
                return Vectors.mish(fArr);
            default:
                return fArr;
        }
    }

    public static float[] activate(float[] fArr, float[] fArr2, ActivationFn activationFn) {
        System.arraycopy(fArr, 0, fArr2, 0, fArr2.length);
        return activate(fArr2, activationFn);
    }

    public Activation(String str) {
        this._afn = getActivator(str);
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int getIndex() {
        return this._index;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public void setIndex(int i) {
        this._index = i;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] initialize(int[] iArr) {
        this._inputShape = iArr;
        this._output = new OutputTensor(this._inputShape);
        return iArr;
    }

    @Override // org.bigml.mimir.deepnet.layers.Layer
    public int getOutputLength() {
        return getOutputShape()[0];
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] getOutputShape() {
        return this._output.getShape();
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public int[] getInputShape() {
        return this._inputShape;
    }

    public synchronized void initialize(int i) {
        if (this._output == null) {
            this._inputShape = new int[]{i};
            this._output = new OutputTensor(this._inputShape);
        }
    }

    @Override // org.bigml.mimir.deepnet.layers.Layer
    public float[] propagate(float[] fArr) {
        if (this._output == null) {
            initialize(fArr.length);
        }
        return activate(fArr, this._output.get(), this._afn);
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public float[] propagate2D(float[] fArr) {
        if (this._afn != ActivationFn.SOFTMAX) {
            return activate(fArr, this._output.get(), this._afn);
        }
        float[] fArr2 = this._output.get();
        System.arraycopy(fArr, 0, fArr2, 0, fArr2.length);
        int i = this._inputShape[this._inputShape.length - 1];
        for (int i2 = 0; i2 < fArr.length; i2 += i) {
            Vectors.softmax(fArr2, i2, i2 + i);
        }
        return fArr2;
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public double[][][] propagateArray(double[][][] dArr) {
        return Matrices.reshape(propagate2D(Matrices.unroll(dArr)), this._output.getShape());
    }

    @Override // org.bigml.mimir.deepnet.layers.twod.Layer2D
    public float[] getLastOutput() {
        return this._output.get();
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        if (this._inputShape != null) {
            this._output = new OutputTensor(this._inputShape);
        }
    }

    public ActivationFn getFunction() {
        return this._afn;
    }
}
