1919from typing import Callable
2020from typing import Dict
2121from typing import List
22+ from typing import Tuple
2223from typing import runtime_checkable
2324
24- from .orderedset import OrderedSet
25-
2625try :
2726 from typing import Protocol
2827except ImportError : # pragma: no cover
@@ -150,6 +149,9 @@ class InvokeContext:
150149 cancelled : threading .Event = field (default_factory = threading .Event )
151150 """Set when the owning state is exited; handlers should check this to stop early."""
152151
152+ kwargs : dict = field (default_factory = dict )
153+ """Keyword arguments from the event that triggered the state entry."""
154+
153155
154156@dataclass
155157class Invocation :
@@ -263,24 +265,32 @@ class InvokeManager:
263265 def __init__ (self , engine : "BaseEngine" ):
264266 self ._engine = engine
265267 self ._active : Dict [str , Invocation ] = {}
266- self ._states_to_invoke : "OrderedSet[ State] " = OrderedSet ()
268+ self ._pending : "List[Tuple[ State, dict]] " = []
267269
268270 @property
269271 def sm (self ) -> "StateChart" :
270272 return self ._engine .sm
271273
272274 # --- Engine hooks ---
273275
274- def mark_for_invoke (self , state : "State" ):
275- """Called by ``_enter_states()`` after entering a state with invoke callbacks."""
276- self ._states_to_invoke .add (state )
276+ def mark_for_invoke (self , state : "State" , event_kwargs : "dict | None" = None ):
277+ """Called by ``_enter_states()`` after entering a state with invoke callbacks.
278+
279+ Args:
280+ state: The state that was entered.
281+ event_kwargs: Keyword arguments from the event that triggered the
282+ state entry. These are forwarded to invoke handlers via
283+ dependency injection (plain callables) and ``InvokeContext.kwargs``
284+ (IInvoke handlers).
285+ """
286+ self ._pending .append ((state , event_kwargs or {}))
277287
278288 def cancel_for_state (self , state : "State" ):
279289 """Called by ``_exit_states()`` before exiting a state."""
280290 for inv_id , inv in list (self ._active .items ()):
281291 if inv .state_id == state .id and not inv .terminated :
282292 self ._cancel (inv_id )
283- self ._states_to_invoke . discard ( state )
293+ self ._pending = [( s , kw ) for s , kw in self . _pending if s is not state ]
284294
285295 def cancel_all (self ):
286296 """Cancel all active invocations."""
@@ -291,17 +301,20 @@ def cancel_all(self):
291301
292302 def spawn_pending_sync (self ):
293303 """Spawn invoke handlers for all states marked for invocation (sync engine)."""
294- for state in sorted (self ._states_to_invoke , key = lambda s : s .document_order ):
304+ pending = sorted (self ._pending , key = lambda p : p [0 ].document_order )
305+ self ._pending .clear ()
306+ for state , event_kwargs in pending :
295307 self .sm ._callbacks .visit (
296308 state .invoke .key ,
297309 self ._spawn_one_sync ,
298310 state = state ,
311+ event_kwargs = event_kwargs ,
299312 )
300- self ._states_to_invoke .clear ()
301313
302314 def _spawn_one_sync (self , callback : "CallbackWrapper" , ** kwargs ):
303315 state : "State" = kwargs ["state" ]
304- ctx = self ._make_context (state )
316+ event_kwargs : dict = kwargs .get ("event_kwargs" , {})
317+ ctx = self ._make_context (state , event_kwargs )
305318 invocation = Invocation (invokeid = ctx .invokeid , state_id = state .id , ctx = ctx )
306319
307320 # Use meta.func to find the original (unwrapped) handler; the callback
@@ -329,7 +342,7 @@ def _run_sync_handler(
329342 if handler is not None :
330343 result = handler .run (ctx )
331344 else :
332- result = callback .call (ctx = ctx , machine = ctx .machine )
345+ result = callback .call (ctx = ctx , machine = ctx .machine , ** ctx . kwargs )
333346 if not ctx .cancelled .is_set ():
334347 self .sm .send (
335348 f"done.invoke.{ ctx .invokeid } " ,
@@ -346,17 +359,20 @@ def _run_sync_handler(
346359
347360 async def spawn_pending_async (self ):
348361 """Spawn invoke handlers for all states marked for invocation (async engine)."""
349- for state in sorted (self ._states_to_invoke , key = lambda s : s .document_order ):
362+ pending = sorted (self ._pending , key = lambda p : p [0 ].document_order )
363+ self ._pending .clear ()
364+ for state , event_kwargs in pending :
350365 await self .sm ._callbacks .async_visit (
351366 state .invoke .key ,
352367 self ._spawn_one_async ,
353368 state = state ,
369+ event_kwargs = event_kwargs ,
354370 )
355- self ._states_to_invoke .clear ()
356371
357372 def _spawn_one_async (self , callback : "CallbackWrapper" , ** kwargs ):
358373 state : "State" = kwargs ["state" ]
359- ctx = self ._make_context (state )
374+ event_kwargs : dict = kwargs .get ("event_kwargs" , {})
375+ ctx = self ._make_context (state , event_kwargs )
360376 invocation = Invocation (invokeid = ctx .invokeid , state_id = state .id , ctx = ctx )
361377
362378 handler = self ._resolve_handler (callback .meta .func )
@@ -382,7 +398,7 @@ async def _run_async_handler(
382398 result = await loop .run_in_executor (None , handler .run , ctx )
383399 else :
384400 result = await loop .run_in_executor (
385- None , lambda : callback .call (ctx = ctx , machine = ctx .machine )
401+ None , lambda : callback .call (ctx = ctx , machine = ctx .machine , ** ctx . kwargs )
386402 )
387403 if not ctx .cancelled .is_set ():
388404 self .sm .send (
@@ -418,13 +434,14 @@ def _cancel(self, invokeid: str):
418434
419435 # --- Helpers ---
420436
421- def _make_context (self , state : "State" ) -> InvokeContext :
437+ def _make_context (self , state : "State" , event_kwargs : "dict | None" = None ) -> InvokeContext :
422438 invokeid = f"{ state .id } .{ uuid .uuid4 ().hex [:8 ]} "
423439 return InvokeContext (
424440 invokeid = invokeid ,
425441 state_id = state .id ,
426442 send = self .sm .send ,
427443 machine = self .sm ,
444+ kwargs = event_kwargs or {},
428445 )
429446
430447 @staticmethod
0 commit comments