diff --git a/diagrams/__init__.py b/diagrams/__init__.py index be74b2e4..152d7e54 100644 --- a/diagrams/__init__.py +++ b/diagrams/__init__.py @@ -67,28 +67,40 @@ class _Cluster: setcluster(self) return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, *args): setcluster(self._parent) - for nodeid, node in self.nodes.items(): - self.dot.node(nodeid, label=node['label'], **node['attrs']) + if not (self.nodes or self.subgraphs): + return - for dot in self.subgraphs: - self.dot.subgraph(dot) + for node in self.nodes.values(): + self.dot.node(node.nodeid, label=node.label, **node._attrs) + + for subgraph in self.subgraphs: + self.dot.subgraph(subgraph.dot) if self._parent: - self._parent.subgraph(self.dot) + self._parent.remove_node(self.nodeid) + self._parent.subgraph(self) - def node(self, nodeid: str, label: str, **attrs) -> None: + def node(self, node: "Node") -> None: """Create a new node.""" - self.nodes[nodeid] = {'label': label, 'attrs': attrs} + self.nodes[node.nodeid] = node def remove_node(self, nodeid: str) -> None: del self.nodes[nodeid] - def subgraph(self, dot: Digraph) -> None: + def subgraph(self, subgraph: "_Cluster") -> None: """Create a subgraph for clustering""" - self.subgraphs.append(dot) + self.subgraphs.append(subgraph) + + @property + def nodes_iter(self): + if self.nodes: + yield from self.nodes.values() + if self.subgraphs: + for subgraph in self.subgraphs: + yield from subgraph.nodes_iter def _validate_direction(self, direction: str): direction = direction.upper() @@ -212,20 +224,20 @@ class Diagram(_Cluster): super().__enter__() return self - def __exit__(self, exc_type, exc_value, traceback): - super().__exit__(exc_type, exc_value, traceback) + def __exit__(self, *args): + super().__exit__(*args) setdiagram(None) - for nodes, edge in self.edges.items(): - node1, node2 = nodes - nodeid1, nodeid2 = node1.nodeid, node2.nodeid - if node1.nodes: - edge._attrs['ltail'] = nodeid1 - nodeid1 = next(iter(node1.nodes.keys())) - if node2.nodes: - edge._attrs['lhead'] = nodeid2 - nodeid2 = next(iter(node2.nodes.keys())) - self.dot.edge(nodeid1, nodeid2, **edge.attrs) + for (node1, node2), edge in self.edges.items(): + cluster_node1 = next(node1.nodes_iter, None) + if cluster_node1: + edge._attrs['ltail'] = node1.nodeid + node1 = cluster_node1 + cluster_node2 = next(node2.nodes_iter, None) + if cluster_node2: + edge._attrs['lhead'] = node2.nodeid + node2 = cluster_node2 + self.dot.edge(node1.nodeid, node2.nodeid, **edge.attrs) self.render() # Remove the graphviz file leaving only the image. @@ -332,11 +344,10 @@ class Node(_Cluster): self._attrs.update(attrs) # If a node is in the cluster context, add it to cluster. - self._parent.node(self._id, self.label, **self._attrs) + self._parent.node(self) def __enter__(self): super().__enter__() - setcluster(self) # Set attributes. for k, v in self._default_graph_attrs.items(): @@ -360,17 +371,10 @@ class Node(_Cluster): return self - def __exit__(self, exc_type, exc_value, traceback): - if not (self.nodes or self.subgraphs): - return - - self._parent.remove_node(self._id) - - self._id = "cluster_" + self._id - self.dot.name = self._id - - super().__exit__(exc_type, exc_value, traceback) - + def __exit__(self, *args): + super().__exit__(*args) + self._id = "cluster_" + self.nodeid + self.dot.name = self.nodeid def __repr__(self): _name = self.__class__.__name__