Skip to content

Commit d6ebb61

Browse files
committed
test: add thread safety stress tests and document thread safety
Add TestThreadSafety with stress tests exercising real contention — multiple threads sending events to the same SM simultaneously via barriers. Tests verify no lost events, state consistency, correct callback counts, and safe concurrent reads. Document thread safety guarantees in docs/processing_model.md (linking to atomic_configuration_update for transient None behavior) and AGENTS.md, noting the PriorityQueue-based event queue must remain thread-safe.
1 parent d62d650 commit d6ebb61

3 files changed

Lines changed: 237 additions & 0 deletions

File tree

AGENTS.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ current event.
7777
- `on_error_execution()` works via naming convention but **only** when a transition for
7878
`error.execution` is declared — it is NOT a generic callback.
7979

80+
### Thread safety
81+
82+
- The sync engine is **thread-safe**: multiple threads can send events to the same SM instance
83+
concurrently. The processing loop uses a `threading.Lock` so at most one thread executes
84+
transitions at a time. Event queues use `PriorityQueue` (stdlib, thread-safe).
85+
- **Do not replace `PriorityQueue`** with non-thread-safe alternatives (e.g., `collections.deque`,
86+
plain `list`) — this would break concurrent access guarantees.
87+
- Stress tests in `tests/test_threading.py::TestThreadSafety` exercise real contention with
88+
barriers and multiple sender threads. Any change to queue or locking internals must pass these.
89+
8090
### Invoke (`<invoke>`)
8191

8292
- `invoke.py``InvokeManager` on the engine manages the lifecycle: `mark_for_invoke()`,

docs/processing_model.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,50 @@ The machine starts, enters `trying` (attempt 1), and the eventless
315315
self-transition keeps firing as long as `can_retry()` returns `True`. Once
316316
the limit is reached, the second eventless transition fires — all within a
317317
single macrostep triggered by initialization.
318+
319+
320+
(thread-safety)=
321+
322+
## Thread safety
323+
324+
State machines are **thread-safe** for concurrent event sending. Multiple threads
325+
can call `send()` or trigger events on the **same state machine instance**
326+
simultaneously — the engine guarantees correct behavior through its internal
327+
locking mechanism.
328+
329+
### How it works
330+
331+
The processing loop uses a non-blocking lock (`threading.Lock`). When a thread
332+
sends an event:
333+
334+
1. The event is placed on the **external queue** (backed by a thread-safe
335+
`PriorityQueue` from the standard library).
336+
2. If no other thread is currently running the processing loop, the sending
337+
thread acquires the lock and processes all queued events.
338+
3. If another thread is already processing, the event is simply enqueued and
339+
will be processed by the thread that holds the lock — no event is lost.
340+
341+
This means that **at most one thread executes transitions at any time**, preserving
342+
the run-to-completion (RTC) guarantee while allowing safe concurrent access.
343+
344+
### What is safe
345+
346+
- **Multiple threads sending events** to the same state machine instance.
347+
- **Reading state** (`current_state_value`, `configuration`) from any thread
348+
while events are being processed. Note that transient `None` values may be
349+
observed for `current_state_value` during configuration updates when using
350+
[`atomic_configuration_update`](behaviour.md#atomic_configuration_update) `= False`
351+
(the default on `StateChart`, SCXML-compliant). With `atomic_configuration_update = True`
352+
(the default on `StateMachine`), the configuration is updated atomically at
353+
the end of the microstep, so `None` is not observed.
354+
- **Invoke handlers** running in background threads or thread executors
355+
communicate with the parent machine via the thread-safe event queue.
356+
357+
### What to avoid
358+
359+
- **Do not share a state machine instance across threads with the async engine**
360+
unless you ensure only one event loop drives the machine. The async engine is
361+
designed for `asyncio` concurrency, not thread-based concurrency.
362+
- **Callbacks execute in the processing thread**, not in the thread that sent
363+
the event. Design callbacks accordingly (e.g., use locks if they access
364+
shared external state).

tests/test_threading.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import threading
22
import time
3+
from collections import Counter
34

5+
import pytest
46
from statemachine.state import State
57
from statemachine.statemachine import StateChart
68

@@ -115,6 +117,184 @@ def __init__(self, name):
115117
assert c3.fsm.statuses_history == ["c3.green", "c3.green", "c3.green", "c3.yellow"]
116118

117119

120+
class TestThreadSafety:
121+
"""Stress tests for concurrent access to a single state machine instance.
122+
123+
These tests exercise real contention: multiple threads sending events to the
124+
same SM simultaneously, synchronized via barriers to maximize overlap.
125+
"""
126+
127+
@pytest.fixture()
128+
def cycling_machine(self):
129+
class CyclingMachine(StateChart):
130+
s1 = State(initial=True)
131+
s2 = State()
132+
s3 = State()
133+
cycle = s1.to(s2) | s2.to(s3) | s3.to(s1)
134+
135+
return CyclingMachine()
136+
137+
@pytest.mark.parametrize("num_threads", [4, 8])
138+
def test_concurrent_sends_no_lost_events(self, cycling_machine, num_threads):
139+
"""All events sent concurrently must be processed — none lost."""
140+
events_per_thread = 300
141+
total_events = num_threads * events_per_thread
142+
barrier = threading.Barrier(num_threads)
143+
errors = []
144+
145+
def sender():
146+
try:
147+
barrier.wait(timeout=5)
148+
for _ in range(events_per_thread):
149+
cycling_machine.send("cycle")
150+
except Exception as e:
151+
errors.append(e)
152+
153+
threads = [threading.Thread(target=sender) for _ in range(num_threads)]
154+
for t in threads:
155+
t.start()
156+
for t in threads:
157+
t.join(timeout=30)
158+
159+
assert not errors, f"Thread errors: {errors}"
160+
161+
# The machine cycles s1→s2→s3→s1. After N total cycle events starting
162+
# from s1, the state is determined by (N % 3).
163+
expected_states = {0: "s1", 1: "s2", 2: "s3"}
164+
expected = expected_states[total_events % 3]
165+
assert cycling_machine.current_state_value == expected
166+
167+
def test_concurrent_sends_state_consistency(self, cycling_machine):
168+
"""State must always be one of the valid states, never corrupted."""
169+
valid_values = {"s1", "s2", "s3"}
170+
num_threads = 6
171+
events_per_thread = 500
172+
barrier = threading.Barrier(num_threads + 1) # +1 for observer
173+
stop_event = threading.Event()
174+
observed_values = []
175+
errors = []
176+
177+
def sender():
178+
try:
179+
barrier.wait(timeout=5)
180+
for _ in range(events_per_thread):
181+
cycling_machine.send("cycle")
182+
except Exception as e:
183+
errors.append(e)
184+
185+
def observer():
186+
barrier.wait(timeout=5)
187+
while not stop_event.is_set():
188+
val = cycling_machine.current_state_value
189+
observed_values.append(val)
190+
191+
threads = [threading.Thread(target=sender) for _ in range(num_threads)]
192+
obs_thread = threading.Thread(target=observer)
193+
194+
for t in threads:
195+
t.start()
196+
obs_thread.start()
197+
198+
for t in threads:
199+
t.join(timeout=30)
200+
201+
stop_event.set()
202+
obs_thread.join(timeout=5)
203+
204+
assert not errors, f"Thread errors: {errors}"
205+
# None may appear transiently during configuration updates — that's expected.
206+
invalid = [v for v in observed_values if v not in valid_values and v is not None]
207+
assert not invalid, f"Observed invalid state values: {set(invalid)}"
208+
assert len(observed_values) > 100, "Observer didn't collect enough samples"
209+
210+
def test_concurrent_sends_with_callbacks(self):
211+
"""Callbacks must execute exactly once per transition under contention."""
212+
call_log = []
213+
lock = threading.Lock()
214+
215+
class CallbackMachine(StateChart):
216+
s1 = State(initial=True)
217+
s2 = State()
218+
go = s1.to(s2) | s2.to(s1)
219+
220+
def on_enter_s2(self):
221+
with lock:
222+
call_log.append("enter_s2")
223+
224+
def on_enter_s1(self):
225+
with lock:
226+
call_log.append("enter_s1")
227+
228+
sm = CallbackMachine()
229+
num_threads = 4
230+
events_per_thread = 200
231+
total_events = num_threads * events_per_thread
232+
barrier = threading.Barrier(num_threads)
233+
errors = []
234+
235+
def sender():
236+
try:
237+
barrier.wait(timeout=5)
238+
for _ in range(events_per_thread):
239+
sm.send("go")
240+
except Exception as e:
241+
errors.append(e)
242+
243+
threads = [threading.Thread(target=sender) for _ in range(num_threads)]
244+
for t in threads:
245+
t.start()
246+
for t in threads:
247+
t.join(timeout=30)
248+
249+
assert not errors, f"Thread errors: {errors}"
250+
251+
# Each transition fires exactly one on_enter callback.
252+
# +1 because initial activation also fires on_enter_s1.
253+
counts = Counter(call_log)
254+
total_callbacks = counts["enter_s1"] + counts["enter_s2"]
255+
assert total_callbacks == total_events + 1
256+
257+
def test_concurrent_send_and_read_configuration(self, cycling_machine):
258+
"""Reading configuration while events are being processed must not raise."""
259+
num_senders = 4
260+
events_per_sender = 300
261+
barrier = threading.Barrier(num_senders + 1)
262+
stop_event = threading.Event()
263+
errors = []
264+
265+
def sender():
266+
try:
267+
barrier.wait(timeout=5)
268+
for _ in range(events_per_sender):
269+
cycling_machine.send("cycle")
270+
except Exception as e:
271+
errors.append(e)
272+
273+
def reader():
274+
barrier.wait(timeout=5)
275+
while not stop_event.is_set():
276+
try:
277+
_ = cycling_machine.configuration
278+
_ = cycling_machine.current_state_value
279+
_ = list(cycling_machine.configuration)
280+
except Exception as e:
281+
errors.append(e)
282+
283+
threads = [threading.Thread(target=sender) for _ in range(num_senders)]
284+
reader_thread = threading.Thread(target=reader)
285+
286+
for t in threads:
287+
t.start()
288+
reader_thread.start()
289+
290+
for t in threads:
291+
t.join(timeout=30)
292+
stop_event.set()
293+
reader_thread.join(timeout=5)
294+
295+
assert not errors, f"Thread errors: {errors}"
296+
297+
118298
async def test_regression_443_with_modifications_for_async_engine():
119299
"""
120300
Test for https://github.com/fgmacedo/python-statemachine/issues/443

0 commit comments

Comments
 (0)