Fix unit tests

pull/823/head
Bruno Meneguello 5 years ago
parent 1067cf50db
commit 854fff5947

@ -17,10 +17,7 @@ __cluster = contextvars.ContextVar("cluster")
def getdiagram(): def getdiagram():
try: return __diagram.get()
return __diagram.get()
except LookupError:
raise EnvironmentError("Global diagrams context not set up")
def setdiagram(diagram): def setdiagram(diagram):
@ -53,7 +50,7 @@ class _Cluster:
try: try:
self._parent = getcluster() or getdiagram() self._parent = getcluster() or getdiagram()
except EnvironmentError: except LookupError:
self._parent = None self._parent = None
@ -348,6 +345,8 @@ class Node(_Cluster):
self._attrs.update(attrs) self._attrs.update(attrs)
# If a node is in the cluster context, add it to cluster. # If a node is in the cluster context, add it to cluster.
if not self._parent:
raise EnvironmentError("Global diagrams context not set up")
self._parent.node(self) self._parent.node(self)
def __enter__(self): def __enter__(self):

@ -135,20 +135,20 @@ class ClusterTest(unittest.TestCase):
def test_with_global_context(self): def test_with_global_context(self):
with Diagram(name=os.path.join(self.name, "with_global_context"), show=False): with Diagram(name=os.path.join(self.name, "with_global_context"), show=False):
self.assertIsNone(getcluster()) self.assertEqual(getcluster(), getdiagram())
with Cluster(): with Cluster():
self.assertIsNotNone(getcluster()) self.assertNotEqual(getcluster(), getdiagram())
self.assertIsNone(getcluster()) self.assertEqual(getcluster(), getdiagram())
def test_with_nested_cluster(self): def test_with_nested_cluster(self):
with Diagram(name=os.path.join(self.name, "with_nested_cluster"), show=False): with Diagram(name=os.path.join(self.name, "with_nested_cluster"), show=False):
self.assertIsNone(getcluster()) self.assertEqual(getcluster(), getdiagram())
with Cluster() as c1: with Cluster() as c1:
self.assertEqual(c1, getcluster()) self.assertEqual(c1, getcluster())
with Cluster() as c2: with Cluster() as c2:
self.assertEqual(c2, getcluster()) self.assertEqual(c2, getcluster())
self.assertEqual(c1, getcluster()) self.assertEqual(c1, getcluster())
self.assertIsNone(getcluster()) self.assertEqual(getcluster(), getdiagram())
def test_node_not_in_diagram(self): def test_node_not_in_diagram(self):
# Node must be belong to a diagrams. # Node must be belong to a diagrams.

Loading…
Cancel
Save