|
1 | 1 | import threading |
2 | 2 | import time |
| 3 | +from collections import Counter |
3 | 4 |
|
| 5 | +import pytest |
4 | 6 | from statemachine.state import State |
5 | 7 | from statemachine.statemachine import StateChart |
6 | 8 |
|
@@ -115,6 +117,184 @@ def __init__(self, name): |
115 | 117 | assert c3.fsm.statuses_history == ["c3.green", "c3.green", "c3.green", "c3.yellow"] |
116 | 118 |
|
117 | 119 |
|
| 120 | +class TestThreadSafety: |
| 121 | + """Stress tests for concurrent access to a single state machine instance. |
| 122 | +
|
| 123 | + These tests exercise real contention: multiple threads sending events to the |
| 124 | + same SM simultaneously, synchronized via barriers to maximize overlap. |
| 125 | + """ |
| 126 | + |
| 127 | + @pytest.fixture() |
| 128 | + def cycling_machine(self): |
| 129 | + class CyclingMachine(StateChart): |
| 130 | + s1 = State(initial=True) |
| 131 | + s2 = State() |
| 132 | + s3 = State() |
| 133 | + cycle = s1.to(s2) | s2.to(s3) | s3.to(s1) |
| 134 | + |
| 135 | + return CyclingMachine() |
| 136 | + |
| 137 | + @pytest.mark.parametrize("num_threads", [4, 8]) |
| 138 | + def test_concurrent_sends_no_lost_events(self, cycling_machine, num_threads): |
| 139 | + """All events sent concurrently must be processed — none lost.""" |
| 140 | + events_per_thread = 300 |
| 141 | + total_events = num_threads * events_per_thread |
| 142 | + barrier = threading.Barrier(num_threads) |
| 143 | + errors = [] |
| 144 | + |
| 145 | + def sender(): |
| 146 | + try: |
| 147 | + barrier.wait(timeout=5) |
| 148 | + for _ in range(events_per_thread): |
| 149 | + cycling_machine.send("cycle") |
| 150 | + except Exception as e: |
| 151 | + errors.append(e) |
| 152 | + |
| 153 | + threads = [threading.Thread(target=sender) for _ in range(num_threads)] |
| 154 | + for t in threads: |
| 155 | + t.start() |
| 156 | + for t in threads: |
| 157 | + t.join(timeout=30) |
| 158 | + |
| 159 | + assert not errors, f"Thread errors: {errors}" |
| 160 | + |
| 161 | + # The machine cycles s1→s2→s3→s1. After N total cycle events starting |
| 162 | + # from s1, the state is determined by (N % 3). |
| 163 | + expected_states = {0: "s1", 1: "s2", 2: "s3"} |
| 164 | + expected = expected_states[total_events % 3] |
| 165 | + assert cycling_machine.current_state_value == expected |
| 166 | + |
| 167 | + def test_concurrent_sends_state_consistency(self, cycling_machine): |
| 168 | + """State must always be one of the valid states, never corrupted.""" |
| 169 | + valid_values = {"s1", "s2", "s3"} |
| 170 | + num_threads = 6 |
| 171 | + events_per_thread = 500 |
| 172 | + barrier = threading.Barrier(num_threads + 1) # +1 for observer |
| 173 | + stop_event = threading.Event() |
| 174 | + observed_values = [] |
| 175 | + errors = [] |
| 176 | + |
| 177 | + def sender(): |
| 178 | + try: |
| 179 | + barrier.wait(timeout=5) |
| 180 | + for _ in range(events_per_thread): |
| 181 | + cycling_machine.send("cycle") |
| 182 | + except Exception as e: |
| 183 | + errors.append(e) |
| 184 | + |
| 185 | + def observer(): |
| 186 | + barrier.wait(timeout=5) |
| 187 | + while not stop_event.is_set(): |
| 188 | + val = cycling_machine.current_state_value |
| 189 | + observed_values.append(val) |
| 190 | + |
| 191 | + threads = [threading.Thread(target=sender) for _ in range(num_threads)] |
| 192 | + obs_thread = threading.Thread(target=observer) |
| 193 | + |
| 194 | + for t in threads: |
| 195 | + t.start() |
| 196 | + obs_thread.start() |
| 197 | + |
| 198 | + for t in threads: |
| 199 | + t.join(timeout=30) |
| 200 | + |
| 201 | + stop_event.set() |
| 202 | + obs_thread.join(timeout=5) |
| 203 | + |
| 204 | + assert not errors, f"Thread errors: {errors}" |
| 205 | + # None may appear transiently during configuration updates — that's expected. |
| 206 | + invalid = [v for v in observed_values if v not in valid_values and v is not None] |
| 207 | + assert not invalid, f"Observed invalid state values: {set(invalid)}" |
| 208 | + assert len(observed_values) > 100, "Observer didn't collect enough samples" |
| 209 | + |
| 210 | + def test_concurrent_sends_with_callbacks(self): |
| 211 | + """Callbacks must execute exactly once per transition under contention.""" |
| 212 | + call_log = [] |
| 213 | + lock = threading.Lock() |
| 214 | + |
| 215 | + class CallbackMachine(StateChart): |
| 216 | + s1 = State(initial=True) |
| 217 | + s2 = State() |
| 218 | + go = s1.to(s2) | s2.to(s1) |
| 219 | + |
| 220 | + def on_enter_s2(self): |
| 221 | + with lock: |
| 222 | + call_log.append("enter_s2") |
| 223 | + |
| 224 | + def on_enter_s1(self): |
| 225 | + with lock: |
| 226 | + call_log.append("enter_s1") |
| 227 | + |
| 228 | + sm = CallbackMachine() |
| 229 | + num_threads = 4 |
| 230 | + events_per_thread = 200 |
| 231 | + total_events = num_threads * events_per_thread |
| 232 | + barrier = threading.Barrier(num_threads) |
| 233 | + errors = [] |
| 234 | + |
| 235 | + def sender(): |
| 236 | + try: |
| 237 | + barrier.wait(timeout=5) |
| 238 | + for _ in range(events_per_thread): |
| 239 | + sm.send("go") |
| 240 | + except Exception as e: |
| 241 | + errors.append(e) |
| 242 | + |
| 243 | + threads = [threading.Thread(target=sender) for _ in range(num_threads)] |
| 244 | + for t in threads: |
| 245 | + t.start() |
| 246 | + for t in threads: |
| 247 | + t.join(timeout=30) |
| 248 | + |
| 249 | + assert not errors, f"Thread errors: {errors}" |
| 250 | + |
| 251 | + # Each transition fires exactly one on_enter callback. |
| 252 | + # +1 because initial activation also fires on_enter_s1. |
| 253 | + counts = Counter(call_log) |
| 254 | + total_callbacks = counts["enter_s1"] + counts["enter_s2"] |
| 255 | + assert total_callbacks == total_events + 1 |
| 256 | + |
| 257 | + def test_concurrent_send_and_read_configuration(self, cycling_machine): |
| 258 | + """Reading configuration while events are being processed must not raise.""" |
| 259 | + num_senders = 4 |
| 260 | + events_per_sender = 300 |
| 261 | + barrier = threading.Barrier(num_senders + 1) |
| 262 | + stop_event = threading.Event() |
| 263 | + errors = [] |
| 264 | + |
| 265 | + def sender(): |
| 266 | + try: |
| 267 | + barrier.wait(timeout=5) |
| 268 | + for _ in range(events_per_sender): |
| 269 | + cycling_machine.send("cycle") |
| 270 | + except Exception as e: |
| 271 | + errors.append(e) |
| 272 | + |
| 273 | + def reader(): |
| 274 | + barrier.wait(timeout=5) |
| 275 | + while not stop_event.is_set(): |
| 276 | + try: |
| 277 | + _ = cycling_machine.configuration |
| 278 | + _ = cycling_machine.current_state_value |
| 279 | + _ = list(cycling_machine.configuration) |
| 280 | + except Exception as e: |
| 281 | + errors.append(e) |
| 282 | + |
| 283 | + threads = [threading.Thread(target=sender) for _ in range(num_senders)] |
| 284 | + reader_thread = threading.Thread(target=reader) |
| 285 | + |
| 286 | + for t in threads: |
| 287 | + t.start() |
| 288 | + reader_thread.start() |
| 289 | + |
| 290 | + for t in threads: |
| 291 | + t.join(timeout=30) |
| 292 | + stop_event.set() |
| 293 | + reader_thread.join(timeout=5) |
| 294 | + |
| 295 | + assert not errors, f"Thread errors: {errors}" |
| 296 | + |
| 297 | + |
118 | 298 | async def test_regression_443_with_modifications_for_async_engine(): |
119 | 299 | """ |
120 | 300 | Test for https://github.com/fgmacedo/python-statemachine/issues/443 |
|
0 commit comments