package org.nd4j.linalg.learning;

import org.apache.camel.util.URISupport;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/learning/AdamUpdater.class */
public class AdamUpdater implements GradientUpdater<Adam> {
    private Adam config;
    private INDArray m;
    private INDArray v;
    private char gradientReshapeOrder;

    public AdamUpdater(Adam adam) {
        this.config = adam;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void setStateViewArray(INDArray iNDArray, int[] iArr, char c, boolean z) {
        if (!iNDArray.isRowVector()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (z) {
            iNDArray.assign((Number) 0);
        }
        int length = iNDArray.length();
        this.m = iNDArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
        this.v = iNDArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
        this.m = Shape.newShapeNoCopy(this.m, iArr, c == 'f');
        this.v = Shape.newShapeNoCopy(this.v, iArr, c == 'f');
        if (this.m == null || this.v == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view arrays");
        }
        this.gradientReshapeOrder = c;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void applyUpdater(INDArray iNDArray, int i, int i2) {
        if (this.m == null || this.v == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        double beta1 = this.config.getBeta1();
        double beta2 = this.config.getBeta2();
        double learningRate = this.config.getLearningRate(i, i2);
        double epsilon = this.config.getEpsilon();
        this.m.muli(Double.valueOf(beta1)).addi(iNDArray.mul(Double.valueOf(1.0d - beta1)));
        this.v.muli(Double.valueOf(beta2)).addi(iNDArray.mul(iNDArray).muli(Double.valueOf(1.0d - beta2)));
        double sqrt = (learningRate * FastMath.sqrt(1.0d - FastMath.pow(beta2, i + 1))) / (1.0d - FastMath.pow(beta1, i + 1));
        if (Double.isNaN(sqrt) || sqrt == 0.0d) {
            sqrt = epsilon;
        }
        iNDArray.assign(this.m).muli(Double.valueOf(sqrt)).divi(Transforms.sqrt(this.v.dup(this.gradientReshapeOrder), false).addi(Double.valueOf(epsilon)));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.learning.GradientUpdater
    public Adam getConfig() {
        return this.config;
    }

    public INDArray getM() {
        return this.m;
    }

    public INDArray getV() {
        return this.v;
    }

    public char getGradientReshapeOrder() {
        return this.gradientReshapeOrder;
    }

    public void setConfig(Adam adam) {
        this.config = adam;
    }

    public void setM(INDArray iNDArray) {
        this.m = iNDArray;
    }

    public void setV(INDArray iNDArray) {
        this.v = iNDArray;
    }

    public void setGradientReshapeOrder(char c) {
        this.gradientReshapeOrder = c;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AdamUpdater)) {
            return false;
        }
        AdamUpdater adamUpdater = (AdamUpdater) obj;
        if (!adamUpdater.canEqual(this)) {
            return false;
        }
        Adam config = getConfig();
        Adam config2 = adamUpdater.getConfig();
        if (config == null) {
            if (config2 != null) {
                return false;
            }
        } else if (!config.equals(config2)) {
            return false;
        }
        INDArray m = getM();
        INDArray m2 = adamUpdater.getM();
        if (m == null) {
            if (m2 != null) {
                return false;
            }
        } else if (!m.equals(m2)) {
            return false;
        }
        INDArray v = getV();
        INDArray v2 = adamUpdater.getV();
        if (v == null) {
            if (v2 != null) {
                return false;
            }
        } else if (!v.equals(v2)) {
            return false;
        }
        return getGradientReshapeOrder() == adamUpdater.getGradientReshapeOrder();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof AdamUpdater;
    }

    public int hashCode() {
        Adam config = getConfig();
        int hashCode = (1 * 59) + (config == null ? 43 : config.hashCode());
        INDArray m = getM();
        int hashCode2 = (hashCode * 59) + (m == null ? 43 : m.hashCode());
        INDArray v = getV();
        return (((hashCode2 * 59) + (v == null ? 43 : v.hashCode())) * 59) + getGradientReshapeOrder();
    }

    public String toString() {
        return "AdamUpdater(config=" + getConfig() + ", m=" + getM() + ", v=" + getV() + ", gradientReshapeOrder=" + getGradientReshapeOrder() + URISupport.RAW_TOKEN_END;
    }
}
