package org.nd4j.autodiff.util;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JException;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/autodiff/util/TrainingUtils.class */
public class TrainingUtils {
    public static Map<String, INDArray> stackOutputs(List<Map<String, INDArray>> list) {
        HashMap hashMap = new HashMap();
        for (Map<String, INDArray> map : list) {
            for (String str : map.keySet()) {
                if (!hashMap.containsKey(str)) {
                    hashMap.put(str, new ArrayList());
                }
                ((List) hashMap.get(str)).add(map.get(str));
            }
        }
        HashMap hashMap2 = new HashMap();
        for (String str2 : hashMap.keySet()) {
            try {
                hashMap2.put(str2, Nd4j.concat(0, (INDArray[]) ((List) hashMap.get(str2)).toArray(new INDArray[0])));
            } catch (Exception e) {
                throw new ND4JException("Error concatenating batch outputs", e);
            }
        }
        return hashMap2;
    }

    public static List<INDArray> getSingleOutput(List<Map<String, INDArray>> list, String str) {
        ArrayList arrayList = new ArrayList();
        Iterator<Map<String, INDArray>> it2 = list.iterator();
        while (it2.hasNext()) {
            arrayList.add(it2.next().get(str));
        }
        return arrayList;
    }

    private TrainingUtils() {
    }
}
