diff --git a/diagrams/__init__.py b/diagrams/__init__.py
index 44f016d1..be74b2e4 100644
--- a/diagrams/__init__.py
+++ b/diagrams/__init__.py
@@ -21,6 +21,7 @@ def getdiagram():
return __diagram.get()
except LookupError:
raise EnvironmentError("Global diagrams context not set up")
+ raise EnvironmentError("Global diagrams context not set up")
def setdiagram(diagram):
@@ -46,6 +47,7 @@ def new_init(cls, init):
cls.__init__ = init
return reset_init
+class _Cluster:
class _Cluster:
__directions = ("TB", "BT", "LR", "RL")
@@ -65,40 +67,28 @@ class _Cluster:
setcluster(self)
return self
- def __exit__(self, *args):
+ def __exit__(self, exc_type, exc_value, traceback):
setcluster(self._parent)
- if not (self.nodes or self.subgraphs):
- return
-
- for node in self.nodes.values():
- self.dot.node(node.nodeid, label=node.label, **node._attrs)
+ for nodeid, node in self.nodes.items():
+ self.dot.node(nodeid, label=node['label'], **node['attrs'])
- for subgraph in self.subgraphs:
- self.dot.subgraph(subgraph.dot)
+ for dot in self.subgraphs:
+ self.dot.subgraph(dot)
if self._parent:
- self._parent.remove_node(self.nodeid)
- self._parent.subgraph(self)
+ self._parent.subgraph(self.dot)
- def node(self, node: "Node") -> None:
+ def node(self, nodeid: str, label: str, **attrs) -> None:
"""Create a new node."""
- self.nodes[node.nodeid] = node
+ self.nodes[nodeid] = {'label': label, 'attrs': attrs}
def remove_node(self, nodeid: str) -> None:
del self.nodes[nodeid]
- def subgraph(self, subgraph: "_Cluster") -> None:
+ def subgraph(self, dot: Digraph) -> None:
"""Create a subgraph for clustering"""
- self.subgraphs.append(subgraph)
-
- @property
- def nodes_iter(self):
- if self.nodes:
- yield from self.nodes.values()
- if self.subgraphs:
- for subgraph in self.subgraphs:
- yield from subgraph.nodes_iter
+ self.subgraphs.append(dot)
def _validate_direction(self, direction: str):
direction = direction.upper()
@@ -176,6 +166,7 @@ class Diagram(_Cluster):
:param edge_attr: Provide edge_attr dot config attributes.
"""
+
self.name = name
if not name and not filename:
filename = "diagrams_image"
@@ -183,9 +174,8 @@ class Diagram(_Cluster):
filename = "_".join(self.name.split()).lower()
self.filename = filename
- self.dot = Digraph(self.name, filename=self.filename)
- self._nodes = {}
- self._edges = {}
+ super().__init__(self.name, filename=self.filename)
+ self.edges = {}
self.dot.attr(compound="true")
# Set attributes.
@@ -220,22 +210,21 @@ class Diagram(_Cluster):
def __enter__(self):
setdiagram(self)
super().__enter__()
- super().__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
- for nodeid, node in self._nodes.items():
- self.dot.node(nodeid, label=node['label'], **node['attrs'])
+ super().__exit__(exc_type, exc_value, traceback)
+ setdiagram(None)
- for nodes, edge in self._edges.items():
+ for nodes, edge in self.edges.items():
node1, node2 = nodes
nodeid1, nodeid2 = node1.nodeid, node2.nodeid
- if hasattr(node1, '_nodes') and node1._nodes:
+ if node1.nodes:
edge._attrs['ltail'] = nodeid1
- nodeid1 = next(iter(node1._nodes.keys()))
- if hasattr(node2, '_nodes') and node2._nodes:
+ nodeid1 = next(iter(node1.nodes.keys()))
+ if node2.nodes:
edge._attrs['lhead'] = nodeid2
- nodeid2 = next(iter(node2._nodes.keys()))
+ nodeid2 = next(iter(node2.nodes.keys()))
self.dot.edge(nodeid1, nodeid2, **edge.attrs)
self.render()
@@ -268,11 +257,7 @@ class Diagram(_Cluster):
def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None:
"""Connect the two Nodes."""
- self._edges[(node, node2)] = edge
-
- def subgraph(self, dot: Digraph) -> None:
- """Create a subgraph for clustering"""
- self.dot.subgraph(dot)
+ self.edges[(node, node2)] = edge
def render(self) -> None:
self.dot.render(format=self.outformat, view=self.show, quiet=True)
@@ -292,116 +277,6 @@ class Node(_Cluster):
"fontsize": "12",
}
- _icon = None
- _icon_size = 0
-
- # fmt: on
-
- # FIXME:
- # Cluster direction does not work now. Graphviz couldn't render
- # correctly for a subgraph that has a different rank direction.
- def __init__(
- self,
- label: str = "cluster",
- direction: str = "LR",
- icon: object = None,
- icon_size: int = 30
- ):
- """Cluster represents a cluster context.
-
- :param label: Cluster label.
- :param direction: Data flow direction. Default is 'left to right'.
- :param graph_attr: Provide graph_attr dot config attributes.
- """
- self.label = label
- self.name = "cluster_" + self.label
- if not self._icon:
- self._icon = icon
- if not self._icon_size:
- self._icon_size = icon_size
-
- self.dot = Digraph(self.name)
- self._nodes = {}
-
- # Set attributes.
- for k, v in self._default_graph_attrs.items():
- self.dot.graph_attr[k] = v
-
- # if an icon is set, try to find and instantiate a Node without calling __init__()
- # then find it's icon by calling _load_icon()
- if self._icon:
- _node = self._icon(_no_init=True)
- if isinstance(_node,Node):
- self._icon_label = '<
 + ') | ' + self.label + ' |
>'
- self.dot.graph_attr["label"] = self._icon_label
- else:
- self.dot.graph_attr["label"] = self.label
-
- if not self._validate_direction(direction):
- raise ValueError(f'"{direction}" is not a valid direction')
- self.dot.graph_attr["rankdir"] = direction
-
- # Node must be belong to a diagrams.
- self._diagram = getdiagram()
- if self._diagram is None:
- raise EnvironmentError("Global diagrams context not set up")
- self._parent = getcluster()
-
- # Set cluster depth for distinguishing the background color
- self.depth = self._parent.depth + 1 if self._parent else 0
- coloridx = self.depth % len(self.__bgcolors)
- self.dot.graph_attr["bgcolor"] = self.__bgcolors[coloridx]
-
- # Merge passed in attributes
- self.dot.graph_attr.update(graph_attr)
-
- def __enter__(self):
- setcluster(self)
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- for nodeid, node in self._nodes.items():
- self.dot.node(nodeid, label=node['label'], **node['attrs'])
-
- if self._parent:
- self._parent.subgraph(self.dot)
- else:
- self._diagram.subgraph(self.dot)
- setcluster(self._parent)
-
- def _validate_direction(self, direction: str) -> bool:
- direction = direction.upper()
- for v in self.__directions:
- if v == direction:
- return True
- return False
-
- def node(self, nodeid: str, label: str, **attrs) -> None:
- """Create a new node in the cluster."""
- self._nodes[nodeid] = {'label': label, 'attrs': attrs}
-
- def remove_node(self, nodeid: str) -> None:
- del self._nodes[nodeid]
-
- def subgraph(self, dot: Digraph) -> None:
- self.dot.subgraph(dot)
-
-
-class Node:
- """Node represents a node for a specific backend service."""
- __directions = ("TB", "BT", "LR", "RL")
- __bgcolors = ("#E5F5FD", "#EBF3E7", "#ECE8F6", "#FDF7E3")
-
- # fmt: off
- _default_graph_attrs = {
- "shape": "box",
- "style": "rounded",
- "labeljust": "l",
- "pencolor": "#AEB6BE",
- "fontname": "Sans-Serif",
- "fontsize": "12",
- }
-
_provider = None
_type = None
@@ -437,73 +312,41 @@ class Node:
"""
# Generates an ID for identifying a node.
self._id = self._rand_id()
- if isinstance(label, str):
- self.label = label
- elif isinstance(label, Sequence):
- self.label = "\n".join(label)
- else:
- self.label = str(label)
+ self.label = label
super().__init__()
- if direction:
- if not self._validate_direction(direction):
- raise ValueError(f'"{direction}" is not a valid direction')
- self._direction = direction
- if icon:
- _node = icon(_no_init=True)
- self._icon = _node._icon
- self._icon_dir = _node._icon_dir
- if icon_size:
- self._icon_size = icon_size
-
# fmt: off
# If a node has an icon, increase the height slightly to avoid
# that label being spanned between icon image and white space.
# Increase the height by the number of new lines included in the label.
- padding = 0.4 * (self.label.count('\n'))
- icon_path = self._load_icon()
+ padding = 0.4 * (label.count('\n'))
+ icon = self._load_icon()
self._attrs = {
"shape": "none",
"height": str(self._height + padding),
- "image": icon_path,
- } if icon_path else {}
-
- self._attrs['tooltip'] = (icon if icon else self).__class__.__name__
+ "image": icon,
+ } if icon else {}
# fmt: on
self._attrs.update(attrs)
# If a node is in the cluster context, add it to cluster.
- if not self._parent:
- raise EnvironmentError("Global diagrams context not set up")
- self._cluster = getcluster()
-
- # If a node is in the cluster context, add it to cluster.
- if self._cluster:
- self._cluster.node(self._id, self.label, **self._attrs)
- else:
- self._diagram.node(self._id, self.label, **self._attrs)
+ self._parent.node(self._id, self.label, **self._attrs)
def __enter__(self):
- if self._cluster:
- self._cluster.remove_node(self._id)
- else:
- self._diagram.remove_node(self._id)
-
+ super().__enter__()
setcluster(self)
- self._id = "cluster_" + self.label
- self.dot = Digraph(self._id)
- self._nodes = {}
# Set attributes.
for k, v in self._default_graph_attrs.items():
self.dot.graph_attr[k] = v
- if self._icon:
+ icon = self._load_icon()
+ if icon:
self.dot.graph_attr["label"] = '<'\
''\
- ' + ') | '\
+ '
'\
'' + self.label + ' |
>'
if not self._validate_direction(self._direction):
@@ -511,38 +354,23 @@ class Node:
self.dot.graph_attr["rankdir"] = self._direction
# Set cluster depth for distinguishing the background color
- self.depth = self._cluster.depth + 1 if self._cluster else 0
+ self.depth = self._parent.depth + 1
coloridx = self.depth % len(self.__bgcolors)
self.dot.graph_attr["bgcolor"] = self.__bgcolors[coloridx]
return self
def __exit__(self, exc_type, exc_value, traceback):
- for nodeid, node in self._nodes.items():
- self.dot.node(nodeid, label=node['label'], **node['attrs'])
+ if not (self.nodes or self.subgraphs):
+ return
- if self._cluster:
- self._cluster.subgraph(self.dot)
- else:
- self._diagram.subgraph(self.dot)
- setcluster(self._cluster)
+ self._parent.remove_node(self._id)
- def _validate_direction(self, direction: str):
- direction = direction.upper()
- for v in self.__directions:
- if v == direction:
- return True
- return False
+ self._id = "cluster_" + self._id
+ self.dot.name = self._id
- def node(self, nodeid: str, label: str, **attrs) -> None:
- """Create a new node in the cluster."""
- self._nodes[nodeid] = {'label': label, 'attrs': attrs}
+ super().__exit__(exc_type, exc_value, traceback)
- def remove_node(self, nodeid: str) -> None:
- del self._nodes[nodeid]
-
- def subgraph(self, dot: Digraph) -> None:
- self.dot.subgraph(dot)
def __repr__(self):
_name = self.__class__.__name__
@@ -632,6 +460,7 @@ class Node:
ValueError(f"{node} is not a valid Edge")
# An edge must be added on the global diagrams, not a cluster.
getdiagram().connect(self, node, edge)
+ getdiagram().connect(self, node, edge)
return node
@staticmethod
@@ -645,6 +474,32 @@ class Node:
return None
+class Cluster(Node):
+ def __init__(
+ self,
+ label: str = "",
+ direction: str = "LR",
+ icon: object = None,
+ icon_size: int = 30,
+ **attrs: Dict
+ ):
+ """Cluster represents a cluster context.
+
+ :param label: Cluster label.
+ :param direction: Data flow direction. Default is "LR" (left to right).
+ :param icon: Custom icon for tihs cluster. Must be a node class or reference.
+ :param icon_size: The icon size. Default is 30.
+ """
+ self._direction = direction
+ if icon:
+ _node = icon(_no_init=True)
+ self._icon = _node._icon
+ self._icon_dir = _node._icon_dir
+ if icon_size:
+ self._icon_size = icon_size
+ super().__init__(label, **attrs)
+
+
class Edge:
"""Edge represents an edge between two nodes."""