Skip to content

Commit 991f4ed

Browse files
committed
fix: diagram CLI accepts module-only paths (auto-discovers SM class)
1 parent 9d7f114 commit 991f4ed

2 files changed

Lines changed: 44 additions & 2 deletions

File tree

statemachine/contrib/diagram.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,37 @@ def quickchart_write_svg(sm: StateChart, path: str):
296296
f.write(data)
297297

298298

299+
def _find_sm_class(module):
300+
"""Find the first StateChart subclass defined in a module."""
301+
import inspect
302+
303+
for _name, obj in inspect.getmembers(module, inspect.isclass):
304+
if (
305+
issubclass(obj, StateChart)
306+
and obj is not StateChart
307+
and obj.__module__ == module.__name__
308+
):
309+
return obj
310+
return None
311+
312+
299313
def import_sm(qualname):
300314
module_name, class_name = qualname.rsplit(".", 1)
301315
module = importlib.import_module(module_name)
302316
smclass = getattr(module, class_name, None)
303-
if not smclass or not issubclass(smclass, StateChart):
304-
raise ValueError(f"{class_name} is not a subclass of StateMachine")
317+
if smclass is not None and isinstance(smclass, type) and issubclass(smclass, StateChart):
318+
return smclass
319+
320+
# qualname may be a module path without a class name — try importing
321+
# the whole path as a module and find the first StateChart subclass.
322+
try:
323+
module = importlib.import_module(qualname)
324+
except ImportError as err:
325+
raise ValueError(f"{class_name} is not a subclass of StateMachine") from err
326+
327+
smclass = _find_sm_class(module)
328+
if smclass is None:
329+
raise ValueError(f"No StateMachine subclass found in module {qualname!r}")
305330

306331
return smclass
307332

tests/test_contrib_diagram.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ def test_generate_image(self, tmp_path):
6666
'<?xml version="1.0" encoding="UTF-8" standalone="no"?>\n<!DOCTYPE svg'
6767
)
6868

69+
def test_generate_image_from_module_path(self, tmp_path):
70+
"""Accept a module path without the class name and auto-discover the SM class."""
71+
out = tmp_path / "sm.svg"
72+
73+
main(["tests.examples.traffic_light_machine", str(out)])
74+
75+
assert out.read_text().startswith(
76+
'<?xml version="1.0" encoding="UTF-8" standalone="no"?>\n<!DOCTYPE svg'
77+
)
78+
6979
def test_generate_complain_about_bad_sm_path(self, capsys, tmp_path):
7080
out = tmp_path / "sm.svg"
7181

@@ -78,6 +88,13 @@ def test_generate_complain_about_bad_sm_path(self, capsys, tmp_path):
7888
]
7989
)
8090

91+
def test_generate_complain_about_module_without_sm(self, tmp_path):
92+
out = tmp_path / "sm.svg"
93+
94+
expected_error = "No StateMachine subclass found in module"
95+
with pytest.raises(ValueError, match=expected_error):
96+
main(["tests.examples", str(out)])
97+
8198

8299
class TestQuickChart:
83100
@contextmanager

0 commit comments

Comments
 (0)