From 0cb9436426e3fb0fe355b93c7a5ee9549e32e6cb Mon Sep 17 00:00:00 2001 From: Andrew Selivanov Date: Sun, 1 Mar 2020 19:20:30 +0300 Subject: [PATCH] Edges support --- diagrams/__init__.py | 257 +++++++++++++++++++++++++++++------------- tests/test_diagram.py | 144 +++++++++++++++++++---- 2 files changed, 304 insertions(+), 97 deletions(-) diff --git a/diagrams/__init__.py b/diagrams/__init__.py index 5234c0c3..75487b53 100644 --- a/diagrams/__init__.py +++ b/diagrams/__init__.py @@ -3,7 +3,7 @@ import os from hashlib import md5 from pathlib import Path from random import getrandbits -from typing import List, Union +from typing import List, Union, Dict from graphviz import Digraph @@ -75,15 +75,15 @@ class Diagram: # TODO: Label position option # TODO: Save directory option (filename + directory?) def __init__( - self, - name: str = "", - filename: str = "", - direction: str = "LR", - outformat: str = "png", - show: bool = True, - graph_attr: dict = {}, - node_attr: dict = {}, - edge_attr: dict = {}, + self, + name: str = "", + filename: str = "", + direction: str = "LR", + outformat: str = "png", + show: bool = True, + graph_attr: dict = {}, + node_attr: dict = {}, + edge_attr: dict = {}, ): """Diagram represents a global diagrams context. @@ -129,6 +129,9 @@ class Diagram: self.show = show + def __str__(self) -> str: + return str(self.dot) + def __enter__(self): setdiagram(self) return self @@ -160,15 +163,9 @@ class Diagram: """Create a new node.""" self.dot.node(hashid, label=label, **attrs) - def connect(self, node: "Node", node2: "Node", directed=True) -> None: + def connect(self, node: "Node", node2: "Node", edge: "Edge") -> None: """Connect the two Nodes.""" - attrs = {"dir": "none"} if not directed else {} - self.dot.edge(node.hashid, node2.hashid, **attrs) - - def reverse(self, node: "Node", node2: "Node", directed=True) -> None: - """Connect the two Nodes in reverse direction.""" - attrs = {"dir": "none"} if not directed else {"dir": "back"} - self.dot.edge(node.hashid, node2.hashid, **attrs) + self.dot.edge(node.hashid, node2.hashid, **edge.to_dot()) def subgraph(self, dot: Digraph) -> None: """Create a subgraph for clustering""" @@ -302,54 +299,70 @@ class Node: _name = self.__class__.__name__ return f"<{self._provider}.{self._type}.{_name}>" - def __sub__(self, other: Union["Node", List["Node"]]): - """Implement Self - Node and Self - [Nodes]""" - if not isinstance(other, list): - return self.connect(other, directed=False) - for node in other: - self.connect(node, directed=False) - return other - - def __rsub__(self, other: List["Node"]): - """ - Called for [Nodes] - Self because list of Nodes don't have - __sub__ operators. - """ - self.__sub__(other) + def __sub__(self, other: Union["Node", List["Node"], "Edge"]): + """Implement Self - Node, Self - [Nodes] and Self - Edge.""" + if isinstance(other, list): + for node in other: + self.connect(node, Edge(self)) + return other + elif isinstance(other, Node): + return self.connect(other, Edge(self)) + else: + other.node = self + return other + + def __rsub__(self, other: Union[List["Node"], List["Edge"]]): + """ Called for [Nodes] and [Edges] - Self because list don't have __sub__ operators. """ + for o in other: + if isinstance(o, Edge): + o.connect(self) + else: + o.connect(self, Edge(self)) return self - def __rshift__(self, other: Union["Node", List["Node"]]): - """Implements Self >> Node and Self >> [Nodes].""" - if not isinstance(other, list): - return self.connect(other) - for node in other: - self.connect(node) - return other - - def __lshift__(self, other: Union["Node", List["Node"]]): - """Implements Self << Node and Self << [Nodes].""" - if not isinstance(other, list): - return self.reverse(other) - for node in other: - self.reverse(node) - return other - - def __rrshift__(self, other: List["Node"]): - """ - Called for [Nodes] >> Self because list of Nodes don't have - __rshift__ operators. - """ - for node in other: - node.connect(self) + def __rshift__(self, other: Union["Node", List["Node"], "Edge"]): + """Implements Self >> Node, Self >> [Nodes] and Self Edge.""" + if isinstance(other, list): + for node in other: + self.connect(node, Edge(self, forward=True)) + return other + elif isinstance(other, Node): + return self.connect(other, Edge(self, forward=True)) + else: + other.forward = True + other.node = self + return other + + def __lshift__(self, other: Union["Node", List["Node"], "Edge"]): + """Implements Self << Node, Self << [Nodes] and Self << Edge.""" + if isinstance(other, list): + for node in other: + self.connect(node, Edge(self, reverse=True)) + return other + elif isinstance(other, Node): + return self.connect(other, Edge(self, reverse=True)) + else: + other.reverse = True + return other.connect(self) + + def __rrshift__(self, other: Union[List["Node"], List["Edge"]]): + """Called for [Nodes] and [Edges] >> Self because list don't have __rshift__ operators.""" + for o in other: + if isinstance(o, Edge): + o.forward = True + o.connect(self) + else: + o.connect(self, Edge(self, forward=True)) return self - def __rlshift__(self, other: List["Node"]): - """ - Called for [Nodes] << Self because list of Nodes don't have - __lshift__ operators. - """ - for node in other: - node.reverse(self) + def __rlshift__(self, other: Union[List["Node"], List["Edge"]]): + """Called for [Nodes] << Self because list of Nodes don't have __lshift__ operators.""" + for o in other: + if isinstance(o, Edge): + o.reverse = True + o.connect(self) + else: + o.connect(self, Edge(self, reverse=True)) return self @property @@ -357,30 +370,19 @@ class Node: return self._hash # TODO: option for adding flow description to the connection edge - def connect(self, node: "Node", directed=True): + def connect(self, node: "Node", edge: "Edge"): """Connect to other node. :param node: Other node instance. - :param directed: Whether the flow is directed or not. - :return: Connected node. - """ - if not isinstance(node, Node): - ValueError(f"{node} is not a valid Node") - # An edge must be added on the global diagrams, not a cluster. - self._diagram.connect(self, node, directed) - return node - - def reverse(self, node: "Node", directed=True): - """Connect to other node in reverse direction. - - :param node: Other node instance. - :param directed: Whether the flow is directed or not. + :param edge: Type of the edge. :return: Connected node. """ if not isinstance(node, Node): ValueError(f"{node} is not a valid Node") + if not isinstance(node, Edge): + ValueError(f"{node} is not a valid Edge") # An edge must be added on the global diagrams, not a cluster. - self._diagram.reverse(self, node, directed) + self._diagram.connect(self, node, edge) return node @staticmethod @@ -392,4 +394,103 @@ class Node: return os.path.join(basedir.parent, self._icon_dir, self._icon) +class Edge: + """Edge represents an edge between two nodes.""" + + def __init__(self, + node: "Node" = None, + forward: bool = False, + reverse: bool = False, + label: str = "", + color: str = "" + ): + """Edge represents an edge between two nodes. + + :param node: Parent node. + :param forward: Points forward. + :param reverse: Points backward. + :param label: Edge label. + :param color: Edge color. + """ + if node is not None: + assert type(node) is Node + + self.node = node + self.forward = forward + self.reverse = reverse + self.label = label + self.color = color + + def __sub__(self, other: Union["Node", "Edge", List["Node"]]): + """Implement Self - Node or Edge and Self - [Nodes]""" + return self.connect(other) + + def __rsub__(self, other: Union[List["Node"], List["Edge"]]) -> List["Edge"]: + """Called for [Nodes] or [Edges] - Self because list don't have __sub__ operators.""" + return self.append(other) + + def __rshift__(self, other: Union["Node", "Edge", List["Node"]]): + """Implements Self >> Node or Edge and Self >> [Nodes].""" + self.forward = True + return self.connect(other) + + def __lshift__(self, other: Union["Node", "Edge", List["Node"]]): + """Implements Self << Node or Edge and Self << [Nodes].""" + self.reverse = True + return self.connect(other) + + def __rrshift__(self, other: Union[List["Node"], List["Edge"]]) -> List["Edge"]: + """Called for [Nodes] or [Edges] >> Self because list of Edges don't have __rshift__ operators.""" + return self.append(other, forward=True) + + def __rlshift__(self, other: Union[List["Node"], List["Edge"]]) -> List["Edge"]: + """Called for [Nodes] or [Edges] << Self because list of Edges don't have __lshift__ operators.""" + return self.append(other, reverse=True) + + def append(self, other: Union[List["Node"], List["Edge"]], forward=None, reverse=None) -> List["Edge"]: + result = [] + for o in other: + if isinstance(o, Edge): + o.color = self.color + o.label = self.label + o.forward = forward if forward is not None else o.forward + o.reverse = forward if forward is not None else o.reverse + result.append(o) + else: + result.append(Edge(o, forward=forward, reverse=reverse, color=self.color, label=self.label)) + return result + + def connect(self, other: Union["Node", "Edge", List["Node"]]): + if isinstance(other, list): + for node in other: + self.node.connect(node, self) + return other + elif isinstance(other, Edge): + self.label = other.label + self.color = other.color + return self + else: + if self.node is not None: + return self.node.connect(other, self) + else: + self.node = other + return self + + def to_dot(self) -> Dict: + dot = {} + if self.forward and self.reverse: + dot['dir'] = 'both' + elif self.forward: + dot['dir'] = 'forward' + elif self.reverse: + dot['dir'] = 'back' + else: + dot['dir'] = 'none' + if self.label: + dot['label'] = self.label + if self.color: + dot['color'] = self.color + return dot + + Group = Cluster diff --git a/tests/test_diagram.py b/tests/test_diagram.py index 67f73da3..d721bb95 100644 --- a/tests/test_diagram.py +++ b/tests/test_diagram.py @@ -1,22 +1,27 @@ import os +import shutil import unittest -from diagrams import Cluster, Diagram, Node +from diagrams import Cluster, Diagram, Node, Edge from diagrams import getcluster, getdiagram, setcluster, setdiagram class DiagramTest(unittest.TestCase): def setUp(self): - self.name = "test" + self.name = "diagram_test" def tearDown(self): setdiagram(None) setcluster(None) # Only some tests generate the image file. try: - os.remove(self.name + ".png") - except FileNotFoundError: - pass + shutil.rmtree(self.name) + except OSError: + # Consider it file + try: + os.remove(self.name + ".png") + except FileNotFoundError: + pass def test_validate_direction(self): # Normal directions. @@ -40,7 +45,7 @@ class DiagramTest(unittest.TestCase): def test_with_global_context(self): self.assertIsNone(getdiagram()) - with Diagram(name=self.name, show=False): + with Diagram(name=os.path.join(self.name, 'with_global_context'), show=False): self.assertIsNotNone(getdiagram()) self.assertIsNone(getdiagram()) @@ -50,7 +55,7 @@ class DiagramTest(unittest.TestCase): Node("node") def test_node_to_node(self): - with Diagram(name=self.name, show=False): + with Diagram(name=os.path.join(self.name, 'node_to_node'), show=False): node1 = Node("node1") node2 = Node("node2") self.assertEqual(node1 - node2, node2) @@ -58,7 +63,7 @@ class DiagramTest(unittest.TestCase): self.assertEqual(node1 << node2, node2) def test_node_to_nodes(self): - with Diagram(name=self.name, show=False): + with Diagram(name=os.path.join(self.name, 'node_to_nodes'), show=False): node1 = Node("node1") nodes = [Node("node2"), Node("node3")] self.assertEqual(node1 - nodes, nodes) @@ -66,7 +71,7 @@ class DiagramTest(unittest.TestCase): self.assertEqual(node1 << nodes, nodes) def test_nodes_to_node(self): - with Diagram(name=self.name, show=False): + with Diagram(name=os.path.join(self.name, 'nodes_to_node'), show=False): node1 = Node("node1") nodes = [Node("node2"), Node("node3")] self.assertEqual(nodes - node1, node1) @@ -88,38 +93,38 @@ class DiagramTest(unittest.TestCase): class ClusterTest(unittest.TestCase): def setUp(self): - self.name = "test" + self.name = "cluster_test" def tearDown(self): setdiagram(None) setcluster(None) # Only some tests generate the image file. try: - os.remove(self.name + ".png") - except FileNotFoundError: + shutil.rmtree(self.name) + except OSError: pass def test_validate_direction(self): # Normal directions. for dir in ("TB", "BT", "LR", "RL"): - with Diagram(name=self.name, show=False): + with Diagram(name=os.path.join(self.name, 'validate_direction'), show=False): Cluster(direction=dir) # Invalid directions. for dir in ("BR", "TL", "Unknown"): with self.assertRaises(ValueError): - with Diagram(name=self.name, show=False): + with Diagram(name=os.path.join(self.name, 'validate_direction'), show=False): Cluster(direction=dir) def test_with_global_context(self): - with Diagram(name=self.name, show=False): + with Diagram(name=os.path.join(self.name, 'with_global_context'), show=False): self.assertIsNone(getcluster()) with Cluster(): self.assertIsNotNone(getcluster()) self.assertIsNone(getcluster()) def test_with_nested_cluster(self): - with Diagram(name=self.name, show=False): + with Diagram(name=os.path.join(self.name, 'with_nested_cluster'), show=False): self.assertIsNone(getcluster()) with Cluster() as c1: self.assertEqual(c1, getcluster()) @@ -134,7 +139,7 @@ class ClusterTest(unittest.TestCase): Node("node") def test_node_to_node(self): - with Diagram(name=self.name, show=False): + with Diagram(name=os.path.join(self.name, 'node_to_node'), show=False): with Cluster(): node1 = Node("node1") node2 = Node("node2") @@ -143,7 +148,7 @@ class ClusterTest(unittest.TestCase): self.assertEqual(node1 << node2, node2) def test_node_to_nodes(self): - with Diagram(name=self.name, show=False): + with Diagram(name=os.path.join(self.name, 'node_to_nodes'), show=False): with Cluster(): node1 = Node("node1") nodes = [Node("node2"), Node("node3")] @@ -152,10 +157,111 @@ class ClusterTest(unittest.TestCase): self.assertEqual(node1 << nodes, nodes) def test_nodes_to_node(self): - with Diagram(name=self.name, show=False): + with Diagram(name=os.path.join(self.name, 'nodes_to_node'), show=False): with Cluster(): node1 = Node("node1") nodes = [Node("node2"), Node("node3")] self.assertEqual(nodes - node1, node1) self.assertEqual(nodes >> node1, node1) self.assertEqual(nodes << node1, node1) + + +class EdgeTest(unittest.TestCase): + def setUp(self): + self.name = "edge_test" + + def tearDown(self): + setdiagram(None) + setcluster(None) + # Only some tests generate the image file. + try: + shutil.rmtree(self.name) + except OSError: + pass + + def test_node_to_node(self): + with Diagram(name=os.path.join(self.name, 'node_to_node'), show=False): + node1 = Node("node1") + node2 = Node("node2") + self.assertEqual(node1 - Edge(color='red') - node2, node2) + + def test_node_to_nodes(self): + with Diagram(name=os.path.join(self.name, 'node_to_nodes'), show=False): + with Cluster(): + node1 = Node("node1") + nodes = [Node("node2"), Node("node3")] + self.assertEqual(node1 - Edge(color='red') - nodes, nodes) + + def test_nodes_to_node(self): + with Diagram(name=os.path.join(self.name, 'nodes_to_node'), show=False): + with Cluster(): + node1 = Node("node1") + nodes = [Node("node2"), Node("node3")] + self.assertEqual(nodes - Edge(color='red') - node1, node1) + + def test_nodes_to_node_with_additional_attributes(self): + with Diagram(name=os.path.join(self.name, 'nodes_to_node_with_additional_attributes'), show=False): + with Cluster(): + node1 = Node("node1") + nodes = [Node("node2"), Node("node3")] + self.assertEqual(nodes - Edge(color='red') - Edge(color='green') - node1, node1) + + def test_node_to_node_with_attributes(self): + with Diagram(name=os.path.join(self.name, 'node_to_node_with_attributes'), show=False): + with Cluster(): + node1 = Node("node1") + node2 = Node("node2") + self.assertEqual(node1 << Edge(color='red', label='1.1') << node2, node2) + self.assertEqual(node1 >> Edge(color='green', label='1.2') >> node2, node2) + self.assertEqual(node1 << Edge(color='blue', label='1.3') >> node2, node2) + + def test_node_to_node_with_additional_attributes(self): + with Diagram(name=os.path.join(self.name, 'node_to_node_with_additional_attributes'), show=False): + with Cluster(): + node1 = Node("node1") + node2 = Node("node2") + self.assertEqual(node1 << Edge(color='red', label='2.1') << Edge(color='blue') << node2, node2) + self.assertEqual(node1 >> Edge(color='green', label='2.2') >> Edge(color='red') >> node2, node2) + self.assertEqual(node1 << Edge(color='blue', label='2.3') >> Edge(color='black') >> node2, node2) + + def test_nodes_to_node_with_attributes_loop(self): + with Diagram(name=os.path.join(self.name, 'nodes_to_node_with_attributes_loop'), show=False): + with Cluster(): + node = Node("node") + self.assertEqual(node >> Edge(color='red', label='3.1') >> node, node) + self.assertEqual(node << Edge(color='green', label='3.2') << node, node) + self.assertEqual(node >> Edge(color='blue', label='3.3') << node, node) + self.assertEqual(node << Edge(color='pink', label='3.4') >> node, node) + + def test_nodes_to_node_with_attributes_bothdirectional(self): + with Diagram(name=os.path.join(self.name, 'nodes_to_node_with_attributes_bothdirectional'), show=False) as diagram: + with Cluster(): + node1 = Node("node1") + nodes = [Node("node2"), Node("node3")] + self.assertEqual(nodes << Edge(color='green', label='4') >> node1, node1) + + def test_nodes_to_node_with_attributes_bidirectional(self): + with Diagram(name=os.path.join(self.name, 'nodes_to_node_with_attributes_bidirectional'), show=False): + with Cluster(): + node1 = Node("node1") + nodes = [Node("node2"), Node("node3")] + self.assertEqual(nodes << Edge(color='blue', label='5') >> node1, node1) + + def test_nodes_to_node_with_attributes_onedirectional(self): + with Diagram(name=os.path.join(self.name, 'nodes_to_node_with_attributes_onedirectional'), show=False): + with Cluster(): + node1 = Node("node1") + nodes = [Node("node2"), Node("node3")] + self.assertEqual(nodes >> Edge(color='red', label='6.1') >> node1, node1) + self.assertEqual(nodes << Edge(color='green', label='6.2') << node1, node1) + + def test_nodes_to_node_with_additional_attributes_directional(self): + with Diagram(name=os.path.join(self.name, 'nodes_to_node_with_additional_attributes_directional'), show=False): + with Cluster(): + node1 = Node("node1") + nodes = [Node("node2"), Node("node3")] + self.assertEqual(nodes + >> Edge(color='red', label='6.1') >> Edge(color='blue', label='6.2') >> node1, node1) + self.assertEqual(nodes + << Edge(color='green', label='6.3') << Edge(color='pink', label='6.4') << node1, node1) +