package org.nd4j.autodiff.samediff.transform;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;

/* loaded from: input_file:org/nd4j/autodiff/samediff/transform/SubGraph.class */
public class SubGraph {
    protected SameDiff sameDiff;
    protected DifferentialFunction rootNode;
    protected List<DifferentialFunction> childNodes;

    /* loaded from: input_file:org/nd4j/autodiff/samediff/transform/SubGraph$SubGraphBuilder.class */
    public static class SubGraphBuilder {
        private SameDiff sameDiff;
        private DifferentialFunction rootNode;
        private List<DifferentialFunction> childNodes;

        SubGraphBuilder() {
        }

        public SubGraphBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }

        public SubGraphBuilder rootNode(DifferentialFunction differentialFunction) {
            this.rootNode = differentialFunction;
            return this;
        }

        public SubGraphBuilder childNodes(List<DifferentialFunction> list) {
            this.childNodes = list;
            return this;
        }

        public SubGraph build() {
            return new SubGraph(this.sameDiff, this.rootNode, this.childNodes);
        }

        public String toString() {
            return "SubGraph.SubGraphBuilder(sameDiff=" + this.sameDiff + ", rootNode=" + this.rootNode + ", childNodes=" + this.childNodes + ")";
        }
    }

    public List<SDVariable> outputs() {
        ArrayList<SDVariable> arrayList = new ArrayList();
        if (this.rootNode.outputVariables() != null) {
            Collections.addAll(arrayList, this.rootNode.outputVariables());
        }
        if (this.childNodes != null && !this.childNodes.isEmpty()) {
            HashSet hashSet = new HashSet();
            if (this.rootNode.args() != null) {
                Collections.addAll(hashSet, this.rootNode.args());
            }
            for (DifferentialFunction differentialFunction : this.childNodes) {
                if (differentialFunction.args() != null) {
                    Collections.addAll(hashSet, differentialFunction.args());
                }
                if (differentialFunction.outputVariables() != null) {
                    Collections.addAll(arrayList, differentialFunction.outputVariables());
                }
            }
        }
        ArrayList arrayList2 = new ArrayList(arrayList.size());
        for (SDVariable sDVariable : arrayList) {
            List<String> inputsForOp = this.sameDiff.getVariables().get(sDVariable.getVarName()).getInputsForOp();
            boolean z = true;
            if (inputsForOp != null) {
                Iterator<String> it = inputsForOp.iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    if (!inSubgraph(this.sameDiff.getOpById(it.next()))) {
                        z = false;
                        break;
                    }
                }
            }
            if (!z) {
                arrayList2.add(sDVariable);
            }
        }
        return arrayList2;
    }

    public List<SDVariable> inputs() {
        HashSet hashSet = new HashSet();
        Iterator<DifferentialFunction> it = allFunctionsInSubgraph().iterator();
        while (it.hasNext()) {
            SDVariable[] outputVariables = it.next().outputVariables();
            if (outputVariables != null) {
                Collections.addAll(hashSet, outputVariables);
            }
        }
        ArrayList arrayList = new ArrayList();
        Iterator<DifferentialFunction> it2 = allFunctionsInSubgraph().iterator();
        while (it2.hasNext()) {
            SDVariable[] args = it2.next().args();
            if (args != null) {
                for (SDVariable sDVariable : args) {
                    if (!hashSet.contains(sDVariable)) {
                        arrayList.add(sDVariable);
                    }
                }
            }
        }
        return arrayList;
    }

    public boolean inSubgraph(DifferentialFunction differentialFunction) {
        if (this.rootNode == differentialFunction) {
            return true;
        }
        if (this.childNodes == null) {
            return false;
        }
        Iterator<DifferentialFunction> it = this.childNodes.iterator();
        while (it.hasNext()) {
            if (it.next() == differentialFunction) {
                return true;
            }
        }
        return false;
    }

    public List<DifferentialFunction> allFunctionsInSubgraph() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.rootNode);
        if (this.childNodes != null) {
            arrayList.addAll(this.childNodes);
        }
        return arrayList;
    }

    public static SubGraphBuilder builder() {
        return new SubGraphBuilder();
    }

    public SubGraph(SameDiff sameDiff, DifferentialFunction differentialFunction, List<DifferentialFunction> list) {
        this.sameDiff = sameDiff;
        this.rootNode = differentialFunction;
        this.childNodes = list;
    }

    public SubGraph() {
    }

    public SameDiff getSameDiff() {
        return this.sameDiff;
    }

    public DifferentialFunction getRootNode() {
        return this.rootNode;
    }

    public List<DifferentialFunction> getChildNodes() {
        return this.childNodes;
    }

    public void setSameDiff(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
    }

    public void setRootNode(DifferentialFunction differentialFunction) {
        this.rootNode = differentialFunction;
    }

    public void setChildNodes(List<DifferentialFunction> list) {
        this.childNodes = list;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof SubGraph)) {
            return false;
        }
        SubGraph subGraph = (SubGraph) obj;
        if (!subGraph.canEqual(this)) {
            return false;
        }
        SameDiff sameDiff = getSameDiff();
        SameDiff sameDiff2 = subGraph.getSameDiff();
        if (sameDiff == null) {
            if (sameDiff2 != null) {
                return false;
            }
        } else if (!sameDiff.equals(sameDiff2)) {
            return false;
        }
        DifferentialFunction rootNode = getRootNode();
        DifferentialFunction rootNode2 = subGraph.getRootNode();
        if (rootNode == null) {
            if (rootNode2 != null) {
                return false;
            }
        } else if (!rootNode.equals(rootNode2)) {
            return false;
        }
        List<DifferentialFunction> childNodes = getChildNodes();
        List<DifferentialFunction> childNodes2 = subGraph.getChildNodes();
        return childNodes == null ? childNodes2 == null : childNodes.equals(childNodes2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof SubGraph;
    }

    public int hashCode() {
        SameDiff sameDiff = getSameDiff();
        int hashCode = (1 * 59) + (sameDiff == null ? 43 : sameDiff.hashCode());
        DifferentialFunction rootNode = getRootNode();
        int hashCode2 = (hashCode * 59) + (rootNode == null ? 43 : rootNode.hashCode());
        List<DifferentialFunction> childNodes = getChildNodes();
        return (hashCode2 * 59) + (childNodes == null ? 43 : childNodes.hashCode());
    }

    public String toString() {
        return "SubGraph(sameDiff=" + getSameDiff() + ", rootNode=" + getRootNode() + ", childNodes=" + getChildNodes() + ")";
    }
}
