Skip to content

Commit 95ac619

Browse files
committed
fix: await async predicates in condition expressions (#535)
The boolean expression combinators (custom_not, custom_and, custom_or, build_custom_operator) called predicates synchronously. When predicates were async, they returned unawaited coroutine objects which are always truthy, causing `not` to always return False, `and` to skip evaluation, and `or` to short-circuit incorrectly. Each combinator now checks `isawaitable()` on predicate results and returns a coroutine when needed, which CallbackWrapper.__call__ already knows how to await. Closes #535
1 parent d5f134d commit 95ac619

3 files changed

Lines changed: 235 additions & 9 deletions

File tree

statemachine/spec_parser.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import operator
33
import re
44
from functools import reduce
5+
from inspect import isawaitable
56
from typing import Callable
67

78
replacements = {"!": "not ", "^": " and ", "v": " or "}
@@ -33,8 +34,15 @@ def match_func(match):
3334

3435

3536
def custom_not(predicate: Callable) -> Callable:
36-
def decorated(*args, **kwargs) -> bool:
37-
return not predicate(*args, **kwargs)
37+
def decorated(*args, **kwargs):
38+
result = predicate(*args, **kwargs)
39+
if isawaitable(result):
40+
41+
async def _negate():
42+
return not await result
43+
44+
return _negate()
45+
return not result
3846

3947
decorated.__name__ = f"not({predicate.__name__})"
4048
unique_key = getattr(predicate, "unique_key", "")
@@ -43,17 +51,53 @@ def decorated(*args, **kwargs) -> bool:
4351

4452

4553
def custom_and(left: Callable, right: Callable) -> Callable:
46-
def decorated(*args, **kwargs) -> bool:
47-
return left(*args, **kwargs) and right(*args, **kwargs) # type: ignore[no-any-return]
54+
def decorated(*args, **kwargs):
55+
left_result = left(*args, **kwargs)
56+
if isawaitable(left_result):
57+
58+
async def _async_and():
59+
lr = await left_result
60+
if not lr:
61+
return lr
62+
rr = right(*args, **kwargs)
63+
if isawaitable(rr):
64+
return await rr
65+
return rr
66+
67+
return _async_and()
68+
if not left_result:
69+
return left_result
70+
right_result = right(*args, **kwargs)
71+
if isawaitable(right_result):
72+
return right_result
73+
return right_result
4874

4975
decorated.__name__ = f"({left.__name__} and {right.__name__})"
5076
decorated.unique_key = _unique_key(left, right, "and") # type: ignore[attr-defined]
5177
return decorated
5278

5379

5480
def custom_or(left: Callable, right: Callable) -> Callable:
55-
def decorated(*args, **kwargs) -> bool:
56-
return left(*args, **kwargs) or right(*args, **kwargs) # type: ignore[no-any-return]
81+
def decorated(*args, **kwargs):
82+
left_result = left(*args, **kwargs)
83+
if isawaitable(left_result):
84+
85+
async def _async_or():
86+
lr = await left_result
87+
if lr:
88+
return lr
89+
rr = right(*args, **kwargs)
90+
if isawaitable(rr):
91+
return await rr
92+
return rr
93+
94+
return _async_or()
95+
if left_result:
96+
return left_result
97+
right_result = right(*args, **kwargs)
98+
if isawaitable(right_result):
99+
return right_result
100+
return right_result
57101

58102
decorated.__name__ = f"({left.__name__} or {right.__name__})"
59103
decorated.unique_key = _unique_key(left, right, "or") # type: ignore[attr-defined]
@@ -73,8 +117,18 @@ def build_custom_operator(operator) -> Callable:
73117
operator_repr = comparison_repr[operator]
74118

75119
def custom_comparator(left: Callable, right: Callable) -> Callable:
76-
def decorated(*args, **kwargs) -> bool:
77-
return bool(operator(left(*args, **kwargs), right(*args, **kwargs)))
120+
def decorated(*args, **kwargs):
121+
left_result = left(*args, **kwargs)
122+
right_result = right(*args, **kwargs)
123+
if isawaitable(left_result) or isawaitable(right_result):
124+
125+
async def _async_compare():
126+
lr = (await left_result) if isawaitable(left_result) else left_result
127+
rr = (await right_result) if isawaitable(right_result) else right_result
128+
return bool(operator(lr, rr))
129+
130+
return _async_compare()
131+
return bool(operator(left_result, right_result))
78132

79133
decorated.__name__ = f"({left.__name__} {operator_repr} {right.__name__})"
80134
decorated.unique_key = _unique_key(left, right, operator_repr) # type: ignore[attr-defined]

tests/test_async.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import re
22

33
import pytest
4-
from statemachine.exceptions import InvalidStateValue
54

65
from statemachine import State
76
from statemachine import StateMachine
7+
from statemachine.exceptions import InvalidStateValue
88

99

1010
@pytest.fixture()
@@ -96,6 +96,86 @@ def test_async_state_from_sync_context(async_order_control_machine):
9696
assert sm.completed.is_active
9797

9898

99+
class AsyncConditionExpressionMachine(StateMachine):
100+
"""Regression test for issue #535: async conditions in boolean expressions."""
101+
102+
s1 = State(initial=True)
103+
104+
go_not = s1.to.itself(cond="not cond_false")
105+
go_and = s1.to.itself(cond="cond_true and cond_true")
106+
go_or_false_first = s1.to.itself(cond="cond_false or cond_true")
107+
go_or_true_first = s1.to.itself(cond="cond_true or cond_false")
108+
go_blocked = s1.to.itself(cond="not cond_true")
109+
go_and_blocked = s1.to.itself(cond="cond_true and cond_false")
110+
go_or_both_false = s1.to.itself(cond="cond_false or cond_false")
111+
112+
async def cond_true(self):
113+
return True
114+
115+
async def cond_false(self):
116+
return False
117+
118+
async def on_enter_state(self, target):
119+
pass
120+
121+
122+
async def test_async_condition_not(recwarn):
123+
"""Issue #535: 'not cond_false' should allow the transition."""
124+
sm = AsyncConditionExpressionMachine()
125+
await sm.activate_initial_state()
126+
await sm.go_not()
127+
assert sm.s1.is_active
128+
assert not any("coroutine" in str(w.message) for w in recwarn.list)
129+
130+
131+
async def test_async_condition_not_blocked():
132+
"""Issue #535: 'not cond_true' should block the transition."""
133+
sm = AsyncConditionExpressionMachine()
134+
await sm.activate_initial_state()
135+
with pytest.raises(sm.TransitionNotAllowed):
136+
await sm.go_blocked()
137+
138+
139+
async def test_async_condition_and():
140+
"""Issue #535: 'cond_true and cond_true' should allow the transition."""
141+
sm = AsyncConditionExpressionMachine()
142+
await sm.activate_initial_state()
143+
await sm.go_and()
144+
assert sm.s1.is_active
145+
146+
147+
async def test_async_condition_and_blocked():
148+
"""Issue #535: 'cond_true and cond_false' should block the transition."""
149+
sm = AsyncConditionExpressionMachine()
150+
await sm.activate_initial_state()
151+
with pytest.raises(sm.TransitionNotAllowed):
152+
await sm.go_and_blocked()
153+
154+
155+
async def test_async_condition_or_false_first():
156+
"""Issue #535: 'cond_false or cond_true' should allow the transition."""
157+
sm = AsyncConditionExpressionMachine()
158+
await sm.activate_initial_state()
159+
await sm.go_or_false_first()
160+
assert sm.s1.is_active
161+
162+
163+
async def test_async_condition_or_true_first():
164+
"""'cond_true or cond_false' should allow the transition."""
165+
sm = AsyncConditionExpressionMachine()
166+
await sm.activate_initial_state()
167+
await sm.go_or_true_first()
168+
assert sm.s1.is_active
169+
170+
171+
async def test_async_condition_or_both_false():
172+
"""'cond_false or cond_false' should block the transition."""
173+
sm = AsyncConditionExpressionMachine()
174+
await sm.activate_initial_state()
175+
with pytest.raises(sm.TransitionNotAllowed):
176+
await sm.go_or_both_false()
177+
178+
99179
async def test_async_state_should_be_initialized(async_order_control_machine):
100180
"""Test that the state machine is initialized before any event is triggered
101181

tests/test_spec_parser.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
12
import logging
23

34
import pytest
5+
46
from statemachine.spec_parser import operator_mapping
57
from statemachine.spec_parser import parse_boolean_expr
68

@@ -247,6 +249,96 @@ def variable_hook(var_name):
247249
("height > 1 and height < 2", True, ["height"]),
248250
],
249251
)
252+
def async_variable_hook(var_name):
253+
"""Variable hook that returns async callables, for testing issue #535."""
254+
values = {
255+
"cond_true": True,
256+
"cond_false": False,
257+
"val_10": 10,
258+
"val_20": 20,
259+
}
260+
261+
async def decorated(*args, **kwargs):
262+
return values.get(var_name, False)
263+
264+
decorated.__name__ = var_name
265+
return decorated
266+
267+
268+
@pytest.mark.parametrize(
269+
("expression", "expected"),
270+
[
271+
("not cond_false", True),
272+
("not cond_true", False),
273+
("cond_true and cond_true", True),
274+
("cond_true and cond_false", False),
275+
("cond_false and cond_true", False),
276+
("cond_false or cond_true", True),
277+
("cond_true or cond_false", True),
278+
("cond_false or cond_false", False),
279+
("not cond_false and cond_true", True),
280+
("not (cond_true and cond_false)", True),
281+
("not (cond_false or cond_false)", True),
282+
("cond_true and not cond_false", True),
283+
("val_10 == 10", True),
284+
("val_10 != 20", True),
285+
("val_10 < val_20", True),
286+
("val_20 > val_10", True),
287+
("val_10 >= 10", True),
288+
("val_10 <= val_20", True),
289+
],
290+
)
291+
def test_async_expressions(expression, expected):
292+
"""Issue #535: condition expressions with async predicates must await results."""
293+
parsed_expr = parse_boolean_expr(expression, async_variable_hook, operator_mapping)
294+
result = parsed_expr()
295+
assert asyncio.iscoroutine(result), f"Expected coroutine for async expression: {expression}"
296+
assert asyncio.run(result) is expected, expression
297+
298+
299+
def mixed_variable_hook(var_name):
300+
"""Variable hook where some vars are sync and some are async."""
301+
sync_values = {"sync_true": True, "sync_false": False, "sync_10": 10}
302+
async_values = {"async_true": True, "async_false": False, "async_20": 20}
303+
304+
if var_name in async_values:
305+
306+
async def async_decorated(*args, **kwargs):
307+
return async_values[var_name]
308+
309+
async_decorated.__name__ = var_name
310+
return async_decorated
311+
312+
def sync_decorated(*args, **kwargs):
313+
return sync_values.get(var_name, False)
314+
315+
sync_decorated.__name__ = var_name
316+
return sync_decorated
317+
318+
319+
@pytest.mark.parametrize(
320+
("expression", "expected"),
321+
[
322+
# async left, sync right
323+
("async_true and sync_true", True),
324+
("async_false or sync_true", True),
325+
# sync left, async right
326+
("sync_true and async_true", True),
327+
("sync_false or async_true", True),
328+
("sync_true and async_false", False),
329+
("sync_false or async_false", False),
330+
],
331+
)
332+
def test_mixed_sync_async_expressions(expression, expected):
333+
"""Expressions mixing sync and async predicates must handle both correctly."""
334+
parsed_expr = parse_boolean_expr(expression, mixed_variable_hook, operator_mapping)
335+
result = parsed_expr()
336+
if asyncio.iscoroutine(result):
337+
assert asyncio.run(result) is expected, expression
338+
else:
339+
assert result is expected, expression
340+
341+
250342
@pytest.mark.xfail(reason="TODO: Optimize so that expressios are evaluated only once")
251343
def test_should_evaluate_values_only_once(expression, expected, caplog, hooks_called):
252344
caplog.set_level(logging.DEBUG, logger="tests")

0 commit comments

Comments
 (0)