mirror of https://github.com/mingrammer/diagrams
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
579 lines
19 KiB
579 lines
19 KiB
import contextvars
|
|
import os
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
from graphviz import Digraph
|
|
|
|
# Global contexts for a diagrams and a cluster.
|
|
#
|
|
# These global contexts are for letting the clusters and nodes know
|
|
# where context they are belong to. So the all clusters and nodes does
|
|
# not need to specify the current diagrams or cluster via parameters.
|
|
__diagram = contextvars.ContextVar("diagrams")
|
|
__cluster = contextvars.ContextVar("cluster")
|
|
|
|
|
|
def getdiagram() -> "Diagram":
|
|
try:
|
|
return __diagram.get()
|
|
except LookupError:
|
|
return None
|
|
|
|
|
|
def setdiagram(diagram: "Diagram"):
|
|
__diagram.set(diagram)
|
|
|
|
|
|
def getcluster() -> "Cluster":
|
|
try:
|
|
return __cluster.get()
|
|
except LookupError:
|
|
return None
|
|
|
|
|
|
def setcluster(cluster: "Cluster"):
|
|
__cluster.set(cluster)
|
|
|
|
|
|
class Diagram:
|
|
__directions = ("TB", "BT", "LR", "RL")
|
|
__curvestyles = ("ortho", "curved")
|
|
__outformats = ("png", "jpg", "svg", "pdf", "dot")
|
|
|
|
# fmt: off
|
|
_default_graph_attrs = {
|
|
"pad": "2.0",
|
|
"splines": "ortho",
|
|
"nodesep": "0.60",
|
|
"ranksep": "0.75",
|
|
"fontname": "Sans-Serif",
|
|
"fontsize": "15",
|
|
"fontcolor": "#2D3436",
|
|
}
|
|
_default_node_attrs = {
|
|
"shape": "box",
|
|
"style": "rounded",
|
|
"fixedsize": "true",
|
|
"width": "1.4",
|
|
"height": "1.4",
|
|
"labelloc": "b",
|
|
# imagepos attribute is not backward compatible
|
|
# TODO: check graphviz version to see if "imagepos" is available >= 2.40
|
|
# https://github.com/xflr6/graphviz/blob/master/graphviz/backend.py#L248
|
|
# "imagepos": "tc",
|
|
"imagescale": "true",
|
|
"fontname": "Sans-Serif",
|
|
"fontsize": "13",
|
|
"fontcolor": "#2D3436",
|
|
}
|
|
_default_edge_attrs = {
|
|
"color": "#7B8894",
|
|
}
|
|
|
|
# fmt: on
|
|
|
|
# TODO: Label position option
|
|
# TODO: Save directory option (filename + directory?)
|
|
def __init__(
|
|
self,
|
|
name: str = "",
|
|
filename: str = "",
|
|
direction: str = "LR",
|
|
curvestyle: str = "ortho",
|
|
outformat: Union[str, list[str]] = "png",
|
|
autolabel: bool = False,
|
|
show: bool = True,
|
|
strict: bool = False,
|
|
graph_attr: Optional[dict] = None,
|
|
node_attr: Optional[dict] = None,
|
|
edge_attr: Optional[dict] = None,
|
|
save_to_disk: bool = True, # Add flag to control saving behavior
|
|
):
|
|
"""Diagram represents a global diagrams context.
|
|
|
|
:param name: Diagram name. It will be used for output filename if the
|
|
filename isn't given.
|
|
:param filename: The output filename, without the extension (.png).
|
|
If not given, it will be generated from the name.
|
|
:param direction: Data flow direction. Default is 'left to right'.
|
|
:param curvestyle: Curve bending style. One of "ortho" or "curved".
|
|
:param outformat: Output file format. Default is 'png'.
|
|
:param show: Open generated image after save if true, just only save otherwise.
|
|
:param graph_attr: Provide graph_attr dot config attributes.
|
|
:param node_attr: Provide node_attr dot config attributes.
|
|
:param edge_attr: Provide edge_attr dot config attributes.
|
|
:param strict: Rendering should merge multi-edges.
|
|
"""
|
|
if graph_attr is None:
|
|
graph_attr = {}
|
|
if node_attr is None:
|
|
node_attr = {}
|
|
if edge_attr is None:
|
|
edge_attr = {}
|
|
self.name = name
|
|
if not name and not filename:
|
|
filename = "diagrams_image"
|
|
elif not filename:
|
|
filename = "_".join(self.name.split()).lower()
|
|
self.filename = filename
|
|
self.dot = Digraph(self.name, filename=self.filename, strict=strict)
|
|
if save_to_disk is None:
|
|
save_to_disk = True
|
|
|
|
|
|
# Set attributes.
|
|
for k, v in self._default_graph_attrs.items():
|
|
self.dot.graph_attr[k] = v
|
|
self.dot.graph_attr["label"] = self.name
|
|
for k, v in self._default_node_attrs.items():
|
|
self.dot.node_attr[k] = v
|
|
for k, v in self._default_edge_attrs.items():
|
|
self.dot.edge_attr[k] = v
|
|
|
|
if not self._validate_direction(direction):
|
|
raise ValueError(f'"{direction}" is not a valid direction')
|
|
self.dot.graph_attr["rankdir"] = direction
|
|
|
|
if not self._validate_curvestyle(curvestyle):
|
|
raise ValueError(f'"{curvestyle}" is not a valid curvestyle')
|
|
self.dot.graph_attr["splines"] = curvestyle
|
|
|
|
if isinstance(outformat, list):
|
|
for one_format in outformat:
|
|
if not self._validate_outformat(one_format):
|
|
raise ValueError(
|
|
f'"{one_format}" is not a valid output format')
|
|
else:
|
|
if not self._validate_outformat(outformat):
|
|
raise ValueError(f'"{outformat}" is not a valid output format')
|
|
self.outformat = outformat
|
|
|
|
# Merge passed in attributes
|
|
self.dot.graph_attr.update(graph_attr)
|
|
self.dot.node_attr.update(node_attr)
|
|
self.dot.edge_attr.update(edge_attr)
|
|
|
|
self.show = show
|
|
self.autolabel = autolabel
|
|
self.save_to_disk = save_to_disk
|
|
|
|
def __str__(self) -> str:
|
|
return str(self.dot)
|
|
|
|
def __enter__(self):
|
|
setdiagram(self)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
if self.save_to_disk: # Only render if save_to_disk is True
|
|
self.render()
|
|
# Remove the graphviz file leaving only the image.
|
|
os.remove(self.filename)
|
|
setdiagram(None)
|
|
|
|
|
|
def _repr_png_(self):
|
|
return self.dot.pipe(format="png")
|
|
|
|
def _validate_direction(self, direction: str) -> bool:
|
|
return direction.upper() in self.__directions
|
|
|
|
def _validate_curvestyle(self, curvestyle: str) -> bool:
|
|
return curvestyle.lower() in self.__curvestyles
|
|
|
|
def _validate_outformat(self, outformat: str) -> bool:
|
|
return outformat.lower() in self.__outformats
|
|
|
|
def node(self, nodeid: str, label: str, **attrs) -> None:
|
|
"""Create a new node."""
|
|
self.dot.node(nodeid, label=label, **attrs)
|
|
|
|
def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None:
|
|
"""Connect the two Nodes."""
|
|
self.dot.edge(node.nodeid, node2.nodeid, **edge.attrs)
|
|
|
|
def subgraph(self, dot: Digraph) -> None:
|
|
"""Create a subgraph for clustering"""
|
|
self.dot.subgraph(dot)
|
|
|
|
def render(self) -> None:
|
|
if isinstance(self.outformat, list):
|
|
for one_format in self.outformat:
|
|
self.dot.render(format=one_format, view=self.show, quiet=True)
|
|
else:
|
|
self.dot.render(format=self.outformat, view=self.show, quiet=True)
|
|
|
|
|
|
class Cluster:
|
|
__directions = ("TB", "BT", "LR", "RL")
|
|
__bgcolors = ("#E5F5FD", "#EBF3E7", "#ECE8F6", "#FDF7E3")
|
|
|
|
# fmt: off
|
|
_default_graph_attrs = {
|
|
"shape": "box",
|
|
"style": "rounded",
|
|
"labeljust": "l",
|
|
"pencolor": "#AEB6BE",
|
|
"fontname": "Sans-Serif",
|
|
"fontsize": "12",
|
|
}
|
|
|
|
# fmt: on
|
|
|
|
# FIXME:
|
|
# Cluster direction does not work now. Graphviz couldn't render
|
|
# correctly for a subgraph that has a different rank direction.
|
|
def __init__(
|
|
self,
|
|
label: str = "cluster",
|
|
direction: str = "LR",
|
|
graph_attr: Optional[dict] = None,
|
|
):
|
|
"""Cluster represents a cluster context.
|
|
|
|
:param label: Cluster label.
|
|
:param direction: Data flow direction. Default is 'left to right'.
|
|
:param graph_attr: Provide graph_attr dot config attributes.
|
|
"""
|
|
if graph_attr is None:
|
|
graph_attr = {}
|
|
self.label = label
|
|
self.name = "cluster_" + self.label
|
|
|
|
self.dot = Digraph(self.name)
|
|
|
|
# Set attributes.
|
|
for k, v in self._default_graph_attrs.items():
|
|
self.dot.graph_attr[k] = v
|
|
self.dot.graph_attr["label"] = self.label
|
|
|
|
if not self._validate_direction(direction):
|
|
raise ValueError(f'"{direction}" is not a valid direction')
|
|
self.dot.graph_attr["rankdir"] = direction
|
|
|
|
# Node must be belong to a diagrams.
|
|
self._diagram = getdiagram()
|
|
if self._diagram is None:
|
|
raise EnvironmentError("Global diagrams context not set up")
|
|
self._parent = getcluster()
|
|
|
|
# Set cluster depth for distinguishing the background color
|
|
self.depth = self._parent.depth + 1 if self._parent else 0
|
|
coloridx = self.depth % len(self.__bgcolors)
|
|
self.dot.graph_attr["bgcolor"] = self.__bgcolors[coloridx]
|
|
|
|
# Merge passed in attributes
|
|
self.dot.graph_attr.update(graph_attr)
|
|
|
|
def __enter__(self):
|
|
setcluster(self)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
if self._parent:
|
|
self._parent.subgraph(self.dot)
|
|
else:
|
|
self._diagram.subgraph(self.dot)
|
|
setcluster(self._parent)
|
|
|
|
def _validate_direction(self, direction: str) -> bool:
|
|
return direction.upper() in self.__directions
|
|
|
|
def node(self, nodeid: str, label: str, **attrs) -> None:
|
|
"""Create a new node in the cluster."""
|
|
self.dot.node(nodeid, label=label, **attrs)
|
|
|
|
def subgraph(self, dot: Digraph) -> None:
|
|
self.dot.subgraph(dot)
|
|
|
|
|
|
class Node:
|
|
"""Node represents a node for a specific backend service."""
|
|
|
|
_provider = None
|
|
_type = None
|
|
|
|
_icon_dir = None
|
|
_icon = None
|
|
|
|
_height = 1.9
|
|
|
|
def __init__(self, label: str = "", *, nodeid: str = None, **attrs: Dict):
|
|
"""Node represents a system component.
|
|
|
|
:param label: Node label.
|
|
"""
|
|
# Generates an ID for identifying a node, unless specified
|
|
self._id = nodeid or self._rand_id()
|
|
self.label = label
|
|
|
|
# Node must be belong to a diagrams.
|
|
self._diagram = getdiagram()
|
|
if self._diagram is None:
|
|
raise EnvironmentError("Global diagrams context not set up")
|
|
|
|
if self._diagram.autolabel:
|
|
prefix = self.__class__.__name__
|
|
if self.label:
|
|
self.label = prefix + "\n" + self.label
|
|
else:
|
|
self.label = prefix
|
|
|
|
# fmt: off
|
|
# If a node has an icon, increase the height slightly to avoid
|
|
# that label being spanned between icon image and white space.
|
|
# Increase the height by the number of new lines included in the label.
|
|
padding = 0.4 * (self.label.count('\n'))
|
|
self._attrs = {
|
|
"shape": "none",
|
|
"height": str(self._height + padding),
|
|
"image": self._load_icon(),
|
|
} if self._icon else {}
|
|
|
|
# fmt: on
|
|
self._attrs.update(attrs)
|
|
|
|
self._cluster = getcluster()
|
|
|
|
# If a node is in the cluster context, add it to cluster.
|
|
if self._cluster:
|
|
self._cluster.node(self._id, self.label, **self._attrs)
|
|
else:
|
|
self._diagram.node(self._id, self.label, **self._attrs)
|
|
|
|
def __repr__(self):
|
|
_name = self.__class__.__name__
|
|
return f"<{self._provider}.{self._type}.{_name}>"
|
|
|
|
def __sub__(self, other: Union["Node", List["Node"], "Edge"]):
|
|
"""Implement Self - Node, Self - [Nodes] and Self - Edge."""
|
|
if isinstance(other, list):
|
|
for node in other:
|
|
self.connect(node, Edge(self))
|
|
return other
|
|
elif isinstance(other, Node):
|
|
return self.connect(other, Edge(self))
|
|
else:
|
|
other.node = self
|
|
return other
|
|
|
|
def __rsub__(self, other: Union[List["Node"], List["Edge"]]):
|
|
"""Called for [Nodes] and [Edges] - Self because list don't have __sub__ operators."""
|
|
for o in other:
|
|
if isinstance(o, Edge):
|
|
o.connect(self)
|
|
else:
|
|
o.connect(self, Edge(self))
|
|
return self
|
|
|
|
def __rshift__(self, other: Union["Node", List["Node"], "Edge"]):
|
|
"""Implements Self >> Node, Self >> [Nodes] and Self Edge."""
|
|
if isinstance(other, list):
|
|
for node in other:
|
|
self.connect(node, Edge(self, forward=True))
|
|
return other
|
|
elif isinstance(other, Node):
|
|
return self.connect(other, Edge(self, forward=True))
|
|
else:
|
|
other.forward = True
|
|
other.node = self
|
|
return other
|
|
|
|
def __lshift__(self, other: Union["Node", List["Node"], "Edge"]):
|
|
"""Implements Self << Node, Self << [Nodes] and Self << Edge."""
|
|
if isinstance(other, list):
|
|
for node in other:
|
|
self.connect(node, Edge(self, reverse=True))
|
|
return other
|
|
elif isinstance(other, Node):
|
|
return self.connect(other, Edge(self, reverse=True))
|
|
else:
|
|
other.reverse = True
|
|
return other.connect(self)
|
|
|
|
def __rrshift__(self, other: Union[List["Node"], List["Edge"]]):
|
|
"""Called for [Nodes] and [Edges] >> Self because list don't have __rshift__ operators."""
|
|
for o in other:
|
|
if isinstance(o, Edge):
|
|
o.forward = True
|
|
o.connect(self)
|
|
else:
|
|
o.connect(self, Edge(self, forward=True))
|
|
return self
|
|
|
|
def __rlshift__(self, other: Union[List["Node"], List["Edge"]]):
|
|
"""Called for [Nodes] << Self because list of Nodes don't have __lshift__ operators."""
|
|
for o in other:
|
|
if isinstance(o, Edge):
|
|
o.reverse = True
|
|
o.connect(self)
|
|
else:
|
|
o.connect(self, Edge(self, reverse=True))
|
|
return self
|
|
|
|
@property
|
|
def nodeid(self):
|
|
return self._id
|
|
|
|
# TODO: option for adding flow description to the connection edge
|
|
def connect(self, node: "Node", edge: "Edge"):
|
|
"""Connect to other node.
|
|
|
|
:param node: Other node instance.
|
|
:param edge: Type of the edge.
|
|
:return: Connected node.
|
|
"""
|
|
if not isinstance(node, Node):
|
|
ValueError(f"{node} is not a valid Node")
|
|
if not isinstance(edge, Edge):
|
|
ValueError(f"{edge} is not a valid Edge")
|
|
# An edge must be added on the global diagrams, not a cluster.
|
|
self._diagram.connect(self, node, edge)
|
|
return node
|
|
|
|
@staticmethod
|
|
def _rand_id():
|
|
return uuid.uuid4().hex
|
|
|
|
def _load_icon(self):
|
|
basedir = Path(os.path.abspath(os.path.dirname(__file__)))
|
|
return os.path.join(basedir.parent, self._icon_dir, self._icon)
|
|
|
|
|
|
class Edge:
|
|
"""Edge represents an edge between two nodes."""
|
|
|
|
_default_edge_attrs = {
|
|
"fontcolor": "#2D3436",
|
|
"fontname": "Sans-Serif",
|
|
"fontsize": "13",
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
node: "Node" = None,
|
|
forward: bool = False,
|
|
reverse: bool = False,
|
|
label: str = "",
|
|
color: str = "",
|
|
style: str = "",
|
|
**attrs: Dict,
|
|
):
|
|
"""Edge represents an edge between two nodes.
|
|
|
|
:param node: Parent node.
|
|
:param forward: Points forward.
|
|
:param reverse: Points backward.
|
|
:param label: Edge label.
|
|
:param color: Edge color.
|
|
:param style: Edge style.
|
|
:param attrs: Other edge attributes
|
|
"""
|
|
if node is not None:
|
|
assert isinstance(node, Node)
|
|
|
|
self.node = node
|
|
self.forward = forward
|
|
self.reverse = reverse
|
|
|
|
self._attrs = {}
|
|
|
|
# Set attributes.
|
|
for k, v in self._default_edge_attrs.items():
|
|
self._attrs[k] = v
|
|
|
|
if label:
|
|
# Graphviz complaining about using label for edges, so replace it with xlabel.
|
|
# Update: xlabel option causes the misaligned label position:
|
|
# https://github.com/mingrammer/diagrams/issues/83
|
|
self._attrs["label"] = label
|
|
if color:
|
|
self._attrs["color"] = color
|
|
if style:
|
|
self._attrs["style"] = style
|
|
self._attrs.update(attrs)
|
|
|
|
def __sub__(self, other: Union["Node", "Edge", List["Node"]]):
|
|
"""Implement Self - Node or Edge and Self - [Nodes]"""
|
|
return self.connect(other)
|
|
|
|
def __rsub__(self, other: Union[List["Node"],
|
|
List["Edge"]]) -> List["Edge"]:
|
|
"""Called for [Nodes] or [Edges] - Self because list don't have __sub__ operators."""
|
|
return self.append(other)
|
|
|
|
def __rshift__(self, other: Union["Node", "Edge", List["Node"]]):
|
|
"""Implements Self >> Node or Edge and Self >> [Nodes]."""
|
|
self.forward = True
|
|
return self.connect(other)
|
|
|
|
def __lshift__(self, other: Union["Node", "Edge", List["Node"]]):
|
|
"""Implements Self << Node or Edge and Self << [Nodes]."""
|
|
self.reverse = True
|
|
return self.connect(other)
|
|
|
|
def __rrshift__(self,
|
|
other: Union[List["Node"],
|
|
List["Edge"]]) -> List["Edge"]:
|
|
"""Called for [Nodes] or [Edges] >> Self because list of Edges don't have __rshift__ operators."""
|
|
return self.append(other, forward=True)
|
|
|
|
def __rlshift__(self,
|
|
other: Union[List["Node"],
|
|
List["Edge"]]) -> List["Edge"]:
|
|
"""Called for [Nodes] or [Edges] << Self because list of Edges don't have __lshift__ operators."""
|
|
return self.append(other, reverse=True)
|
|
|
|
def append(self,
|
|
other: Union[List["Node"],
|
|
List["Edge"]],
|
|
forward=None,
|
|
reverse=None) -> List["Edge"]:
|
|
result = []
|
|
for o in other:
|
|
if isinstance(o, Edge):
|
|
o.forward = forward if forward else o.forward
|
|
o.reverse = reverse if reverse else o.reverse
|
|
self._attrs = o.attrs.copy()
|
|
result.append(o)
|
|
else:
|
|
result.append(
|
|
Edge(
|
|
o,
|
|
forward=forward,
|
|
reverse=reverse,
|
|
**self._attrs))
|
|
return result
|
|
|
|
def connect(self, other: Union["Node", "Edge", List["Node"]]):
|
|
if isinstance(other, list):
|
|
for node in other:
|
|
self.node.connect(node, self)
|
|
return other
|
|
elif isinstance(other, Edge):
|
|
self._attrs = other._attrs.copy()
|
|
return self
|
|
else:
|
|
if self.node is not None:
|
|
return self.node.connect(other, self)
|
|
else:
|
|
self.node = other
|
|
return self
|
|
|
|
@property
|
|
def attrs(self) -> Dict:
|
|
if self.forward and self.reverse:
|
|
direction = "both"
|
|
elif self.forward:
|
|
direction = "forward"
|
|
elif self.reverse:
|
|
direction = "back"
|
|
else:
|
|
direction = "none"
|
|
return {**self._attrs, "dir": direction}
|
|
|
|
|
|
Group = Cluster
|