package org.deeplearning4j.spark.impl.common.updater;

import org.apache.spark.api.java.function.Function2;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;

/* loaded from: input_file:org/deeplearning4j/spark/impl/common/updater/UpdaterElementCombiner.class */
public class UpdaterElementCombiner implements Function2<UpdaterAggregator, Updater, UpdaterAggregator> {
    public UpdaterAggregator call(UpdaterAggregator updaterAggregator, Updater updater) throws Exception {
        if (updaterAggregator == null && updater == null) {
            return null;
        }
        if (updaterAggregator == null) {
            return updater.getAggregator(true);
        }
        if (updater == null) {
            return updaterAggregator;
        }
        updaterAggregator.aggregate(updater);
        return updaterAggregator;
    }
}
