package org.nd4j.linalg.learning;

import org.apache.camel.util.URISupport;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/learning/AdaDeltaUpdater.class */
public class AdaDeltaUpdater implements GradientUpdater<AdaDelta> {
    private final AdaDelta config;
    private INDArray msg;
    private INDArray msdx;

    public AdaDeltaUpdater(AdaDelta adaDelta) {
        this.config = adaDelta;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void setStateViewArray(INDArray iNDArray, int[] iArr, char c, boolean z) {
        if (!iNDArray.isRowVector()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (z) {
            iNDArray.assign((Number) 0);
        }
        int length = iNDArray.length();
        this.msg = iNDArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
        this.msdx = iNDArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
        this.msg = Shape.newShapeNoCopy(this.msg, iArr, c == 'f');
        this.msdx = Shape.newShapeNoCopy(this.msdx, iArr, c == 'f');
        if (this.msg == null || this.msdx == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view arrays");
        }
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void applyUpdater(INDArray iNDArray, int i, int i2) {
        if (this.msg == null || this.msdx == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        double rho = this.config.getRho();
        double epsilon = this.config.getEpsilon();
        this.msg.muli(Double.valueOf(rho)).addi(iNDArray.mul(iNDArray).muli(Double.valueOf(1.0d - rho)));
        INDArray muli = iNDArray.muli(Transforms.sqrt(this.msdx.add(Double.valueOf(epsilon)), false).divi(Transforms.sqrt(this.msg.add(Double.valueOf(epsilon)), false)));
        this.msdx.muli(Double.valueOf(rho)).addi(muli.mul(muli).muli(Double.valueOf(1.0d - rho)));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.learning.GradientUpdater
    public AdaDelta getConfig() {
        return this.config;
    }

    public INDArray getMsg() {
        return this.msg;
    }

    public INDArray getMsdx() {
        return this.msdx;
    }

    public void setMsg(INDArray iNDArray) {
        this.msg = iNDArray;
    }

    public void setMsdx(INDArray iNDArray) {
        this.msdx = iNDArray;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AdaDeltaUpdater)) {
            return false;
        }
        AdaDeltaUpdater adaDeltaUpdater = (AdaDeltaUpdater) obj;
        if (!adaDeltaUpdater.canEqual(this)) {
            return false;
        }
        AdaDelta config = getConfig();
        AdaDelta config2 = adaDeltaUpdater.getConfig();
        if (config == null) {
            if (config2 != null) {
                return false;
            }
        } else if (!config.equals(config2)) {
            return false;
        }
        INDArray msg = getMsg();
        INDArray msg2 = adaDeltaUpdater.getMsg();
        if (msg == null) {
            if (msg2 != null) {
                return false;
            }
        } else if (!msg.equals(msg2)) {
            return false;
        }
        INDArray msdx = getMsdx();
        INDArray msdx2 = adaDeltaUpdater.getMsdx();
        return msdx == null ? msdx2 == null : msdx.equals(msdx2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof AdaDeltaUpdater;
    }

    public int hashCode() {
        AdaDelta config = getConfig();
        int hashCode = (1 * 59) + (config == null ? 43 : config.hashCode());
        INDArray msg = getMsg();
        int hashCode2 = (hashCode * 59) + (msg == null ? 43 : msg.hashCode());
        INDArray msdx = getMsdx();
        return (hashCode2 * 59) + (msdx == null ? 43 : msdx.hashCode());
    }

    public String toString() {
        return "AdaDeltaUpdater(config=" + getConfig() + ", msg=" + getMsg() + ", msdx=" + getMsdx() + URISupport.RAW_TOKEN_END;
    }
}
