diff --git a/diagrams/__init__.py b/diagrams/__init__.py index 7b6d2461..6f75907b 100644 --- a/diagrams/__init__.py +++ b/diagrams/__init__.py @@ -36,6 +36,10 @@ def getcluster() -> "Cluster": def setcluster(cluster: "Cluster"): __cluster.set(cluster) +def getIconLabel(iconPath: str = '', iconSize: int = 30, label: str = ''): + return '<
' + label + '
>' class Diagram: __directions = ("TB", "BT", "LR", "RL") @@ -197,7 +201,6 @@ class Diagram: else: self.dot.render(format=self.outformat, view=self.show, quiet=True) - class Cluster: __directions = ("TB", "BT", "LR", "RL") __bgcolors = ("#E5F5FD", "#EBF3E7", "#ECE8F6", "#FDF7E3") @@ -228,18 +231,27 @@ class Cluster: :param label: Cluster label. :param direction: Data flow direction. Default is 'left to right'. :param graph_attr: Provide graph_attr dot config attributes. + :param providerIconNode: Provide a node to be included as a link + :param icon_size: The icon size """ if graph_attr is None: graph_attr = {} self.label = label self.name = "cluster_" + self.label + self.providerIconNode = providerIconNode + self.icon_size = icon_size self.dot = Digraph(self.name) - # Set attributes. + # Set attributes for k, v in self._default_graph_attrs.items(): self.dot.graph_attr[k] = v - self.dot.graph_attr["label"] = self.label + + if self.providerIconNode: + _iconNode = self.providerIconNode() #load the object + self.dot.graph_attr["label"] = getIconLabel(_iconNode._load_icon(), self.icon_size, self.label) + else: + self.dot.graph_attr["label"] = self.label if not self._validate_direction(direction): raise ValueError(f'"{direction}" is not a valid direction') @@ -273,6 +285,7 @@ class Cluster: def _validate_direction(self, direction: str) -> bool: return direction.upper() in self.__directions + def node(self, nodeid: str, label: str, **attrs) -> None: """Create a new node in the cluster.""" self.dot.node(nodeid, label=label, **attrs) @@ -292,6 +305,9 @@ class Node: _height = 1.9 + def __new__(cls, *args, **kwargs): + return object.__new__(cls) + def __init__(self, label: str = "", *, nodeid: str = None, **attrs: Dict): """Node represents a system component. diff --git a/tests/test_diagram.py b/tests/test_diagram.py index 00bdacc6..c74493a9 100644 --- a/tests/test_diagram.py +++ b/tests/test_diagram.py @@ -5,7 +5,7 @@ import pathlib from diagrams import Cluster, Diagram, Edge, Node from diagrams import getcluster, getdiagram, setcluster, setdiagram - +from diagrams.aws.network import VPC class DiagramTest(unittest.TestCase): def setUp(self): @@ -23,6 +23,13 @@ class DiagramTest(unittest.TestCase): os.remove(self.name + ".png") except FileNotFoundError: pass + + def test_cluster_icon(self): + self.name = "example-cluster-icon" + with Diagram(name=self.name, show=False): + cluster = Cluster("example_cluster_icon", providerIconNode=VPC, icon_size=30) + self.assertIsNotNone(cluster) + self.assertTrue(os.path.exists(f"{self.name}.png")) def test_validate_direction(self): # Normal directions.