package org.nd4j.linalg.learning;

import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/learning/AdaGrad.class */
public class AdaGrad implements Serializable {
    protected static final long serialVersionUID = -4754127927704099888L;
    protected double masterStepSize;
    public INDArray historicalGradient;
    public INDArray adjustedGradient;
    public double fudgeFactor;
    public INDArray gradient;
    public int[] shape;
    protected int numIterations;
    protected double lrDecay;
    protected boolean decayLr;
    protected double minLearningRate;

    public AdaGrad(int i, int i2, double d) {
        this.masterStepSize = 0.1d;
        this.fudgeFactor = 1.0E-6d;
        this.numIterations = 0;
        this.lrDecay = 0.95d;
        this.minLearningRate = 1.0E-4d;
        this.shape = new int[]{i, i2};
        createHistoricalGradient();
        createAdjustedGradient();
        this.masterStepSize = d;
        this.decayLr = false;
    }

    public AdaGrad(int[] iArr) {
        this.masterStepSize = 0.1d;
        this.fudgeFactor = 1.0E-6d;
        this.numIterations = 0;
        this.lrDecay = 0.95d;
        this.minLearningRate = 1.0E-4d;
        this.shape = iArr;
        createHistoricalGradient();
        createAdjustedGradient();
        this.masterStepSize = 0.1d;
        this.decayLr = false;
    }

    public AdaGrad(int i, int i2) {
        this(i, i2, 0.1d);
    }

    protected void createHistoricalGradient() {
        this.historicalGradient = Nd4j.create(this.shape);
    }

    protected void createAdjustedGradient() {
        this.adjustedGradient = Nd4j.create(this.shape);
    }

    public INDArray getLearningRates(INDArray iNDArray) {
        this.gradient = iNDArray;
        INDArray pow = Transforms.pow(this.gradient, (Number) 2);
        if (this.historicalGradient == null || this.historicalGradient.length() != this.gradient.length()) {
            this.historicalGradient = Nd4j.zeros(this.gradient.rows(), this.gradient.columns());
        }
        this.historicalGradient.addi(pow);
        this.numIterations++;
        this.adjustedGradient = Transforms.abs(iNDArray).divi(Transforms.sqrt(this.historicalGradient).addi(Double.valueOf(this.fudgeFactor))).muli(Double.valueOf(this.masterStepSize));
        return this.adjustedGradient;
    }

    public double getMasterStepSize() {
        return this.masterStepSize;
    }

    public void setMasterStepSize(double d) {
        this.masterStepSize = d;
    }

    public synchronized boolean isDecayLr() {
        return this.decayLr;
    }

    public synchronized void setDecayLr(boolean z) {
        this.decayLr = z;
    }
}
