package org.deeplearning4j.nn.gradient;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/gradient/DefaultGradient.class */
public class DefaultGradient implements Gradient {
    private Map<String, INDArray> gradients = new LinkedHashMap();

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public Map<String, INDArray> gradientForVariable() {
        return this.gradients;
    }

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public INDArray gradient(List<String> list) {
        ArrayList arrayList = new ArrayList();
        for (String str : list) {
            if (!gradientForVariable().containsKey(str)) {
                throw new IllegalStateException("Illegal key " + str + " no gradient with key found");
            }
            arrayList.add(gradientForVariable().get(str));
        }
        return Nd4j.toFlattened(arrayList);
    }

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public INDArray gradient() {
        return Nd4j.toFlattened(this.gradients.values());
    }

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public void clear() {
        this.gradients.clear();
    }

    @Override // org.deeplearning4j.nn.gradient.Gradient
    public INDArray getGradientFor(String str) {
        return this.gradients.get(str);
    }

    public String toString() {
        return "DefaultGradient{gradients=" + this.gradients + '}';
    }
}
