package org.nd4j.autodiff.samediff.config;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/nd4j/autodiff/samediff/config/BatchOutputConfig.class */
public class BatchOutputConfig {
    private SameDiff sd;

    @NonNull
    private List<String> outputs = new ArrayList();
    private Map<String, INDArray> placeholders = new HashMap();

    @NonNull
    private List<Listener> listeners = new ArrayList();

    public BatchOutputConfig(@NonNull SameDiff sameDiff) {
        if (sameDiff == null) {
            throw new NullPointerException("sd is marked @NonNull but is null");
        }
        this.sd = sameDiff;
    }

    public BatchOutputConfig output(@NonNull String... strArr) {
        if (strArr == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        this.outputs.addAll(Arrays.asList(strArr));
        return this;
    }

    public BatchOutputConfig output(@NonNull SDVariable... sDVariableArr) {
        if (sDVariableArr == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        String[] strArr = new String[sDVariableArr.length];
        for (int i = 0; i < sDVariableArr.length; i++) {
            strArr[i] = sDVariableArr[i].getVarName();
        }
        return output(strArr);
    }

    public BatchOutputConfig outputAll() {
        return output((SDVariable[]) this.sd.variables().toArray(new SDVariable[0]));
    }

    public BatchOutputConfig input(@NonNull String str, @NonNull INDArray iNDArray) {
        if (str == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        if (iNDArray == null) {
            throw new NullPointerException("placeholder is marked @NonNull but is null");
        }
        Preconditions.checkState(!this.placeholders.containsKey(str), "Placeholder for variable %s already specified", str);
        Preconditions.checkNotNull(this.sd.getVariable(str), "Variable %s does not exist in this SameDiff graph", str);
        this.placeholders.put(str, iNDArray);
        return this;
    }

    public BatchOutputConfig input(@NonNull SDVariable sDVariable, @NonNull INDArray iNDArray) {
        if (sDVariable == null) {
            throw new NullPointerException("variable is marked @NonNull but is null");
        }
        if (iNDArray == null) {
            throw new NullPointerException("placeholder is marked @NonNull but is null");
        }
        return input(sDVariable.getVarName(), iNDArray);
    }

    public BatchOutputConfig inputs(Map<String, INDArray> map) {
        if (map == null) {
            this.placeholders = null;
            return this;
        }
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            input(entry.getKey(), entry.getValue());
        }
        return this;
    }

    public BatchOutputConfig listeners(@NonNull Listener... listenerArr) {
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        this.listeners.addAll(Arrays.asList(listenerArr));
        return this;
    }

    public Map<String, INDArray> exec() {
        return this.sd.output(this.placeholders, this.listeners, (String[]) this.outputs.toArray(new String[0]));
    }

    public INDArray execSingle() {
        Preconditions.checkState(this.outputs.size() == 1, "Can only use execSingle() when exactly one output is specified, there were %s", this.outputs.size());
        return exec().get(this.outputs.get(0));
    }

    public SameDiff getSd() {
        return this.sd;
    }

    @NonNull
    public List<String> getOutputs() {
        return this.outputs;
    }

    public Map<String, INDArray> getPlaceholders() {
        return this.placeholders;
    }

    @NonNull
    public List<Listener> getListeners() {
        return this.listeners;
    }

    public void setOutputs(@NonNull List<String> list) {
        if (list == null) {
            throw new NullPointerException("outputs is marked @NonNull but is null");
        }
        this.outputs = list;
    }

    public void setPlaceholders(Map<String, INDArray> map) {
        this.placeholders = map;
    }

    public void setListeners(@NonNull List<Listener> list) {
        if (list == null) {
            throw new NullPointerException("listeners is marked @NonNull but is null");
        }
        this.listeners = list;
    }
}
