package org.deeplearning4j.iterativereduce.akka.gradient;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.nn.gradient.MultiLayerGradient;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/akka/gradient/GradientAccumulator.class */
public class GradientAccumulator {
    private List<MultiLayerGradient> workers = new ArrayList();

    public void accumulate(MultiLayerGradient multiLayerGradient) {
        this.workers.add(multiLayerGradient);
    }

    public MultiLayerGradient averaged() {
        if (this.workers.isEmpty()) {
            return null;
        }
        MultiLayerGradient multiLayerGradient = this.workers.get(0);
        List gradients = multiLayerGradient.getGradients();
        for (int i = 1; i < this.workers.size(); i++) {
            List gradients2 = this.workers.get(i).getGradients();
            for (int i2 = 0; i2 < gradients2.size(); i2++) {
                ((NeuralNetworkGradient) gradients.get(i2)).add((NeuralNetworkGradient) gradients2.get(i2));
            }
            multiLayerGradient.getLogRegGradient().add(this.workers.get(i).getLogRegGradient());
        }
        Iterator it = multiLayerGradient.getGradients().iterator();
        while (it.hasNext()) {
            ((NeuralNetworkGradient) it.next()).div(this.workers.size());
        }
        multiLayerGradient.getLogRegGradient().div(this.workers.size());
        return multiLayerGradient;
    }
}
