@@ -296,12 +296,37 @@ def quickchart_write_svg(sm: StateChart, path: str):
296296 f .write (data )
297297
298298
299+ def _find_sm_class (module ):
300+ """Find the first StateChart subclass defined in a module."""
301+ import inspect
302+
303+ for _name , obj in inspect .getmembers (module , inspect .isclass ):
304+ if (
305+ issubclass (obj , StateChart )
306+ and obj is not StateChart
307+ and obj .__module__ == module .__name__
308+ ):
309+ return obj
310+ return None
311+
312+
299313def import_sm (qualname ):
300314 module_name , class_name = qualname .rsplit ("." , 1 )
301315 module = importlib .import_module (module_name )
302316 smclass = getattr (module , class_name , None )
303- if not smclass or not issubclass (smclass , StateChart ):
304- raise ValueError (f"{ class_name } is not a subclass of StateMachine" )
317+ if smclass is not None and isinstance (smclass , type ) and issubclass (smclass , StateChart ):
318+ return smclass
319+
320+ # qualname may be a module path without a class name — try importing
321+ # the whole path as a module and find the first StateChart subclass.
322+ try :
323+ module = importlib .import_module (qualname )
324+ except ImportError as err :
325+ raise ValueError (f"{ class_name } is not a subclass of StateMachine" ) from err
326+
327+ smclass = _find_sm_class (module )
328+ if smclass is None :
329+ raise ValueError (f"No StateMachine subclass found in module { qualname !r} " )
305330
306331 return smclass
307332
0 commit comments