Skip to content

Commit 68f8825

Browse files
committed
perf: remove weakref indirection on engine.sm and extract _first_transition_that_matches
- Replace weakref + property with direct reference: eliminates 145k weakref deref + assert calls per benchmark cycle. The engine lifetime is tied to the SM — no leak risk with CPython's cyclic GC (PEP 442). - Extract `first_transition_that_matches` closure into a proper method `_first_transition_that_matches` on both BaseEngine and AsyncEngine, avoiding re-creation of the function object on every _select_transitions call.
1 parent 9708a09 commit 68f8825

2 files changed

Lines changed: 37 additions & 42 deletions

File tree

statemachine/engines/async_.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from .base import BaseEngine
1717

1818
if TYPE_CHECKING:
19-
from ..event import Event
2019
from ..transition import Transition
2120

2221
# ContextVar to distinguish reentrant calls (from within callbacks) from
@@ -105,29 +104,32 @@ async def _conditions_match(self, transition: "Transition", trigger_data: Trigge
105104
transition.cond.key, *args, on_error=on_error, **kwargs
106105
)
107106

107+
async def _first_transition_that_matches( # type: ignore[override]
108+
self,
109+
state: State,
110+
trigger_data: TriggerData,
111+
predicate: Callable,
112+
) -> "Transition | None":
113+
for s in chain([state], state.ancestors()):
114+
transition: "Transition"
115+
for transition in s.transitions:
116+
if (
117+
not transition.initial
118+
and predicate(transition, trigger_data.event)
119+
and await self._conditions_match(transition, trigger_data)
120+
):
121+
return transition
122+
return None
123+
108124
async def _select_transitions( # type: ignore[override]
109125
self, trigger_data: TriggerData, predicate: Callable
110126
) -> "OrderedSet[Transition]":
111127
enabled_transitions: "OrderedSet[Transition]" = OrderedSet()
112128

113129
atomic_states = (state for state in self.sm.configuration if state.is_atomic)
114130

115-
async def first_transition_that_matches(
116-
state: State, event: "Event | None"
117-
) -> "Transition | None":
118-
for s in chain([state], state.ancestors()):
119-
transition: "Transition"
120-
for transition in s.transitions:
121-
if (
122-
not transition.initial
123-
and predicate(transition, event)
124-
and await self._conditions_match(transition, trigger_data)
125-
):
126-
return transition
127-
return None
128-
129131
for state in atomic_states:
130-
transition = await first_transition_that_matches(state, trigger_data.event)
132+
transition = await self._first_transition_that_matches(state, trigger_data, predicate)
131133
if transition is not None:
132134
enabled_transitions.add(transition)
133135

statemachine/engines/base.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,8 @@
1111
from typing import Dict
1212
from typing import List
1313
from typing import cast
14-
from weakref import ReferenceType
15-
from weakref import ref
1614

1715
from ..event import BoundEvent
18-
from ..event import Event
1916
from ..event_data import EventData
2017
from ..event_data import TriggerData
2118
from ..exceptions import InvalidDefinition
@@ -88,7 +85,7 @@ def remove(self, send_id: str):
8885

8986
class BaseEngine:
9087
def __init__(self, sm: "StateChart"):
91-
self._sm: ReferenceType["StateChart"] = ref(sm)
88+
self.sm: "StateChart" = sm
9289
self.external_queue = EventQueue()
9390
self.internal_queue = EventQueue()
9491
self._sentinel = object()
@@ -105,12 +102,6 @@ def __init__(self, sm: "StateChart"):
105102
def empty(self): # pragma: no cover
106103
return self.external_queue.is_empty()
107104

108-
@property
109-
def sm(self) -> "StateChart":
110-
sm = self._sm()
111-
assert sm, "StateMachine has been destroyed"
112-
return sm
113-
114105
def clear_cache(self):
115106
"""Clears the cache. Should be called at the start of each processing loop."""
116107
self._cache.clear()
@@ -347,6 +338,23 @@ def select_transitions(self, trigger_data: TriggerData) -> OrderedSet[Transition
347338
"""
348339
return self._select_transitions(trigger_data, lambda t, e: t.match(e))
349340

341+
def _first_transition_that_matches(
342+
self,
343+
state: State,
344+
trigger_data: TriggerData,
345+
predicate: Callable,
346+
) -> "Transition | None":
347+
for s in chain([state], state.ancestors()):
348+
transition: Transition
349+
for transition in s.transitions:
350+
if (
351+
not transition.initial
352+
and predicate(transition, trigger_data.event)
353+
and self._conditions_match(transition, trigger_data)
354+
):
355+
return transition
356+
return None
357+
350358
def _select_transitions(
351359
self, trigger_data: TriggerData, predicate: Callable
352360
) -> OrderedSet[Transition]:
@@ -356,23 +364,8 @@ def _select_transitions(
356364
# Get atomic states, TODO: sorted by document order
357365
atomic_states = (state for state in self.sm.configuration if state.is_atomic)
358366

359-
def first_transition_that_matches(
360-
state: State, event: "Event | None"
361-
) -> "Transition | None":
362-
for s in chain([state], state.ancestors()):
363-
transition: Transition
364-
for transition in s.transitions:
365-
if (
366-
not transition.initial
367-
and predicate(transition, event)
368-
and self._conditions_match(transition, trigger_data)
369-
):
370-
return transition
371-
372-
return None
373-
374367
for state in atomic_states:
375-
transition = first_transition_that_matches(state, trigger_data.event)
368+
transition = self._first_transition_that_matches(state, trigger_data, predicate)
376369
if transition is not None:
377370
enabled_transitions.add(transition)
378371

0 commit comments

Comments
 (0)