Skip to content

Commit d5f134d

Browse files
committed
fix: re-enqueue initial event when deserializing async state machine (#544)
When an async SM is pickled/deepcopied (e.g. via multiprocessing), the engine queue is not preserved. __setstate__ recreated the engine but never called start(), so the __initial__ event was never enqueued and activate_initial_state() would fail with InvalidStateValue. Closes #544
1 parent 37e6c1a commit d5f134d

3 files changed

Lines changed: 54 additions & 5 deletions

File tree

statemachine/statemachine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __setstate__(self, state):
147147
self._register_callbacks([])
148148
self.add_listener(*listeners.keys())
149149
self._engine = self._get_engine(rtc)
150+
self._engine.start()
150151

151152
def _get_initial_state(self):
152153
initial_state_value = self.start_value if self.start_value else self.initial_state.value

tests/examples/user_machine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414
from dataclasses import dataclass
1515
from enum import Enum
1616

17-
from statemachine.states import States
18-
1917
from statemachine import State
2018
from statemachine import StateMachine
19+
from statemachine.states import States
2120

2221

2322
class UserStatus(str, Enum):
@@ -88,7 +87,7 @@ class UserStatusMachine(StateMachine):
8887
def on_signup(self, token: str):
8988
if token == "":
9089
raise ValueError("Token is required")
91-
self.model.verified = True
90+
self.model.verified = True # type: ignore[union-attr]
9291

9392

9493
class UserExperienceMachine(StateMachine):

tests/test_copy.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
import asyncio
12
import logging
23
import pickle
34
from copy import deepcopy
45
from enum import Enum
56
from enum import auto
67

78
import pytest
8-
from statemachine.exceptions import TransitionNotAllowed
9-
from statemachine.states import States
109

1110
from statemachine import State
1211
from statemachine import StateMachine
12+
from statemachine.exceptions import TransitionNotAllowed
13+
from statemachine.states import States
1314

1415
logger = logging.getLogger(__name__)
1516
DEBUG = logging.DEBUG
@@ -181,3 +182,51 @@ def test_copy_with_custom_init_and_vars(copy_method):
181182
assert sm2.custom == 1
182183
assert sm2.value == [1, 2, 3]
183184
assert sm2.current_state == MyStateMachine.started
185+
186+
187+
class AsyncTrafficLightMachine(StateMachine):
188+
green = State(initial=True)
189+
yellow = State()
190+
red = State()
191+
192+
cycle = green.to(yellow) | yellow.to(red) | red.to(green)
193+
194+
async def on_enter_state(self, target):
195+
pass
196+
197+
198+
def test_copy_async_statemachine_before_activation(copy_method):
199+
"""Regression test for issue #544: async SM fails after pickle/deepcopy.
200+
201+
When an async SM is copied before activation, the copy must still be
202+
activatable because ``__setstate__`` re-enqueues the ``__initial__`` event.
203+
"""
204+
sm = AsyncTrafficLightMachine()
205+
sm_copy = copy_method(sm)
206+
207+
async def verify():
208+
await sm_copy.activate_initial_state()
209+
assert sm_copy.current_state == AsyncTrafficLightMachine.green
210+
await sm_copy.cycle()
211+
assert sm_copy.current_state == AsyncTrafficLightMachine.yellow
212+
213+
asyncio.run(verify())
214+
215+
216+
def test_copy_async_statemachine_after_activation(copy_method):
217+
"""Copying an async SM that is already activated preserves its current state."""
218+
219+
async def setup_and_verify():
220+
sm = AsyncTrafficLightMachine()
221+
await sm.activate_initial_state()
222+
await sm.cycle()
223+
assert sm.current_state == AsyncTrafficLightMachine.yellow
224+
225+
sm_copy = copy_method(sm)
226+
227+
await sm_copy.activate_initial_state()
228+
assert sm_copy.current_state == AsyncTrafficLightMachine.yellow
229+
await sm_copy.cycle()
230+
assert sm_copy.current_state == AsyncTrafficLightMachine.red
231+
232+
asyncio.run(setup_and_verify())

0 commit comments

Comments
 (0)