package org.deeplearning4j.nn.updater.graph;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.updater.UpdaterCreator;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/updater/graph/ComputationGraphUpdater.class */
public class ComputationGraphUpdater implements Serializable, Cloneable {
    private final Updater[] layerUpdaters;
    private final Map<String, Integer> layerUpdatersMap;

    /* loaded from: input_file:org/deeplearning4j/nn/updater/graph/ComputationGraphUpdater$Aggregator.class */
    public static class Aggregator implements Serializable {
        private UpdaterAggregator[] aggregators;
        private Map<String, Integer> layerNamesMap;

        public void aggregate(ComputationGraphUpdater computationGraphUpdater) {
            if (this.aggregators == null) {
                this.aggregators = new UpdaterAggregator[computationGraphUpdater.layerUpdaters.length];
                for (int i = 0; i < computationGraphUpdater.layerUpdaters.length; i++) {
                    this.aggregators[i] = computationGraphUpdater.layerUpdaters[i].getAggregator(true);
                }
                this.layerNamesMap = new HashMap(computationGraphUpdater.layerUpdatersMap);
                return;
            }
            if (computationGraphUpdater.layerUpdaters == null) {
                return;
            }
            for (int i2 = 0; i2 < this.aggregators.length; i2++) {
                this.aggregators[i2].aggregate(computationGraphUpdater.layerUpdaters[i2]);
            }
        }

        public void merge(Aggregator aggregator) {
            if (this.aggregators == null) {
                this.aggregators = aggregator.aggregators;
            } else if (aggregator.aggregators != null) {
                for (int i = 0; i < this.aggregators.length; i++) {
                    this.aggregators[i].merge(aggregator.aggregators[i]);
                }
            }
        }

        public ComputationGraphUpdater getUpdater() {
            ComputationGraphUpdater computationGraphUpdater = new ComputationGraphUpdater(this.aggregators.length, this.layerNamesMap);
            for (int i = 0; i < this.aggregators.length; i++) {
                computationGraphUpdater.layerUpdaters[i] = this.aggregators[i].getUpdater();
            }
            return computationGraphUpdater;
        }
    }

    public ComputationGraphUpdater(ComputationGraph computationGraph) {
        this.layerUpdaters = new Updater[computationGraph.getNumLayers()];
        this.layerUpdatersMap = new HashMap();
        int i = 0;
        for (Layer layer : computationGraph.getLayers()) {
            this.layerUpdaters[i] = UpdaterCreator.getUpdater(layer);
            this.layerUpdatersMap.put(layer.conf().getLayer().getLayerName(), Integer.valueOf(i));
            i++;
        }
    }

    private ComputationGraphUpdater(int i, Map<String, Integer> map) {
        this.layerUpdaters = new Updater[i];
        this.layerUpdatersMap = map;
    }

    private ComputationGraphUpdater(ComputationGraphUpdater computationGraphUpdater) {
        this.layerUpdaters = new Updater[computationGraphUpdater.layerUpdaters.length];
        for (int i = 0; i < this.layerUpdaters.length; i++) {
            this.layerUpdaters[i] = computationGraphUpdater.layerUpdaters[i].m72clone();
        }
        this.layerUpdatersMap = new HashMap(computationGraphUpdater.layerUpdatersMap);
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ComputationGraphUpdater m76clone() {
        return new ComputationGraphUpdater(this);
    }

    public void update(ComputationGraph computationGraph, Gradient gradient, int i, int i2) {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            int lastIndexOf = key.lastIndexOf("_");
            if (lastIndexOf == -1) {
                throw new IllegalStateException("Invalid key: ComputationGraph Gradient key does not have layer separator: \"" + key + "\"");
            }
            String substring = key.substring(0, lastIndexOf);
            Gradient gradient2 = (Gradient) hashMap.get(substring);
            if (gradient2 == null) {
                gradient2 = new DefaultGradient();
                hashMap.put(substring, gradient2);
            }
            gradient2.setGradientFor(key.substring(lastIndexOf + 1), entry.getValue());
        }
        for (Map.Entry entry2 : hashMap.entrySet()) {
            String str = (String) entry2.getKey();
            this.layerUpdaters[this.layerUpdatersMap.get(str).intValue()].update(computationGraph.getLayer(str), (Gradient) entry2.getValue(), i, i2);
            for (Map.Entry<String, INDArray> entry3 : ((Gradient) hashMap.get(str)).gradientForVariable().entrySet()) {
                gradient.setGradientFor(((String) entry2.getKey()) + "_" + entry3.getKey(), entry3.getValue());
            }
        }
    }

    public Aggregator getAggregator(boolean z) {
        Aggregator aggregator = new Aggregator();
        if (z) {
            aggregator.aggregate(this);
        }
        return aggregator;
    }
}
