diff --git a/diagrams/__init__.py b/diagrams/__init__.py index 44f016d1..be74b2e4 100644 --- a/diagrams/__init__.py +++ b/diagrams/__init__.py @@ -21,6 +21,7 @@ def getdiagram(): return __diagram.get() except LookupError: raise EnvironmentError("Global diagrams context not set up") + raise EnvironmentError("Global diagrams context not set up") def setdiagram(diagram): @@ -46,6 +47,7 @@ def new_init(cls, init): cls.__init__ = init return reset_init +class _Cluster: class _Cluster: __directions = ("TB", "BT", "LR", "RL") @@ -65,40 +67,28 @@ class _Cluster: setcluster(self) return self - def __exit__(self, *args): + def __exit__(self, exc_type, exc_value, traceback): setcluster(self._parent) - if not (self.nodes or self.subgraphs): - return - - for node in self.nodes.values(): - self.dot.node(node.nodeid, label=node.label, **node._attrs) + for nodeid, node in self.nodes.items(): + self.dot.node(nodeid, label=node['label'], **node['attrs']) - for subgraph in self.subgraphs: - self.dot.subgraph(subgraph.dot) + for dot in self.subgraphs: + self.dot.subgraph(dot) if self._parent: - self._parent.remove_node(self.nodeid) - self._parent.subgraph(self) + self._parent.subgraph(self.dot) - def node(self, node: "Node") -> None: + def node(self, nodeid: str, label: str, **attrs) -> None: """Create a new node.""" - self.nodes[node.nodeid] = node + self.nodes[nodeid] = {'label': label, 'attrs': attrs} def remove_node(self, nodeid: str) -> None: del self.nodes[nodeid] - def subgraph(self, subgraph: "_Cluster") -> None: + def subgraph(self, dot: Digraph) -> None: """Create a subgraph for clustering""" - self.subgraphs.append(subgraph) - - @property - def nodes_iter(self): - if self.nodes: - yield from self.nodes.values() - if self.subgraphs: - for subgraph in self.subgraphs: - yield from subgraph.nodes_iter + self.subgraphs.append(dot) def _validate_direction(self, direction: str): direction = direction.upper() @@ -176,6 +166,7 @@ class Diagram(_Cluster): :param edge_attr: Provide edge_attr dot config attributes. """ + self.name = name if not name and not filename: filename = "diagrams_image" @@ -183,9 +174,8 @@ class Diagram(_Cluster): filename = "_".join(self.name.split()).lower() self.filename = filename - self.dot = Digraph(self.name, filename=self.filename) - self._nodes = {} - self._edges = {} + super().__init__(self.name, filename=self.filename) + self.edges = {} self.dot.attr(compound="true") # Set attributes. @@ -220,22 +210,21 @@ class Diagram(_Cluster): def __enter__(self): setdiagram(self) super().__enter__() - super().__enter__() return self def __exit__(self, exc_type, exc_value, traceback): - for nodeid, node in self._nodes.items(): - self.dot.node(nodeid, label=node['label'], **node['attrs']) + super().__exit__(exc_type, exc_value, traceback) + setdiagram(None) - for nodes, edge in self._edges.items(): + for nodes, edge in self.edges.items(): node1, node2 = nodes nodeid1, nodeid2 = node1.nodeid, node2.nodeid - if hasattr(node1, '_nodes') and node1._nodes: + if node1.nodes: edge._attrs['ltail'] = nodeid1 - nodeid1 = next(iter(node1._nodes.keys())) - if hasattr(node2, '_nodes') and node2._nodes: + nodeid1 = next(iter(node1.nodes.keys())) + if node2.nodes: edge._attrs['lhead'] = nodeid2 - nodeid2 = next(iter(node2._nodes.keys())) + nodeid2 = next(iter(node2.nodes.keys())) self.dot.edge(nodeid1, nodeid2, **edge.attrs) self.render() @@ -268,11 +257,7 @@ class Diagram(_Cluster): def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None: """Connect the two Nodes.""" - self._edges[(node, node2)] = edge - - def subgraph(self, dot: Digraph) -> None: - """Create a subgraph for clustering""" - self.dot.subgraph(dot) + self.edges[(node, node2)] = edge def render(self) -> None: self.dot.render(format=self.outformat, view=self.show, quiet=True) @@ -292,116 +277,6 @@ class Node(_Cluster): "fontsize": "12", } - _icon = None - _icon_size = 0 - - # fmt: on - - # FIXME: - # Cluster direction does not work now. Graphviz couldn't render - # correctly for a subgraph that has a different rank direction. - def __init__( - self, - label: str = "cluster", - direction: str = "LR", - icon: object = None, - icon_size: int = 30 - ): - """Cluster represents a cluster context. - - :param label: Cluster label. - :param direction: Data flow direction. Default is 'left to right'. - :param graph_attr: Provide graph_attr dot config attributes. - """ - self.label = label - self.name = "cluster_" + self.label - if not self._icon: - self._icon = icon - if not self._icon_size: - self._icon_size = icon_size - - self.dot = Digraph(self.name) - self._nodes = {} - - # Set attributes. - for k, v in self._default_graph_attrs.items(): - self.dot.graph_attr[k] = v - - # if an icon is set, try to find and instantiate a Node without calling __init__() - # then find it's icon by calling _load_icon() - if self._icon: - _node = self._icon(_no_init=True) - if isinstance(_node,Node): - self._icon_label = '<
' + self.label + '
>' - self.dot.graph_attr["label"] = self._icon_label - else: - self.dot.graph_attr["label"] = self.label - - if not self._validate_direction(direction): - raise ValueError(f'"{direction}" is not a valid direction') - self.dot.graph_attr["rankdir"] = direction - - # Node must be belong to a diagrams. - self._diagram = getdiagram() - if self._diagram is None: - raise EnvironmentError("Global diagrams context not set up") - self._parent = getcluster() - - # Set cluster depth for distinguishing the background color - self.depth = self._parent.depth + 1 if self._parent else 0 - coloridx = self.depth % len(self.__bgcolors) - self.dot.graph_attr["bgcolor"] = self.__bgcolors[coloridx] - - # Merge passed in attributes - self.dot.graph_attr.update(graph_attr) - - def __enter__(self): - setcluster(self) - return self - - def __exit__(self, exc_type, exc_value, traceback): - for nodeid, node in self._nodes.items(): - self.dot.node(nodeid, label=node['label'], **node['attrs']) - - if self._parent: - self._parent.subgraph(self.dot) - else: - self._diagram.subgraph(self.dot) - setcluster(self._parent) - - def _validate_direction(self, direction: str) -> bool: - direction = direction.upper() - for v in self.__directions: - if v == direction: - return True - return False - - def node(self, nodeid: str, label: str, **attrs) -> None: - """Create a new node in the cluster.""" - self._nodes[nodeid] = {'label': label, 'attrs': attrs} - - def remove_node(self, nodeid: str) -> None: - del self._nodes[nodeid] - - def subgraph(self, dot: Digraph) -> None: - self.dot.subgraph(dot) - - -class Node: - """Node represents a node for a specific backend service.""" - __directions = ("TB", "BT", "LR", "RL") - __bgcolors = ("#E5F5FD", "#EBF3E7", "#ECE8F6", "#FDF7E3") - - # fmt: off - _default_graph_attrs = { - "shape": "box", - "style": "rounded", - "labeljust": "l", - "pencolor": "#AEB6BE", - "fontname": "Sans-Serif", - "fontsize": "12", - } - _provider = None _type = None @@ -437,73 +312,41 @@ class Node: """ # Generates an ID for identifying a node. self._id = self._rand_id() - if isinstance(label, str): - self.label = label - elif isinstance(label, Sequence): - self.label = "\n".join(label) - else: - self.label = str(label) + self.label = label super().__init__() - if direction: - if not self._validate_direction(direction): - raise ValueError(f'"{direction}" is not a valid direction') - self._direction = direction - if icon: - _node = icon(_no_init=True) - self._icon = _node._icon - self._icon_dir = _node._icon_dir - if icon_size: - self._icon_size = icon_size - # fmt: off # If a node has an icon, increase the height slightly to avoid # that label being spanned between icon image and white space. # Increase the height by the number of new lines included in the label. - padding = 0.4 * (self.label.count('\n')) - icon_path = self._load_icon() + padding = 0.4 * (label.count('\n')) + icon = self._load_icon() self._attrs = { "shape": "none", "height": str(self._height + padding), - "image": icon_path, - } if icon_path else {} - - self._attrs['tooltip'] = (icon if icon else self).__class__.__name__ + "image": icon, + } if icon else {} # fmt: on self._attrs.update(attrs) # If a node is in the cluster context, add it to cluster. - if not self._parent: - raise EnvironmentError("Global diagrams context not set up") - self._cluster = getcluster() - - # If a node is in the cluster context, add it to cluster. - if self._cluster: - self._cluster.node(self._id, self.label, **self._attrs) - else: - self._diagram.node(self._id, self.label, **self._attrs) + self._parent.node(self._id, self.label, **self._attrs) def __enter__(self): - if self._cluster: - self._cluster.remove_node(self._id) - else: - self._diagram.remove_node(self._id) - + super().__enter__() setcluster(self) - self._id = "cluster_" + self.label - self.dot = Digraph(self._id) - self._nodes = {} # Set attributes. for k, v in self._default_graph_attrs.items(): self.dot.graph_attr[k] = v - if self._icon: + icon = self._load_icon() + if icon: self.dot.graph_attr["label"] = '<'\ ''\ + ''\ '
'\ - '' + self.label + '
>' if not self._validate_direction(self._direction): @@ -511,38 +354,23 @@ class Node: self.dot.graph_attr["rankdir"] = self._direction # Set cluster depth for distinguishing the background color - self.depth = self._cluster.depth + 1 if self._cluster else 0 + self.depth = self._parent.depth + 1 coloridx = self.depth % len(self.__bgcolors) self.dot.graph_attr["bgcolor"] = self.__bgcolors[coloridx] return self def __exit__(self, exc_type, exc_value, traceback): - for nodeid, node in self._nodes.items(): - self.dot.node(nodeid, label=node['label'], **node['attrs']) + if not (self.nodes or self.subgraphs): + return - if self._cluster: - self._cluster.subgraph(self.dot) - else: - self._diagram.subgraph(self.dot) - setcluster(self._cluster) + self._parent.remove_node(self._id) - def _validate_direction(self, direction: str): - direction = direction.upper() - for v in self.__directions: - if v == direction: - return True - return False + self._id = "cluster_" + self._id + self.dot.name = self._id - def node(self, nodeid: str, label: str, **attrs) -> None: - """Create a new node in the cluster.""" - self._nodes[nodeid] = {'label': label, 'attrs': attrs} + super().__exit__(exc_type, exc_value, traceback) - def remove_node(self, nodeid: str) -> None: - del self._nodes[nodeid] - - def subgraph(self, dot: Digraph) -> None: - self.dot.subgraph(dot) def __repr__(self): _name = self.__class__.__name__ @@ -632,6 +460,7 @@ class Node: ValueError(f"{node} is not a valid Edge") # An edge must be added on the global diagrams, not a cluster. getdiagram().connect(self, node, edge) + getdiagram().connect(self, node, edge) return node @staticmethod @@ -645,6 +474,32 @@ class Node: return None +class Cluster(Node): + def __init__( + self, + label: str = "", + direction: str = "LR", + icon: object = None, + icon_size: int = 30, + **attrs: Dict + ): + """Cluster represents a cluster context. + + :param label: Cluster label. + :param direction: Data flow direction. Default is "LR" (left to right). + :param icon: Custom icon for tihs cluster. Must be a node class or reference. + :param icon_size: The icon size. Default is 30. + """ + self._direction = direction + if icon: + _node = icon(_no_init=True) + self._icon = _node._icon + self._icon_dir = _node._icon_dir + if icon_size: + self._icon_size = icon_size + super().__init__(label, **attrs) + + class Edge: """Edge represents an edge between two nodes."""