package org.nd4j.linalg.activations.impl;

import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;

@JsonIgnoreProperties({"alpha"})
/* loaded from: input_file:org/nd4j/linalg/activations/impl/ActivationRReLU.class */
public class ActivationRReLU extends BaseActivationFunction {
    public static final double DEFAULT_L = 0.125d;
    public static final double DEFAULT_U = 0.3333333333333333d;
    private double l;
    private double u;
    private transient INDArray alpha;

    public ActivationRReLU() {
        this(0.125d, 0.3333333333333333d);
    }

    public ActivationRReLU(double d, double d2) {
        if (d > d2) {
            throw new IllegalArgumentException("Cannot have lower value (" + d + ") greater than upper (" + d2 + ")");
        }
        this.l = d;
        this.u = d2;
    }

    @Override // org.nd4j.linalg.activations.IActivation
    public INDArray getActivation(INDArray iNDArray, boolean z) {
        if (!z) {
            this.alpha = null;
            return Nd4j.getExecutioner().exec((ScalarOp) new RectifiedLinear(iNDArray, 0.5d * (this.l + this.u)));
        }
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                this.alpha = Nd4j.rand(this.l, this.u, Nd4j.getRandom(), iNDArray.shape());
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                BooleanIndexing.replaceWhere(iNDArray, iNDArray.mul(this.alpha), Conditions.lessThan(0));
                return iNDArray;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.nd4j.linalg.activations.IActivation
    public Pair<INDArray, INDArray> backprop(INDArray iNDArray, INDArray iNDArray2) {
        assertShape(iNDArray, iNDArray2);
        INDArray ones = Nd4j.ones(iNDArray.shape());
        BooleanIndexing.replaceWhere(ones, this.alpha, Conditions.lessThanOrEqual(Double.valueOf(0.0d)));
        ones.muli(iNDArray2);
        return new Pair<>(ones, null);
    }

    public String toString() {
        return "rrelu(l=" + this.l + ", u=" + this.u + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ActivationRReLU)) {
            return false;
        }
        ActivationRReLU activationRReLU = (ActivationRReLU) obj;
        return activationRReLU.canEqual(this) && Double.compare(getL(), activationRReLU.getL()) == 0 && Double.compare(getU(), activationRReLU.getU()) == 0;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ActivationRReLU;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(getL());
        int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        long doubleToLongBits2 = Double.doubleToLongBits(getU());
        return (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
    }

    public double getL() {
        return this.l;
    }

    public double getU() {
        return this.u;
    }

    public INDArray getAlpha() {
        return this.alpha;
    }
}
