package org.neo4j.gds.ml.metrics;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryUsage;

/* loaded from: input_file:org/neo4j/gds/ml/metrics/StatsMap.class */
public final class StatsMap {
    private final Map<Metric, List<ModelStats>> map;

    public static MemoryEstimation memoryEstimation(int i, int i2) {
        return memoryEstimation(i, i2, 1000);
    }

    public static MemoryEstimation memoryEstimation(int i, int i2, int i3) {
        return MemoryEstimations.builder(StatsMap.class).fixed("array list", MemoryUsage.sizeOfInstance(ArrayList.class)).fixed("model stats", MemoryUsage.sizeOfInstance(ImmutableModelStats.class) * i * i3 * i2).build();
    }

    public static StatsMap create(List<Metric> list) {
        HashMap hashMap = new HashMap();
        list.forEach(metric -> {
            hashMap.put(metric, new ArrayList());
        });
        return new StatsMap(hashMap);
    }

    private StatsMap(Map<Metric, List<ModelStats>> map) {
        this.map = map;
    }

    public void add(Metric metric, ModelStats modelStats) {
        this.map.get(metric).add(modelStats);
    }

    public List<ModelStats> getMetricStats(Metric metric) {
        return this.map.get(metric);
    }

    public Map<String, List<Map<String, Object>>> toMap() {
        return (Map) this.map.entrySet().stream().collect(Collectors.toMap(entry -> {
            return ((Metric) entry.getKey()).name();
        }, entry2 -> {
            return (List) ((List) entry2.getValue()).stream().map((v0) -> {
                return v0.toMap();
            }).collect(Collectors.toList());
        }));
    }
}
