Skip to content

Commit fbccc5b

Browse files
committed
feat: add done_invoke_ naming convention, State(invoke=...) API, and unit tests
- factory.py: done_invoke_<state> registers both done_invoke_<state> and done.invoke.<state> event names (same pattern as error_ and done_state_) - state.py: _normalize_invoke() converts StateChart classes to InvokeConfig - invoke.py: fix async spawn to handle sync child engines (await-guard on activate_initial_state and send results) - tests/test_invoke.py: unit tests for normalization, naming convention, spawn/cancel lifecycle, and multiple invocations using sm_runner fixture
1 parent 5267b95 commit fbccc5b

4 files changed

Lines changed: 217 additions & 6 deletions

File tree

statemachine/factory.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ def add_from_attributes(cls, attrs): # noqa: C901
251251
elif key.startswith("done_state_"):
252252
suffix = key[len("done_state_") :]
253253
event_id = f"{key} done.state.{suffix}"
254+
elif key.startswith("done_invoke_"):
255+
suffix = key[len("done_invoke_") :]
256+
event_id = f"{key} done.invoke.{suffix}"
254257
cls.add_event(event=Event(transitions=value, id=event_id, name=key))
255258
elif isinstance(value, (Event,)):
256259
if value._has_real_id:
@@ -260,6 +263,9 @@ def add_from_attributes(cls, attrs): # noqa: C901
260263
elif key.startswith("done_state_"):
261264
suffix = key[len("done_state_") :]
262265
event_id = f"{key} done.state.{suffix}"
266+
elif key.startswith("done_invoke_"):
267+
suffix = key[len("done_invoke_") :]
268+
event_id = f"{key} done.invoke.{suffix}"
263269
else:
264270
event_id = key
265271
new_event = Event(

statemachine/invoke.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,18 @@ async def spawn_async(self, state: "State", config: InvokeConfig, trigger_data:
148148

149149
async def run_child():
150150
try:
151-
await child_sm.activate_initial_state()
151+
result = child_sm.activate_initial_state()
152+
if result is not None:
153+
await result
152154

153155
while not child_sm.is_terminated and not invocation.cancelled:
154156
await asyncio.sleep(0.01)
155157

156158
if not invocation.cancelled:
157159
logger.debug("Child %s terminated, sending done.invoke.%s", invokeid, invokeid)
158-
self.sm.send(f"done.invoke.{invokeid}", invokeid=invokeid)
160+
result = self.sm.send(f"done.invoke.{invokeid}", invokeid=invokeid)
161+
if result is not None:
162+
await result
159163
except Exception:
160164
logger.exception("Error in child session %s", invokeid)
161165

statemachine/state.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,30 @@ def __init__(
239239

240240
@staticmethod
241241
def _normalize_invoke(invoke: Any) -> list:
242-
"""Normalize the invoke parameter into a list of InvokeConfig."""
242+
"""Normalize the invoke parameter into a list of InvokeConfig.
243+
244+
Accepts:
245+
- None → []
246+
- A StateChart subclass → [InvokeConfig(child_class=...)]
247+
- An InvokeConfig → [config]
248+
- A list of the above → [InvokeConfig(...), ...]
249+
"""
250+
from .invoke import InvokeConfig
251+
243252
if invoke is None:
244253
return []
245-
if isinstance(invoke, list):
246-
return invoke
247-
return [invoke]
254+
255+
items = invoke if isinstance(invoke, list) else [invoke]
256+
result = []
257+
for item in items:
258+
if isinstance(item, InvokeConfig):
259+
result.append(item)
260+
elif isinstance(item, type):
261+
# Assume it's a StateChart subclass
262+
result.append(InvokeConfig(child_class=item))
263+
else:
264+
result.append(item)
265+
return result
248266

249267
def _init_states(self):
250268
for state in self.states:

tests/test_invoke.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""Unit tests for the Python invoke API (State(invoke=...) and done_invoke_ convention)."""
2+
3+
import asyncio
4+
import time
5+
6+
import pytest
7+
from statemachine.invoke import InvokeConfig
8+
9+
from statemachine import State
10+
from statemachine import StateChart
11+
12+
13+
class ChildMachine(StateChart):
14+
"""A simple child that immediately reaches its final state via eventless transition."""
15+
16+
s1 = State(initial=True)
17+
done = State(final=True)
18+
19+
# Eventless transition — fires automatically on entry
20+
s1.to(done)
21+
22+
23+
class SlowChild(StateChart):
24+
"""A child that waits for an external event before finishing."""
25+
26+
waiting = State(initial=True)
27+
done = State(final=True)
28+
29+
go = waiting.to(done)
30+
31+
32+
# --- Helpers ---
33+
34+
35+
def _wait_for_state(sm, state_id, timeout=3.0, poll=0.02):
36+
"""Poll until the state machine reaches the given state or timeout."""
37+
deadline = time.monotonic() + timeout
38+
while time.monotonic() < deadline:
39+
if state_id in sm.configuration_values:
40+
return True
41+
time.sleep(poll)
42+
return False
43+
44+
45+
async def _async_wait_for_state(sm, state_id, timeout=3.0, poll=0.02):
46+
"""Async poll until the state machine reaches the given state or timeout."""
47+
deadline = time.monotonic() + timeout
48+
while time.monotonic() < deadline:
49+
if state_id in sm.configuration_values:
50+
return True
51+
await asyncio.sleep(poll)
52+
return False
53+
54+
55+
async def _wait_for(sm_runner, sm, state_id, timeout=3.0):
56+
"""Wait for state using sync or async polling based on runner."""
57+
if sm_runner.is_async:
58+
return await _async_wait_for_state(sm, state_id, timeout=timeout)
59+
return _wait_for_state(sm, state_id, timeout=timeout)
60+
61+
62+
async def _sleep(sm_runner, seconds=0.1):
63+
if sm_runner.is_async:
64+
await asyncio.sleep(seconds)
65+
else:
66+
time.sleep(seconds)
67+
68+
69+
# --- Tests ---
70+
71+
72+
class TestStateInvokeNormalization:
73+
"""Test that State(invoke=...) normalizes various input types."""
74+
75+
def test_invoke_none(self):
76+
s = State(invoke=None)
77+
assert s.invocations == []
78+
79+
def test_invoke_class_becomes_invoke_config(self):
80+
s = State(invoke=ChildMachine)
81+
assert len(s.invocations) == 1
82+
assert isinstance(s.invocations[0], InvokeConfig)
83+
assert s.invocations[0].child_class is ChildMachine
84+
85+
def test_invoke_config_passthrough(self):
86+
config = InvokeConfig(child_class=ChildMachine)
87+
s = State(invoke=config)
88+
assert len(s.invocations) == 1
89+
assert s.invocations[0] is config
90+
91+
def test_invoke_list_of_classes(self):
92+
s = State(invoke=[ChildMachine, SlowChild])
93+
assert len(s.invocations) == 2
94+
assert s.invocations[0].child_class is ChildMachine
95+
assert s.invocations[1].child_class is SlowChild
96+
97+
98+
class TestDoneInvokeNamingConvention:
99+
"""Test that done_invoke_<state> registers the done.invoke.<state> event."""
100+
101+
def test_done_invoke_event_registered(self):
102+
class Parent(StateChart):
103+
active = State(initial=True, invoke=ChildMachine)
104+
completed = State(final=True)
105+
106+
done_invoke_active = active.to(completed)
107+
108+
event_ids = {e.id for e in Parent.events}
109+
assert any("done.invoke.active" in eid for eid in event_ids)
110+
111+
@pytest.mark.asyncio()
112+
async def test_done_invoke_transition_fires(self, sm_runner):
113+
"""When the child terminates, done.invoke.<state> fires on the parent."""
114+
115+
class Parent(StateChart):
116+
active = State(initial=True, invoke=ChildMachine)
117+
completed = State(final=True)
118+
119+
done_invoke_active = active.to(completed)
120+
121+
sm = await sm_runner.start(Parent)
122+
reached = await _wait_for(sm_runner, sm, "completed")
123+
assert reached, f"Parent did not reach 'completed'. Config: {sm.configuration_values}"
124+
125+
126+
class TestInvokeSpawnAndCancel:
127+
"""Test basic invoke lifecycle: spawn on entry, cancel on exit."""
128+
129+
@pytest.mark.asyncio()
130+
async def test_child_spawned_on_state_entry(self, sm_runner):
131+
"""When entering a state with invoke, a child session should be spawned."""
132+
133+
class Parent(StateChart):
134+
idle = State(initial=True)
135+
active = State(invoke=SlowChild)
136+
done = State(final=True)
137+
138+
start = idle.to(active)
139+
done_invoke_active = active.to(done)
140+
141+
sm = await sm_runner.start(Parent)
142+
await sm_runner.send(sm, "start")
143+
await _sleep(sm_runner)
144+
145+
active = sm._engine.invoke_manager._active
146+
assert len(active) > 0, "No active invocations found after entering invoke state"
147+
148+
@pytest.mark.asyncio()
149+
async def test_child_cancelled_on_state_exit(self, sm_runner):
150+
"""When the parent exits the invoking state, the child is cancelled."""
151+
152+
class Parent(StateChart):
153+
idle = State(initial=True)
154+
active = State(invoke=SlowChild)
155+
other = State()
156+
done = State(final=True)
157+
158+
start = idle.to(active)
159+
abort = active.to(other)
160+
finish = other.to(done)
161+
162+
sm = await sm_runner.start(Parent)
163+
await sm_runner.send(sm, "start")
164+
await _sleep(sm_runner)
165+
166+
active_before = list(sm._engine.invoke_manager._active.values())
167+
assert len(active_before) > 0
168+
169+
await sm_runner.send(sm, "abort")
170+
assert active_before[0].cancelled, "Child was not cancelled after exiting invoke state"
171+
172+
173+
class TestMultipleInvocations:
174+
"""Test invoking multiple children from the same state."""
175+
176+
def test_multiple_invoke_configs(self):
177+
class Parent(StateChart):
178+
active = State(initial=True, invoke=[ChildMachine, ChildMachine])
179+
done = State(final=True)
180+
181+
done_invoke_active = active.to(done)
182+
183+
assert len(Parent.states_map["active"].invocations) == 2

0 commit comments

Comments
 (0)