package org.nd4j.linalg.learning.config;

import java.util.Arrays;
import java.util.Map;
import org.apache.camel.util.URISupport;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.AdaMaxUpdater;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/nd4j/linalg/learning/config/AdaMax.class */
public class AdaMax implements IUpdater {
    public static final double DEFAULT_ADAMAX_LEARNING_RATE = 0.001d;
    public static final double DEFAULT_ADAMAX_EPSILON = 1.0E-8d;
    public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9d;
    public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999d;
    private double learningRate;
    private ISchedule learningRateSchedule;
    private double beta1;
    private double beta2;
    private double epsilon;

    /* loaded from: input_file:org/nd4j/linalg/learning/config/AdaMax$Builder.class */
    public static class Builder {
    }

    public AdaMax() {
        this(0.001d);
    }

    public AdaMax(double d) {
        this(d, null, 0.9d, 0.999d, 1.0E-8d);
    }

    public AdaMax(ISchedule iSchedule) {
        this(Double.NaN, iSchedule, 0.9d, 0.999d, 1.0E-8d);
    }

    public AdaMax(double d, double d2, double d3, double d4) {
        this(d, null, d2, d3, d4);
    }

    private AdaMax(@JsonProperty("learningRate") double d, @JsonProperty("learningRateSchedule") ISchedule iSchedule, @JsonProperty("beta1") double d2, @JsonProperty("beta2") double d3, @JsonProperty("epsilon") double d4) {
        this.learningRate = 0.001d;
        this.beta1 = 0.9d;
        this.beta2 = 0.999d;
        this.epsilon = 1.0E-8d;
        this.learningRate = d;
        this.learningRateSchedule = iSchedule;
        this.beta1 = d2;
        this.beta2 = d3;
        this.epsilon = d4;
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    public long stateSize(long j) {
        return 2 * j;
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    public GradientUpdater instantiate(INDArray iNDArray, boolean z) {
        AdaMaxUpdater adaMaxUpdater = new AdaMaxUpdater(this);
        long[] shape = iNDArray.shape();
        long[] copyOf = Arrays.copyOf(shape, shape.length);
        copyOf[1] = copyOf[1] / 2;
        adaMaxUpdater.setStateViewArray(iNDArray, copyOf, iNDArray.ordering(), z);
        return adaMaxUpdater;
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    public GradientUpdater instantiate(Map<String, INDArray> map, boolean z) {
        AdaMaxUpdater adaMaxUpdater = new AdaMaxUpdater(this);
        adaMaxUpdater.setState(map, z);
        return adaMaxUpdater;
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public IUpdater m6290clone() {
        return new AdaMax(this.learningRate, this.learningRateSchedule, this.beta1, this.beta2, this.epsilon);
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    public double getLearningRate(int i, int i2) {
        return this.learningRateSchedule != null ? this.learningRateSchedule.valueAt(i, i2) : this.learningRate;
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    public boolean hasLearningRate() {
        return true;
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    public void setLrAndSchedule(double d, ISchedule iSchedule) {
        this.learningRate = d;
        this.learningRateSchedule = iSchedule;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public ISchedule getLearningRateSchedule() {
        return this.learningRateSchedule;
    }

    public double getBeta1() {
        return this.beta1;
    }

    public double getBeta2() {
        return this.beta2;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public void setLearningRateSchedule(ISchedule iSchedule) {
        this.learningRateSchedule = iSchedule;
    }

    public void setBeta1(double d) {
        this.beta1 = d;
    }

    public void setBeta2(double d) {
        this.beta2 = d;
    }

    public void setEpsilon(double d) {
        this.epsilon = d;
    }

    @Override // org.nd4j.linalg.learning.config.IUpdater
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AdaMax)) {
            return false;
        }
        AdaMax adaMax = (AdaMax) obj;
        if (!adaMax.canEqual(this) || Double.compare(getLearningRate(), adaMax.getLearningRate()) != 0) {
            return false;
        }
        ISchedule learningRateSchedule = getLearningRateSchedule();
        ISchedule learningRateSchedule2 = adaMax.getLearningRateSchedule();
        if (learningRateSchedule == null) {
            if (learningRateSchedule2 != null) {
                return false;
            }
        } else if (!learningRateSchedule.equals(learningRateSchedule2)) {
            return false;
        }
        return Double.compare(getBeta1(), adaMax.getBeta1()) == 0 && Double.compare(getBeta2(), adaMax.getBeta2()) == 0 && Double.compare(getEpsilon(), adaMax.getEpsilon()) == 0;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof AdaMax;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(getLearningRate());
        int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        ISchedule learningRateSchedule = getLearningRateSchedule();
        int hashCode = (i * 59) + (learningRateSchedule == null ? 43 : learningRateSchedule.hashCode());
        long doubleToLongBits2 = Double.doubleToLongBits(getBeta1());
        int i2 = (hashCode * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        long doubleToLongBits3 = Double.doubleToLongBits(getBeta2());
        int i3 = (i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3));
        long doubleToLongBits4 = Double.doubleToLongBits(getEpsilon());
        return (i3 * 59) + ((int) ((doubleToLongBits4 >>> 32) ^ doubleToLongBits4));
    }

    public String toString() {
        return "AdaMax(learningRate=" + getLearningRate() + ", learningRateSchedule=" + getLearningRateSchedule() + ", beta1=" + getBeta1() + ", beta2=" + getBeta2() + ", epsilon=" + getEpsilon() + URISupport.RAW_TOKEN_END;
    }
}
