diff --git a/diagrams/__init__.py b/diagrams/__init__.py index c307cb2f..1f953e6d 100644 --- a/diagrams/__init__.py +++ b/diagrams/__init__.py @@ -60,28 +60,40 @@ class _Cluster: setcluster(self) return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, *args): setcluster(self._parent) - for nodeid, node in self.nodes.items(): - self.dot.node(nodeid, label=node['label'], **node['attrs']) + if not (self.nodes or self.subgraphs): + return - for dot in self.subgraphs: - self.dot.subgraph(dot) + for node in self.nodes.values(): + self.dot.node(node.nodeid, label=node.label, **node._attrs) + + for subgraph in self.subgraphs: + self.dot.subgraph(subgraph.dot) if self._parent: - self._parent.subgraph(self.dot) + self._parent.remove_node(self.nodeid) + self._parent.subgraph(self) - def node(self, nodeid: str, label: str, **attrs) -> None: + def node(self, node: "Node") -> None: """Create a new node.""" - self.nodes[nodeid] = {'label': label, 'attrs': attrs} + self.nodes[node.nodeid] = node def remove_node(self, nodeid: str) -> None: del self.nodes[nodeid] - def subgraph(self, dot: Digraph) -> None: + def subgraph(self, subgraph: "_Cluster") -> None: """Create a subgraph for clustering""" - self.subgraphs.append(dot) + self.subgraphs.append(subgraph) + + @property + def nodes_iter(self): + if self.nodes: + yield from self.nodes.values() + if self.subgraphs: + for subgraph in self.subgraphs: + yield from subgraph.nodes_iter def _validate_direction(self, direction: str): direction = direction.upper() @@ -202,21 +214,21 @@ class Diagram(_Cluster): setdiagram(self) super().__enter__() return self - - def __exit__(self, exc_type, exc_value, traceback): - super().__exit__(exc_type, exc_value, traceback) + + def __exit__(self, *args): + super().__exit__(*args) setdiagram(None) - for nodes, edge in self.edges.items(): - node1, node2 = nodes - nodeid1, nodeid2 = node1.nodeid, node2.nodeid - if node1.nodes: - edge._attrs['ltail'] = nodeid1 - nodeid1 = next(iter(node1.nodes.keys())) - if node2.nodes: - edge._attrs['lhead'] = nodeid2 - nodeid2 = next(iter(node2.nodes.keys())) - self.dot.edge(nodeid1, nodeid2, **edge.attrs) + for (node1, node2), edge in self.edges.items(): + cluster_node1 = next(node1.nodes_iter, None) + if cluster_node1: + edge._attrs['ltail'] = node1.nodeid + node1 = cluster_node1 + cluster_node2 = next(node2.nodes_iter, None) + if cluster_node2: + edge._attrs['lhead'] = node2.nodeid + node2 = cluster_node2 + self.dot.edge(node1.nodeid, node2.nodeid, **edge.attrs) self.render() # Remove the graphviz file leaving only the image. @@ -322,11 +334,10 @@ class Node(_Cluster): self._attrs.update(attrs) # If a node is in the cluster context, add it to cluster. - self._parent.node(self._id, self.label, **self._attrs) + self._parent.node(self) def __enter__(self): super().__enter__() - setcluster(self) # Set attributes. for k, v in self._default_graph_attrs.items(): @@ -350,17 +361,10 @@ class Node(_Cluster): return self - def __exit__(self, exc_type, exc_value, traceback): - if not (self.nodes or self.subgraphs): - return - - self._parent.remove_node(self._id) - - self._id = "cluster_" + self._id - self.dot.name = self._id - - super().__exit__(exc_type, exc_value, traceback) - + def __exit__(self, *args): + super().__exit__(*args) + self._id = "cluster_" + self.nodeid + self.dot.name = self.nodeid def __repr__(self): _name = self.__class__.__name__ @@ -435,7 +439,7 @@ class Node(_Cluster): @property def nodeid(self): return self._id - + # TODO: option for adding flow description to the connection edge def connect(self, node: "Node", edge: "Edge"): """Connect to other node.