package org.deeplearning4j.nn.conf.constraint;

import java.util.Collections;
import java.util.Set;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/conf/constraint/MinMaxNormConstraint.class */
public class MinMaxNormConstraint extends BaseConstraint {
    public static final double DEFAULT_RATE = 1.0d;
    private double min;
    private double max;
    private double rate;

    private MinMaxNormConstraint() {
    }

    public MinMaxNormConstraint(double d, double d2, int... iArr) {
        this(d, d2, 1.0d, null, iArr);
    }

    public MinMaxNormConstraint(double d, double d2, double d3, int... iArr) {
        this(d, d2, d3, Collections.emptySet(), iArr);
    }

    public MinMaxNormConstraint(double d, double d2, double d3, Set<String> set, int... iArr) {
        super(set, iArr);
        if (d3 <= 0.0d || d3 > 1.0d) {
            throw new IllegalStateException("Invalid rate: must be in interval (0,1]: got " + d3);
        }
        this.min = d;
        this.max = d2;
        this.rate = d3;
    }

    @Override // org.deeplearning4j.nn.conf.constraint.BaseConstraint
    public void apply(INDArray iNDArray) {
        INDArray norm2 = iNDArray.norm2(this.dimensions);
        INDArray unsafeDuplication = norm2.unsafeDuplication();
        Nd4j.getExecutioner().exec(DynamicCustomOp.builder("clipbyvalue").addInputs(unsafeDuplication).callInplace(true).addFloatingPointArguments(Double.valueOf(this.min), Double.valueOf(this.max)).build());
        norm2.addi(Double.valueOf(this.epsilon));
        unsafeDuplication.divi(norm2);
        if (this.rate != 1.0d) {
            unsafeDuplication.muli(Double.valueOf(this.rate)).addi(norm2.muli(Double.valueOf(1.0d - this.rate)));
        }
        Broadcast.mul(iNDArray, unsafeDuplication, iNDArray, getBroadcastDims(this.dimensions, iNDArray.rank()));
    }

    @Override // org.deeplearning4j.nn.conf.constraint.BaseConstraint
    /* renamed from: clone */
    public MinMaxNormConstraint mo5683clone() {
        return new MinMaxNormConstraint(this.min, this.max, this.rate, this.params, this.dimensions);
    }

    public double getMin() {
        return this.min;
    }

    public double getMax() {
        return this.max;
    }

    public double getRate() {
        return this.rate;
    }

    public void setMin(double d) {
        this.min = d;
    }

    public void setMax(double d) {
        this.max = d;
    }

    public void setRate(double d) {
        this.rate = d;
    }

    @Override // org.deeplearning4j.nn.conf.constraint.BaseConstraint
    public String toString() {
        return "MinMaxNormConstraint(min=" + getMin() + ", max=" + getMax() + ", rate=" + getRate() + ")";
    }

    @Override // org.deeplearning4j.nn.conf.constraint.BaseConstraint
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof MinMaxNormConstraint)) {
            return false;
        }
        MinMaxNormConstraint minMaxNormConstraint = (MinMaxNormConstraint) obj;
        return minMaxNormConstraint.canEqual(this) && super.equals(obj) && Double.compare(getMin(), minMaxNormConstraint.getMin()) == 0 && Double.compare(getMax(), minMaxNormConstraint.getMax()) == 0 && Double.compare(getRate(), minMaxNormConstraint.getRate()) == 0;
    }

    @Override // org.deeplearning4j.nn.conf.constraint.BaseConstraint
    protected boolean canEqual(Object obj) {
        return obj instanceof MinMaxNormConstraint;
    }

    @Override // org.deeplearning4j.nn.conf.constraint.BaseConstraint
    public int hashCode() {
        int hashCode = super.hashCode();
        long doubleToLongBits = Double.doubleToLongBits(getMin());
        int i = (hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        long doubleToLongBits2 = Double.doubleToLongBits(getMax());
        int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        long doubleToLongBits3 = Double.doubleToLongBits(getRate());
        return (i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3));
    }
}
