pull/823/merge
Rob Lazzurs 3 years ago committed by GitHub
commit 9289e7eb41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,8 +1,9 @@
import contextvars import contextvars
import html
import os import os
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import List, Union, Dict from typing import List, Union, Dict, Sequence
from graphviz import Digraph from graphviz import Digraph
@ -36,9 +37,77 @@ def getcluster() -> "Cluster":
def setcluster(cluster: "Cluster"): def setcluster(cluster: "Cluster"):
__cluster.set(cluster) __cluster.set(cluster)
def new_init(cls, init):
def reset_init(*args, **kwargs):
cls.__init__ = init
return reset_init
class Diagram: class _Cluster:
__directions = ("TB", "BT", "LR", "RL") __directions = ("TB", "BT", "LR", "RL")
def __init__(self, name=None, **kwargs):
self.dot = Digraph(name, **kwargs)
self.depth = 0
self.nodes = {}
self.subgraphs = []
try:
self._parent = getcluster() or getdiagram()
except LookupError:
self._parent = None
def __enter__(self):
setcluster(self)
return self
def __exit__(self, *args):
setcluster(self._parent)
if not (self.nodes or self.subgraphs):
return
for node in self.nodes.values():
self.dot.node(node.nodeid, label=node.label, **node._attrs)
for subgraph in self.subgraphs:
self.dot.subgraph(subgraph.dot)
if self._parent:
self._parent.remove_node(self.nodeid)
self._parent.subgraph(self)
def node(self, node: "Node") -> None:
"""Create a new node."""
self.nodes[node.nodeid] = node
def remove_node(self, nodeid: str) -> None:
del self.nodes[nodeid]
def subgraph(self, subgraph: "_Cluster") -> None:
"""Create a subgraph for clustering"""
self.subgraphs.append(subgraph)
@property
def nodes_iter(self):
if self.nodes:
yield from self.nodes.values()
if self.subgraphs:
for subgraph in self.subgraphs:
yield from subgraph.nodes_iter
def _validate_direction(self, direction: str):
direction = direction.upper()
for v in self.__directions:
if v == direction:
return True
return False
def __str__(self) -> str:
return str(self.dot)
class Diagram(_Cluster):
__curvestyles = ("ortho", "curved") __curvestyles = ("ortho", "curved")
__outformats = ("png", "jpg", "svg", "pdf", "dot") __outformats = ("png", "jpg", "svg", "pdf", "dot")
@ -105,15 +174,20 @@ class Diagram:
:param edge_attr: Provide edge_attr dot config attributes. :param edge_attr: Provide edge_attr dot config attributes.
:param strict: Rendering should merge multi-edges. :param strict: Rendering should merge multi-edges.
""" """
self.name = name self.name = name
if not name and not filename: if not name and not filename:
filename = "diagrams_image" filename = "diagrams_image"
elif not filename: elif not filename:
filename = "_".join(self.name.split()).lower() filename = "_".join(self.name.split()).lower()
self.filename = filename self.filename = filename
super().__init__(self.name, filename=self.filename)
self.edges = {}
self.dot = Digraph(self.name, filename=self.filename, strict=strict) self.dot = Digraph(self.name, filename=self.filename, strict=strict)
# Set attributes. # Set attributes.
self.dot.attr(compound="true")
for k, v in self._default_graph_attrs.items(): for k, v in self._default_graph_attrs.items():
self.dot.graph_attr[k] = v self.dot.graph_attr[k] = v
self.dot.graph_attr["label"] = self.name self.dot.graph_attr["label"] = self.name
@ -147,18 +221,29 @@ class Diagram:
self.show = show self.show = show
self.autolabel = autolabel self.autolabel = autolabel
def __str__(self) -> str:
return str(self.dot)
def __enter__(self): def __enter__(self):
setdiagram(self) setdiagram(self)
super().__enter__()
return self return self
def __exit__(self, *args):
super().__exit__(*args)
setdiagram(None)
for (node1, node2), edge in self.edges.items():
cluster_node1 = next(node1.nodes_iter, None)
if cluster_node1:
edge._attrs['ltail'] = node1.nodeid
node1 = cluster_node1
cluster_node2 = next(node2.nodes_iter, None)
if cluster_node2:
edge._attrs['lhead'] = node2.nodeid
node2 = cluster_node2
self.dot.edge(node1.nodeid, node2.nodeid, **edge.attrs)
def __exit__(self, exc_type, exc_value, traceback):
self.render() self.render()
# Remove the graphviz file leaving only the image. # Remove the graphviz file leaving only the image.
os.remove(self.filename) os.remove(self.filename)
setdiagram(None)
def _repr_png_(self): def _repr_png_(self):
return self.dot.pipe(format="png") return self.dot.pipe(format="png")
@ -172,17 +257,9 @@ class Diagram:
def _validate_outformat(self, outformat: str) -> bool: def _validate_outformat(self, outformat: str) -> bool:
return outformat.lower() in self.__outformats return outformat.lower() in self.__outformats
def node(self, nodeid: str, label: str, **attrs) -> None:
"""Create a new node."""
self.dot.node(nodeid, label=label, **attrs)
def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None: def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None:
"""Connect the two Nodes.""" """Connect the two Nodes."""
self.dot.edge(node.nodeid, node2.nodeid, **edge.attrs) self.edges[(node, node2)] = edge
def subgraph(self, dot: Digraph) -> None:
"""Create a subgraph for clustering"""
self.dot.subgraph(dot)
def render(self) -> None: def render(self) -> None:
if isinstance(self.outformat, list): if isinstance(self.outformat, list):
@ -192,8 +269,8 @@ class Diagram:
self.dot.render(format=self.outformat, view=self.show, quiet=True) self.dot.render(format=self.outformat, view=self.show, quiet=True)
class Cluster: class Node(_Cluster):
__directions = ("TB", "BT", "LR", "RL") """Node represents a node for a specific backend service."""
__bgcolors = ("#E5F5FD", "#EBF3E7", "#ECE8F6", "#FDF7E3") __bgcolors = ("#E5F5FD", "#EBF3E7", "#ECE8F6", "#FDF7E3")
# fmt: off # fmt: off
@ -281,14 +358,56 @@ class Node:
_icon_dir = None _icon_dir = None
_icon = None _icon = None
_icon_size = 30
_direction = "LR"
_height = 1.9 _height = 1.9
def __init__(self, label: str = "", *, nodeid: str = None, **attrs: Dict): # fmt: on
def __new__(cls, *args, **kwargs):
instance = object.__new__(cls)
lazy = kwargs.pop('_no_init', False)
if not lazy:
return instance
cls.__init__ = new_init(cls, cls.__init__)
return instance
def __init__(
self,
label: str = "",
direction: str = None,
icon: object = None,
icon_size: int = None,
**attrs: Dict
):
"""Node represents a system component. """Node represents a system component.
:param label: Node label. :param label: Node label.
:param direction: Data flow direction. Default is "LR" (left to right).
:param icon: Custom icon for tihs cluster. Must be a node class or reference.
:param icon_size: The icon size when used as a Cluster. Default is 30.
""" """
# Generates an ID for identifying a node.
self._id = self._rand_id()
if isinstance(label, str):
self.label = label
elif isinstance(label, Sequence):
self.label = "\n".join(label)
else:
self.label = str(label)
super().__init__()
if direction:
if not self._validate_direction(direction):
raise ValueError(f'"{direction}" is not a valid direction')
self._direction = direction
if icon:
_node = icon(_no_init=True)
self._icon = _node._icon
self._icon_dir = _node._icon_dir
if icon_size:
self._icon_size = icon_size
# Generates an ID for identifying a node, unless specified # Generates an ID for identifying a node, unless specified
self._id = nodeid or self._rand_id() self._id = nodeid or self._rand_id()
self.label = label self.label = label
@ -310,11 +429,14 @@ class Node:
# that label being spanned between icon image and white space. # that label being spanned between icon image and white space.
# Increase the height by the number of new lines included in the label. # Increase the height by the number of new lines included in the label.
padding = 0.4 * (self.label.count('\n')) padding = 0.4 * (self.label.count('\n'))
icon_path = self._load_icon()
self._attrs = { self._attrs = {
"shape": "none", "shape": "none",
"height": str(self._height + padding), "height": str(self._height + padding),
"image": self._load_icon(), "image": icon_path,
} if self._icon else {} } if icon_path else {}
self._attrs['tooltip'] = (icon if icon else self).__class__.__name__
# fmt: on # fmt: on
self._attrs.update(attrs) self._attrs.update(attrs)
@ -322,10 +444,43 @@ class Node:
self._cluster = getcluster() self._cluster = getcluster()
# If a node is in the cluster context, add it to cluster. # If a node is in the cluster context, add it to cluster.
if self._cluster: if not self._parent:
self._cluster.node(self._id, self.label, **self._attrs) raise EnvironmentError("Global diagrams context not set up")
self._parent.node(self)
def __enter__(self):
super().__enter__()
# Set attributes.
for k, v in self._default_graph_attrs.items():
self.dot.graph_attr[k] = v
for k, v in self._attrs.items():
self.dot.graph_attr[k] = v
icon = self._load_icon()
if icon:
lines = iter(html.escape(self.label).split("\n"))
self.dot.graph_attr["label"] = '<<TABLE border="0"><TR>' +\
f'<TD fixedsize="true" width="{self._icon_size}" height="{self._icon_size}"><IMG SRC="{icon}"></IMG></TD>' +\
f'<TD align="left">{next(lines)}</TD></TR>' +\
''.join(f'<TR><TD colspan="2" align="left">{line}</TD></TR>' for line in lines) +\
'</TABLE>>'
else: else:
self._diagram.node(self._id, self.label, **self._attrs) self.dot.graph_attr["label"] = self.label
self.dot.graph_attr["rankdir"] = self._direction
# Set cluster depth for distinguishing the background color
self.depth = self._parent.depth + 1
coloridx = self.depth % len(self.__bgcolors)
self.dot.graph_attr["bgcolor"] = self.__bgcolors[coloridx]
return self
def __exit__(self, *args):
super().__exit__(*args)
self._id = "cluster_" + self.nodeid
self.dot.name = self.nodeid
def __repr__(self): def __repr__(self):
_name = self.__class__.__name__ _name = self.__class__.__name__
@ -400,7 +555,7 @@ class Node:
@property @property
def nodeid(self): def nodeid(self):
return self._id return self._id
# TODO: option for adding flow description to the connection edge # TODO: option for adding flow description to the connection edge
def connect(self, node: "Node", edge: "Edge"): def connect(self, node: "Node", edge: "Edge"):
"""Connect to other node. """Connect to other node.
@ -414,7 +569,7 @@ class Node:
if not isinstance(edge, Edge): if not isinstance(edge, Edge):
ValueError(f"{edge} is not a valid Edge") ValueError(f"{edge} is not a valid Edge")
# An edge must be added on the global diagrams, not a cluster. # An edge must be added on the global diagrams, not a cluster.
self._diagram.connect(self, node, edge) getdiagram().connect(self, node, edge)
return node return node
@staticmethod @staticmethod
@ -422,8 +577,10 @@ class Node:
return uuid.uuid4().hex return uuid.uuid4().hex
def _load_icon(self): def _load_icon(self):
basedir = Path(os.path.abspath(os.path.dirname(__file__))) if self._icon and self._icon_dir:
return os.path.join(basedir.parent, self._icon_dir, self._icon) basedir = Path(os.path.abspath(os.path.dirname(__file__)))
return os.path.join(basedir.parent, self._icon_dir, self._icon)
return None
class Edge: class Edge:
@ -472,6 +629,7 @@ class Edge:
# Graphviz complaining about using label for edges, so replace it with xlabel. # Graphviz complaining about using label for edges, so replace it with xlabel.
# Update: xlabel option causes the misaligned label position: https://github.com/mingrammer/diagrams/issues/83 # Update: xlabel option causes the misaligned label position: https://github.com/mingrammer/diagrams/issues/83
self._attrs["label"] = label self._attrs["label"] = label
self._attrs["tooltip"] = label
if color: if color:
self._attrs["color"] = color self._attrs["color"] = color
if style: if style:
@ -544,4 +702,4 @@ class Edge:
return {**self._attrs, "dir": direction} return {**self._attrs, "dir": direction}
Group = Cluster Group = Cluster = Node

@ -2,7 +2,7 @@
AWS provides a set of services for Amazon Web Service provider. AWS provides a set of services for Amazon Web Service provider.
""" """
from diagrams import Node from diagrams import Node, Cluster
class _AWS(Node): class _AWS(Node):

@ -0,0 +1,104 @@
from diagrams import Cluster
from diagrams.aws.compute import EC2, ApplicationAutoScaling
from diagrams.aws.network import VPC, PrivateSubnet, PublicSubnet
class Region(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "dotted",
"labeljust": "l",
"pencolor": "#AEB6BE",
"fontname": "Sans-Serif",
"fontsize": "12",
}
# fmt: on
class AvailabilityZone(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "dashed",
"labeljust": "l",
"pencolor": "#27a0ff",
"fontname": "sans-serif",
"fontsize": "12",
}
# fmt: on
class VirtualPrivateCloud(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "",
"labeljust": "l",
"pencolor": "#00D110",
"fontname": "sans-serif",
"fontsize": "12",
}
# fmt: on
_icon = VPC
class PrivateSubnet(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "",
"labeljust": "l",
"pencolor": "#329CFF",
"fontname": "sans-serif",
"fontsize": "12",
}
# fmt: on
_icon = PrivateSubnet
class PublicSubnet(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "",
"labeljust": "l",
"pencolor": "#00D110",
"fontname": "sans-serif",
"fontsize": "12",
}
# fmt: on
_icon = PublicSubnet
class SecurityGroup(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "dashed",
"labeljust": "l",
"pencolor": "#FF361E",
"fontname": "Sans-Serif",
"fontsize": "12",
}
# fmt: on
class AutoScalling(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "dashed",
"labeljust": "l",
"pencolor": "#FF7D1E",
"fontname": "Sans-Serif",
"fontsize": "12",
}
# fmt: on
_icon = ApplicationAutoScaling
class EC2Contents(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "",
"labeljust": "l",
"pencolor": "#FFB432",
"fontname": "Sans-Serif",
"fontsize": "12",
}
# fmt: on
_icon = EC2

@ -0,0 +1,143 @@
from diagrams import Cluster
from diagrams.azure.compute import VM, VMWindows, VMLinux #, VMScaleSet # Depends on PR-404
from diagrams.azure.network import VirtualNetworks, Subnets, NetworkSecurityGroupsClassic
class Subscription(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "dotted",
"labeljust": "l",
"pencolor": "#AEB6BE",
"fontname": "Sans-Serif",
"fontsize": "12",
}
# fmt: on
class Region(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "dotted",
"labeljust": "l",
"pencolor": "#AEB6BE",
"fontname": "Sans-Serif",
"fontsize": "12",
}
# fmt: on
class AvailabilityZone(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "dashed",
"labeljust": "l",
"pencolor": "#27a0ff",
"fontname": "sans-serif",
"fontsize": "12",
}
# fmt: on
class VirtualNetwork(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "",
"labeljust": "l",
"pencolor": "#00D110",
"fontname": "sans-serif",
"fontsize": "12",
}
# fmt: on
_icon = VirtualNetworks
class SubnetWithNSG(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "",
"labeljust": "l",
"pencolor": "#329CFF",
"fontname": "sans-serif",
"fontsize": "12",
}
# fmt: on
_icon = NetworkSecurityGroupsClassic
class Subnet(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "",
"labeljust": "l",
"pencolor": "#00D110",
"fontname": "sans-serif",
"fontsize": "12",
}
# fmt: on
_icon = Subnets
class SecurityGroup(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "dashed",
"labeljust": "l",
"pencolor": "#FF361E",
"fontname": "Sans-Serif",
"fontsize": "12",
}
# fmt: on
class VMContents(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "",
"labeljust": "l",
"pencolor": "#FFB432",
"fontname": "Sans-Serif",
"fontsize": "12",
}
# fmt: on
_icon = VM
class VMLinuxContents(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "",
"labeljust": "l",
"pencolor": "#FFB432",
"fontname": "Sans-Serif",
"fontsize": "12",
}
# fmt: on
_icon = VMLinux
class VMWindowsContents(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "",
"labeljust": "l",
"pencolor": "#FFB432",
"fontname": "Sans-Serif",
"fontsize": "12",
}
# fmt: on
_icon = VMWindows
# Depends on PR-404
# class VMSS(Cluster):
# # fmt: off
# _default_graph_attrs = {
# "shape": "box",
# "style": "dashed",
# "labeljust": "l",
# "pencolor": "#FF7D1E",
# "fontname": "Sans-Serif",
# "fontsize": "12",
# }
# # fmt: on
# _icon = VMScaleSet

@ -0,0 +1,15 @@
from diagrams import Cluster
from diagrams.onprem.compute import Server
class ServerContents(Cluster):
# fmt: off
_default_graph_attrs = {
"shape": "box",
"style": "rounded,dotted",
"labeljust": "l",
"pencolor": "#A0A0A0",
"fontname": "Sans-Serif",
"fontsize": "12",
}
# fmt: on
_icon = Server

@ -66,6 +66,36 @@ with Diagram("Event Processing", show=False):
handlers >> dw handlers >> dw
``` ```
## Clusters with icons in the label
You can add a Node icon before the cluster label (and specify its size as well). You need to import the used Node class first.
It's also possible to use the node in the `with` context adding `cluster=True` to
make it behave like a cluster.
```python
from diagrams import Cluster, Diagram
from diagrams.aws.compute import ECS
from diagrams.aws.database import RDS, Aurora
from diagrams.aws.network import Route53, VPC
with Diagram("Simple Web Service with DB Cluster", show=False):
dns = Route53("dns")
web = ECS("service")
with Cluster(label='VPC',icon=VPC):
with Cluster("DB Cluster",icon=Aurora,icon_size=30):
db_master = RDS("master")
db_master - [RDS("slave1"),
RDS("slave2")]
with Aurora("DB Cluster", cluster=True):
db_master = RDS("master")
db_master - [RDS("slave1"),
RDS("slave2")]
dns >> web >> db_master
```
![event processing diagram](/img/event_processing_diagram.png) ![event processing diagram](/img/event_processing_diagram.png)
> There is no depth limit of nesting. Feel free to create nested clusters as deep as you want. > There is no depth limit of nesting. Feel free to create nested clusters as deep as you want.

Binary file not shown.

After

Width:  |  Height:  |  Size: 67 KiB

@ -0,0 +1,26 @@
from diagrams import Diagram, Edge
from diagrams.aws.cluster import *
from diagrams.aws.compute import EC2
from diagrams.onprem.container import Docker
from diagrams.onprem.cluster import *
from diagrams.aws.network import ELB
with Diagram(name="", direction="TB", show=True):
with Cluster("AWS"):
with Region("eu-west-1"):
with AvailabilityZone("eu-west-1a"):
with VirtualPrivateCloud(""):
with PrivateSubnet("Private"):
with SecurityGroup("web sg"):
with AutoScalling(""):
with EC2Contents("A"):
d1 = Docker("Container")
with ServerContents("A1"):
d2 = Docker("Container")
with PublicSubnet("Public"):
with SecurityGroup("elb sg"):
lb = ELB()
lb >> Edge(forward=True, reverse=True) >> d1
lb >> Edge(forward=True, reverse=True) >> d2

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

@ -0,0 +1,24 @@
from diagrams import Diagram, Edge
from diagrams.azure.cluster import *
from diagrams.azure.compute import VM
from diagrams.onprem.container import Docker
from diagrams.onprem.cluster import *
from diagrams.azure.network import LoadBalancers
with Diagram(name="", filename="azure", direction="TB", show=True):
with Cluster("Azure"):
with Region("East US2"):
with AvailabilityZone("Zone 2"):
with VirtualNetwork(""):
with SubnetWithNSG("Private"):
# with VMScaleSet(""): # Depends on PR-404
with VMContents("A"):
d1 = Docker("Container")
with ServerContents("A1"):
d2 = Docker("Container")
with Subnet("Public"):
lb = LoadBalancers()
lb >> Edge(forward=True, reverse=True) >> d1
lb >> Edge(forward=True, reverse=True) >> d2

@ -154,20 +154,20 @@ class ClusterTest(unittest.TestCase):
def test_with_global_context(self): def test_with_global_context(self):
with Diagram(name=os.path.join(self.name, "with_global_context"), show=False): with Diagram(name=os.path.join(self.name, "with_global_context"), show=False):
self.assertIsNone(getcluster()) self.assertEqual(getcluster(), getdiagram())
with Cluster(): with Cluster():
self.assertIsNotNone(getcluster()) self.assertNotEqual(getcluster(), getdiagram())
self.assertIsNone(getcluster()) self.assertEqual(getcluster(), getdiagram())
def test_with_nested_cluster(self): def test_with_nested_cluster(self):
with Diagram(name=os.path.join(self.name, "with_nested_cluster"), show=False): with Diagram(name=os.path.join(self.name, "with_nested_cluster"), show=False):
self.assertIsNone(getcluster()) self.assertEqual(getcluster(), getdiagram())
with Cluster() as c1: with Cluster() as c1:
self.assertEqual(c1, getcluster()) self.assertEqual(c1, getcluster())
with Cluster() as c2: with Cluster() as c2:
self.assertEqual(c2, getcluster()) self.assertEqual(c2, getcluster())
self.assertEqual(c1, getcluster()) self.assertEqual(c1, getcluster())
self.assertIsNone(getcluster()) self.assertEqual(getcluster(), getdiagram())
def test_node_not_in_diagram(self): def test_node_not_in_diagram(self):
# Node must be belong to a diagrams. # Node must be belong to a diagrams.

Loading…
Cancel
Save