package org.nd4j.linalg.api.activation;

import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.ElementWiseOp;
import org.nd4j.linalg.util.ComplexUtil;

/* loaded from: input_file:org/nd4j/linalg/api/activation/HardTanh.class */
public class HardTanh extends BaseActivationFunction {
    private static final long serialVersionUID = -8484119406683594852L;

    @Override // org.nd4j.linalg.api.activation.ActivationFunction
    public Class<? extends ElementWiseOp> transformClazz() {
        return org.nd4j.linalg.ops.transforms.HardTanh.class;
    }

    @Override // org.nd4j.linalg.api.activation.BaseActivationFunction, org.nd4j.linalg.api.activation.ActivationFunction
    public String type() {
        return "hardtanh";
    }

    @Override // org.nd4j.linalg.api.activation.ActivationFunction
    public INDArray applyDerivative(INDArray iNDArray) {
        if (iNDArray instanceof IComplexNDArray) {
            IComplexNDArray linearView = ((IComplexNDArray) iNDArray).linearView();
            for (int i = 0; i < linearView.length(); i++) {
                IComplexNumber complex = linearView.getComplex(i);
                if (complex.realComponent().doubleValue() < -1.0d) {
                    complex.set(-1, Double.valueOf(complex.imaginaryComponent().doubleValue()));
                } else if (complex.realComponent().doubleValue() > 1.0d) {
                    complex.set(1, Double.valueOf(complex.imaginaryComponent().doubleValue()));
                } else {
                    complex = Nd4j.createDouble(1.0d, 0.0d).subi(ComplexUtil.pow(ComplexUtil.tanh(complex), 2.0d));
                }
                linearView.putScalar(i, complex);
            }
        } else {
            INDArray linearView2 = iNDArray.linearView();
            for (int i2 = 0; i2 < linearView2.length(); i2++) {
                float f = linearView2.getFloat(i2);
                linearView2.putScalar(i2, f < -1.0f ? -1.0f : f > 1.0f ? 1.0f : 1.0f - ((float) Math.pow(Math.tanh(f), 2.0d)));
            }
        }
        return iNDArray;
    }
}
