package org.deeplearning4j.nn.updater;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.updater.BaseUpdater;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.Adam;
import org.nd4j.linalg.learning.GradientUpdater;

/* loaded from: input_file:org/deeplearning4j/nn/updater/AdamUpdater.class */
public class AdamUpdater extends BaseUpdater {

    /* loaded from: input_file:org/deeplearning4j/nn/updater/AdamUpdater$AdamAggregator.class */
    protected static class AdamAggregator extends BaseUpdater.UpdaterAggregatorImpl {
        protected AdamAggregator() {
        }

        @Override // org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator
        public Updater getUpdater() {
            return setUpdaterState(new AdamUpdater());
        }
    }

    @Override // org.deeplearning4j.nn.updater.BaseUpdater
    public void init() {
    }

    @Override // org.deeplearning4j.nn.updater.BaseUpdater
    public GradientUpdater init(String str, INDArray iNDArray, Layer layer) {
        GradientUpdater gradientUpdater = (Adam) this.updaterForVariable.get(str);
        if (gradientUpdater == null) {
            gradientUpdater = new Adam(layer.conf().getLearningRateByParam(str), layer.conf().getLayer().getAdamMeanDecay(), layer.conf().getLayer().getAdamVarDecay());
            this.updaterForVariable.put(str, gradientUpdater);
        }
        return gradientUpdater;
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public UpdaterAggregator getAggregator(boolean z) {
        AdamAggregator adamAggregator = new AdamAggregator();
        if (z) {
            adamAggregator.aggregate(this);
        }
        return adamAggregator;
    }
}
