Skip to content

Commit 3198871

Browse files
committed
fix: improve invoke test coverage and fix async cancel flakiness
- Fix _InvokeCallableWrapper.on_cancel() to handle class handlers that haven't been instantiated yet (early return instead of calling unbound method) - Replace blocking threading.Event.wait() with sm_runner.sleep() in async tests to avoid freezing the event loop - Add tests for cancel_all(), cancel of terminated invocations, on_cancel exception suppression, StateChartInvoker.on_cancel(), normalize_invoke_callbacks edge cases, and _resolve_handler paths - Coverage: 90% → 96%
1 parent aac44fc commit 3198871

2 files changed

Lines changed: 221 additions & 7 deletions

File tree

statemachine/invoke.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ def run(self, ctx: "InvokeContext") -> Any:
8383

8484
def on_cancel(self):
8585
"""Delegate to the live instance's ``on_cancel()`` if available."""
86-
target = self._instance if self._instance is not None else self._invoke_handler
86+
if self._instance is not None:
87+
target = self._instance
88+
elif self._is_class:
89+
return # Handler hasn't been instantiated yet — nothing to cancel
90+
else:
91+
target = self._invoke_handler
8792
if hasattr(target, "on_cancel"):
8893
target.on_cancel()
8994

tests/test_invoke.py

Lines changed: 215 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,9 @@ def on_invoke_loading(self, ctx=None, **kwargs):
189189
async def test_cancel_on_exit_with_on_cancel(self, sm_runner):
190190
"""Test that on_cancel() is called when state is exited."""
191191
cancel_called = []
192-
started = threading.Event()
193192

194193
class CancelTracker:
195194
def run(self, ctx):
196-
started.set()
197-
# Poll instead of blocking to work with both sync and async engines
198195
while not ctx.cancelled.is_set():
199196
ctx.cancelled.wait(0.01)
200197

@@ -207,10 +204,10 @@ class SM(StateChart):
207204
cancel = loading.to(cancelled_state)
208205

209206
sm = await sm_runner.start(SM)
210-
# Wait for invoke handler to start (runs in thread for sync, task for async)
211-
await sm_runner.sleep(0.05)
207+
# Give the invoke handler time to start in its background thread
208+
await sm_runner.sleep(0.15)
212209
await sm_runner.send(sm, "cancel")
213-
await sm_runner.sleep(0.05)
210+
await sm_runner.sleep(0.15)
214211

215212
assert cancel_called == [True]
216213
assert "cancelled_state" in sm.configuration_values
@@ -508,3 +505,215 @@ class SM(StateChart):
508505
assert "active" in sm.configuration_values
509506
await sm_runner.send(sm, "finish")
510507
assert "done" in sm.configuration_values
508+
509+
510+
class TestInvokeManagerCancelAll:
511+
"""InvokeManager.cancel_all() cancels every active invocation."""
512+
513+
async def test_cancel_all(self, sm_runner):
514+
class SlowHandler:
515+
def run(self, ctx):
516+
ctx.cancelled.wait(timeout=5.0)
517+
518+
class SM(StateChart):
519+
loading = State(initial=True, invoke=SlowHandler)
520+
stopped = State(final=True)
521+
cancel = loading.to(stopped)
522+
523+
sm = await sm_runner.start(SM)
524+
await sm_runner.sleep(0.15)
525+
sm._engine._invoke_manager.cancel_all()
526+
await sm_runner.sleep(0.15)
527+
528+
# All invocations should be terminated
529+
for inv in sm._engine._invoke_manager._active.values():
530+
assert inv.terminated
531+
532+
533+
class TestInvokeCancelAlreadyTerminated:
534+
"""Cancelling an already-terminated invocation is a no-op."""
535+
536+
async def test_cancel_terminated_invocation(self, sm_runner):
537+
class SM(StateChart):
538+
loading = State(initial=True, invoke=lambda: 42)
539+
ready = State(final=True)
540+
done_invoke_loading = loading.to(ready)
541+
542+
sm = await sm_runner.start(SM)
543+
await sm_runner.sleep(0.15)
544+
await sm_runner.processing_loop(sm)
545+
546+
assert "ready" in sm.configuration_values
547+
# All invocations should be terminated by now
548+
manager = sm._engine._invoke_manager
549+
for inv in manager._active.values():
550+
assert inv.terminated
551+
# Calling cancel on terminated invocations should be a safe no-op
552+
for inv_id in list(manager._active.keys()):
553+
manager._cancel(inv_id)
554+
555+
556+
class TestInvokeOnCancelException:
557+
"""Exception in on_cancel() is caught and logged, not propagated."""
558+
559+
async def test_on_cancel_exception_is_suppressed(self, sm_runner):
560+
class BadCancelHandler:
561+
def run(self, ctx):
562+
ctx.cancelled.wait(timeout=5.0)
563+
564+
def on_cancel(self):
565+
raise RuntimeError("on_cancel exploded")
566+
567+
class SM(StateChart):
568+
loading = State(initial=True, invoke=BadCancelHandler)
569+
stopped = State(final=True)
570+
cancel = loading.to(stopped)
571+
572+
sm = await sm_runner.start(SM)
573+
await sm_runner.sleep(0.15)
574+
# This should NOT raise even though on_cancel() raises
575+
await sm_runner.send(sm, "cancel")
576+
await sm_runner.sleep(0.15)
577+
578+
assert "stopped" in sm.configuration_values
579+
580+
581+
class TestStateChartInvokerOnCancel:
582+
"""StateChartInvoker.on_cancel() cleans up the child reference."""
583+
584+
def test_on_cancel_clears_child(self):
585+
from statemachine.invoke import StateChartInvoker
586+
587+
class ChildMachine(StateChart):
588+
start = State(initial=True, final=True)
589+
590+
invoker = StateChartInvoker(ChildMachine)
591+
ctx = InvokeContext(
592+
invokeid="test.123",
593+
state_id="test",
594+
send=lambda *a, **kw: None,
595+
machine=None,
596+
)
597+
invoker.run(ctx)
598+
assert invoker._child is not None
599+
invoker.on_cancel()
600+
assert invoker._child is None
601+
602+
603+
class TestNormalizeInvokeCallbacks:
604+
"""normalize_invoke_callbacks handles edge cases."""
605+
606+
def test_string_passes_through(self):
607+
from statemachine.invoke import normalize_invoke_callbacks
608+
609+
result = normalize_invoke_callbacks("some_method_name")
610+
assert result == ["some_method_name"]
611+
612+
def test_already_wrapped_passes_through(self):
613+
from statemachine.invoke import _InvokeCallableWrapper
614+
from statemachine.invoke import normalize_invoke_callbacks
615+
616+
class MyHandler:
617+
def run(self, ctx):
618+
pass
619+
620+
wrapper = _InvokeCallableWrapper(MyHandler)
621+
result = normalize_invoke_callbacks(wrapper)
622+
assert len(result) == 1
623+
assert result[0] is wrapper
624+
625+
def test_iinvoke_class_with_run_method(self):
626+
from statemachine.invoke import _InvokeCallableWrapper
627+
from statemachine.invoke import normalize_invoke_callbacks
628+
629+
class CustomHandler:
630+
def run(self, ctx):
631+
return "result"
632+
633+
result = normalize_invoke_callbacks(CustomHandler)
634+
assert len(result) == 1
635+
assert isinstance(result[0], _InvokeCallableWrapper)
636+
637+
def test_plain_callable_passes_through(self):
638+
from statemachine.invoke import _InvokeCallableWrapper
639+
from statemachine.invoke import normalize_invoke_callbacks
640+
641+
def my_func():
642+
return 42
643+
644+
result = normalize_invoke_callbacks(my_func)
645+
assert len(result) == 1
646+
assert result[0] is my_func
647+
assert not isinstance(result[0], _InvokeCallableWrapper)
648+
649+
650+
class TestResolveHandler:
651+
"""InvokeManager._resolve_handler edge cases."""
652+
653+
def test_bare_iinvoke_instance(self):
654+
from statemachine.invoke import InvokeManager
655+
656+
class MyHandler:
657+
def run(self, ctx):
658+
return "result"
659+
660+
handler = MyHandler()
661+
assert isinstance(handler, IInvoke)
662+
resolved = InvokeManager._resolve_handler(handler)
663+
assert resolved is handler
664+
665+
def test_bare_statechart_class(self):
666+
from statemachine.invoke import InvokeManager
667+
from statemachine.invoke import StateChartInvoker
668+
669+
class ChildMachine(StateChart):
670+
start = State(initial=True, final=True)
671+
672+
resolved = InvokeManager._resolve_handler(ChildMachine)
673+
assert isinstance(resolved, StateChartInvoker)
674+
675+
def test_plain_callable_returns_none(self):
676+
from statemachine.invoke import InvokeManager
677+
678+
def my_func():
679+
return 42
680+
681+
assert InvokeManager._resolve_handler(my_func) is None
682+
683+
684+
class TestInvokeCallableWrapperOnCancel:
685+
"""_InvokeCallableWrapper.on_cancel() edge cases."""
686+
687+
def test_on_cancel_non_class_instance_with_on_cancel(self):
688+
"""Non-class handler (already instantiated) delegates on_cancel."""
689+
from statemachine.invoke import _InvokeCallableWrapper
690+
691+
cancel_called = []
692+
693+
class MyHandler:
694+
def run(self, ctx):
695+
return "result"
696+
697+
def on_cancel(self):
698+
cancel_called.append(True)
699+
700+
handler = MyHandler()
701+
wrapper = _InvokeCallableWrapper(handler)
702+
# _instance is None, _is_class is False → falls through to _invoke_handler
703+
wrapper.on_cancel()
704+
assert cancel_called == [True]
705+
706+
def test_on_cancel_class_not_yet_instantiated(self):
707+
"""Class handler not yet instantiated — on_cancel is a no-op."""
708+
from statemachine.invoke import _InvokeCallableWrapper
709+
710+
class MyHandler:
711+
def run(self, ctx):
712+
return "result"
713+
714+
def on_cancel(self):
715+
raise RuntimeError("should not be called")
716+
717+
wrapper = _InvokeCallableWrapper(MyHandler)
718+
# _instance is None, _is_class is True → early return
719+
wrapper.on_cancel() # should not raise

0 commit comments

Comments
 (0)