Skip to content

Commit 608bdbe

Browse files
committed
feat: propagate event kwargs to invoke handlers
Forward keyword arguments from the triggering event to invoke handlers: - Plain callables receive them via SignatureAdapter dependency injection - IInvoke handlers receive them via ctx.kwargs This allows patterns like sm.send("start", file_name="config.json") where the invoke handler reads file_name as a parameter.
1 parent 3ff527a commit 608bdbe

5 files changed

Lines changed: 148 additions & 20 deletions

File tree

docs/invoke.md

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ matching any invoke completion for that state regardless of the specific invoke
174174
## IInvoke protocol
175175

176176
For advanced use cases, implement the `IInvoke` protocol. This gives you access to
177-
the `InvokeContext` — with the invoke ID, cancellation signal, and a reference to the
178-
parent machine:
177+
the `InvokeContext` — with the invoke ID, cancellation signal, event kwargs, and a
178+
reference to the parent machine:
179179

180180
```py
181181
>>> from statemachine.invoke import IInvoke, InvokeContext
@@ -188,6 +188,7 @@ parent machine:
188188
... # ctx.cancelled — threading.Event, set when state exits
189189
... # ctx.send — send events to parent machine
190190
... # ctx.machine — reference to parent machine
191+
... # ctx.kwargs — keyword arguments from the triggering event
191192
... path = ctx.machine.file_path
192193
... return Path(path).read_text()
193194
...
@@ -263,6 +264,46 @@ Events from cancelled invocations are silently ignored.
263264

264265
```
265266

267+
## Event data propagation
268+
269+
When a state with invoke handlers is entered via an event, the keyword arguments from
270+
that event are forwarded to the invoke handlers. Plain callables receive them via
271+
{ref}`SignatureAdapter <actions>` dependency injection; `IInvoke` handlers receive them
272+
via `ctx.kwargs`:
273+
274+
```py
275+
>>> config_file = Path(tempfile.mktemp(suffix=".json"))
276+
>>> _ = config_file.write_text('{"debug": true}')
277+
278+
>>> class ConfigByName(StateChart):
279+
... idle = State(initial=True)
280+
... loading = State()
281+
... ready = State(final=True)
282+
... start = idle.to(loading)
283+
... done_invoke_loading = loading.to(ready)
284+
...
285+
... def on_invoke_loading(self, file_name=None, **kwargs):
286+
... """file_name comes from send('start', file_name=...)."""
287+
... return json.loads(Path(file_name).read_text())
288+
...
289+
... def on_enter_ready(self, data=None, **kwargs):
290+
... self.config = data
291+
292+
>>> sm = ConfigByName()
293+
>>> sm.send("start", file_name=str(config_file))
294+
>>> time.sleep(0.2)
295+
296+
>>> "ready" in sm.configuration_values
297+
True
298+
>>> sm.config
299+
{'debug': True}
300+
301+
>>> config_file.unlink()
302+
303+
```
304+
305+
For initial states (entered automatically, not via an event), `kwargs` is empty.
306+
266307
## Error handling
267308

268309
If an invoke handler raises an exception, `error.execution` is sent to the machine's

statemachine/engines/async_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ async def _enter_states( # noqa: C901
248248

249249
# Mark state for invocation if it has invoke callbacks registered
250250
if target.invoke.key in self.sm._callbacks:
251-
self._invoke_manager.mark_for_invoke(target)
251+
self._invoke_manager.mark_for_invoke(target, trigger_data.kwargs)
252252

253253
# Handle final states
254254
if target.final:

statemachine/engines/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def _enter_states( # noqa: C901
653653

654654
# Mark state for invocation if it has invoke callbacks registered
655655
if target.invoke.key in self.sm._callbacks:
656-
self._invoke_manager.mark_for_invoke(target)
656+
self._invoke_manager.mark_for_invoke(target, trigger_data.kwargs)
657657

658658
# Handle final states
659659
if target.final:

statemachine/invoke.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
from typing import Callable
2020
from typing import Dict
2121
from typing import List
22+
from typing import Tuple
2223
from typing import runtime_checkable
2324

24-
from .orderedset import OrderedSet
25-
2625
try:
2726
from typing import Protocol
2827
except ImportError: # pragma: no cover
@@ -150,6 +149,9 @@ class InvokeContext:
150149
cancelled: threading.Event = field(default_factory=threading.Event)
151150
"""Set when the owning state is exited; handlers should check this to stop early."""
152151

152+
kwargs: dict = field(default_factory=dict)
153+
"""Keyword arguments from the event that triggered the state entry."""
154+
153155

154156
@dataclass
155157
class Invocation:
@@ -263,24 +265,32 @@ class InvokeManager:
263265
def __init__(self, engine: "BaseEngine"):
264266
self._engine = engine
265267
self._active: Dict[str, Invocation] = {}
266-
self._states_to_invoke: "OrderedSet[State]" = OrderedSet()
268+
self._pending: "List[Tuple[State, dict]]" = []
267269

268270
@property
269271
def sm(self) -> "StateChart":
270272
return self._engine.sm
271273

272274
# --- Engine hooks ---
273275

274-
def mark_for_invoke(self, state: "State"):
275-
"""Called by ``_enter_states()`` after entering a state with invoke callbacks."""
276-
self._states_to_invoke.add(state)
276+
def mark_for_invoke(self, state: "State", event_kwargs: "dict | None" = None):
277+
"""Called by ``_enter_states()`` after entering a state with invoke callbacks.
278+
279+
Args:
280+
state: The state that was entered.
281+
event_kwargs: Keyword arguments from the event that triggered the
282+
state entry. These are forwarded to invoke handlers via
283+
dependency injection (plain callables) and ``InvokeContext.kwargs``
284+
(IInvoke handlers).
285+
"""
286+
self._pending.append((state, event_kwargs or {}))
277287

278288
def cancel_for_state(self, state: "State"):
279289
"""Called by ``_exit_states()`` before exiting a state."""
280290
for inv_id, inv in list(self._active.items()):
281291
if inv.state_id == state.id and not inv.terminated:
282292
self._cancel(inv_id)
283-
self._states_to_invoke.discard(state)
293+
self._pending = [(s, kw) for s, kw in self._pending if s is not state]
284294

285295
def cancel_all(self):
286296
"""Cancel all active invocations."""
@@ -291,17 +301,20 @@ def cancel_all(self):
291301

292302
def spawn_pending_sync(self):
293303
"""Spawn invoke handlers for all states marked for invocation (sync engine)."""
294-
for state in sorted(self._states_to_invoke, key=lambda s: s.document_order):
304+
pending = sorted(self._pending, key=lambda p: p[0].document_order)
305+
self._pending.clear()
306+
for state, event_kwargs in pending:
295307
self.sm._callbacks.visit(
296308
state.invoke.key,
297309
self._spawn_one_sync,
298310
state=state,
311+
event_kwargs=event_kwargs,
299312
)
300-
self._states_to_invoke.clear()
301313

302314
def _spawn_one_sync(self, callback: "CallbackWrapper", **kwargs):
303315
state: "State" = kwargs["state"]
304-
ctx = self._make_context(state)
316+
event_kwargs: dict = kwargs.get("event_kwargs", {})
317+
ctx = self._make_context(state, event_kwargs)
305318
invocation = Invocation(invokeid=ctx.invokeid, state_id=state.id, ctx=ctx)
306319

307320
# Use meta.func to find the original (unwrapped) handler; the callback
@@ -329,7 +342,7 @@ def _run_sync_handler(
329342
if handler is not None:
330343
result = handler.run(ctx)
331344
else:
332-
result = callback.call(ctx=ctx, machine=ctx.machine)
345+
result = callback.call(ctx=ctx, machine=ctx.machine, **ctx.kwargs)
333346
if not ctx.cancelled.is_set():
334347
self.sm.send(
335348
f"done.invoke.{ctx.invokeid}",
@@ -346,17 +359,20 @@ def _run_sync_handler(
346359

347360
async def spawn_pending_async(self):
348361
"""Spawn invoke handlers for all states marked for invocation (async engine)."""
349-
for state in sorted(self._states_to_invoke, key=lambda s: s.document_order):
362+
pending = sorted(self._pending, key=lambda p: p[0].document_order)
363+
self._pending.clear()
364+
for state, event_kwargs in pending:
350365
await self.sm._callbacks.async_visit(
351366
state.invoke.key,
352367
self._spawn_one_async,
353368
state=state,
369+
event_kwargs=event_kwargs,
354370
)
355-
self._states_to_invoke.clear()
356371

357372
def _spawn_one_async(self, callback: "CallbackWrapper", **kwargs):
358373
state: "State" = kwargs["state"]
359-
ctx = self._make_context(state)
374+
event_kwargs: dict = kwargs.get("event_kwargs", {})
375+
ctx = self._make_context(state, event_kwargs)
360376
invocation = Invocation(invokeid=ctx.invokeid, state_id=state.id, ctx=ctx)
361377

362378
handler = self._resolve_handler(callback.meta.func)
@@ -382,7 +398,7 @@ async def _run_async_handler(
382398
result = await loop.run_in_executor(None, handler.run, ctx)
383399
else:
384400
result = await loop.run_in_executor(
385-
None, lambda: callback.call(ctx=ctx, machine=ctx.machine)
401+
None, lambda: callback.call(ctx=ctx, machine=ctx.machine, **ctx.kwargs)
386402
)
387403
if not ctx.cancelled.is_set():
388404
self.sm.send(
@@ -418,13 +434,14 @@ def _cancel(self, invokeid: str):
418434

419435
# --- Helpers ---
420436

421-
def _make_context(self, state: "State") -> InvokeContext:
437+
def _make_context(self, state: "State", event_kwargs: "dict | None" = None) -> InvokeContext:
422438
invokeid = f"{state.id}.{uuid.uuid4().hex[:8]}"
423439
return InvokeContext(
424440
invokeid=invokeid,
425441
state_id=state.id,
426442
send=self.sm.send,
427443
machine=self.sm,
444+
kwargs=event_kwargs or {},
428445
)
429446

430447
@staticmethod

tests/test_invoke.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,76 @@ def on_enter_ready(self, data=None, **kwargs):
489489
assert all_results[1] == [2]
490490

491491

492+
class TestInvokeEventKwargs:
493+
"""Event kwargs from send() are forwarded to invoke handlers."""
494+
495+
async def test_plain_callable_receives_event_kwargs(self, sm_runner):
496+
"""Plain callable invoke handler receives event kwargs via SignatureAdapter."""
497+
received = []
498+
499+
class SM(StateChart):
500+
idle = State(initial=True)
501+
loading = State()
502+
ready = State(final=True)
503+
start = idle.to(loading)
504+
done_invoke_loading = loading.to(ready)
505+
506+
def on_invoke_loading(self, file_name=None, **kwargs):
507+
received.append(file_name)
508+
return f"loaded:{file_name}"
509+
510+
def on_enter_ready(self, data=None, **kwargs):
511+
received.append(data)
512+
513+
sm = await sm_runner.start(SM)
514+
await sm_runner.send(sm, "start", file_name="config.json")
515+
await sm_runner.sleep(0.15)
516+
await sm_runner.processing_loop(sm)
517+
518+
assert "ready" in sm.configuration_values
519+
assert received == ["config.json", "loaded:config.json"]
520+
521+
async def test_iinvoke_handler_receives_event_kwargs_via_ctx(self, sm_runner):
522+
"""IInvoke handler receives event kwargs via ctx.kwargs."""
523+
received = []
524+
525+
class FileLoader:
526+
def run(self, ctx: InvokeContext):
527+
received.append(ctx.kwargs.get("file_name"))
528+
return f"loaded:{ctx.kwargs['file_name']}"
529+
530+
class SM(StateChart):
531+
idle = State(initial=True)
532+
loading = State(invoke=FileLoader)
533+
ready = State(final=True)
534+
start = idle.to(loading)
535+
done_invoke_loading = loading.to(ready)
536+
537+
def on_enter_ready(self, data=None, **kwargs):
538+
received.append(data)
539+
540+
sm = await sm_runner.start(SM)
541+
await sm_runner.send(sm, "start", file_name="data.csv")
542+
await sm_runner.sleep(0.15)
543+
await sm_runner.processing_loop(sm)
544+
545+
assert "ready" in sm.configuration_values
546+
assert received == ["data.csv", "loaded:data.csv"]
547+
548+
async def test_initial_state_invoke_has_empty_kwargs(self, sm_runner):
549+
"""Invoke on initial state gets empty kwargs (no triggering event)."""
550+
551+
class SM(StateChart):
552+
loading = State(initial=True, invoke=lambda: 42)
553+
ready = State(final=True)
554+
done_invoke_loading = loading.to(ready)
555+
556+
sm = await sm_runner.start(SM)
557+
await sm_runner.sleep(0.15)
558+
await sm_runner.processing_loop(sm)
559+
assert "ready" in sm.configuration_values
560+
561+
492562
class TestInvokeNotTriggeredOnNonInvokeState:
493563
"""States without invoke handlers should not be affected."""
494564

0 commit comments

Comments
 (0)