Skip to content

Commit cdb472a

Browse files
committed
perf: eliminate for_instance() cache — store InstanceState in sm.__dict__
Remove the State descriptor protocol (__get__/__set__) and for_instance() cache. InstanceState objects are now created eagerly in StateChart.__init__ and stored directly in sm.__dict__, making sm.<state_id> a plain dict lookup instead of a descriptor call + cache lookup. Configuration no longer holds a weakref to the machine; it receives a dedicated instance_states dict and resolves active states via direct dict lookup. The __setattr__ guard on StateChart preserves the existing protection against accidental state overriding. States whose id collides with an event name are kept out of __dict__ to preserve Event descriptor priority. Fix type mismatches in engines/base.py and event_data.py that were previously masked by the State.__set__ descriptor.
1 parent f1e07d0 commit cdb472a

6 files changed

Lines changed: 47 additions & 51 deletions

File tree

statemachine/configuration.py

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import TYPE_CHECKING
22
from typing import Any
33
from typing import Dict
4+
from typing import Mapping
45
from typing import MutableSet
5-
from weakref import ref
66

77
from .exceptions import InvalidStateValue
88
from .i18n import _
@@ -12,7 +12,6 @@
1212

1313
if TYPE_CHECKING:
1414
from .state import State
15-
from .statemachine import StateChart
1615

1716

1817
class Configuration:
@@ -25,28 +24,25 @@ class Configuration:
2524
"""
2625

2726
__slots__ = (
28-
"_machine_ref",
27+
"_instance_states",
2928
"_model",
3029
"_state_field",
3130
"_states_map",
32-
"_for_instance",
3331
"_cached",
3432
"_cached_value",
3533
)
3634

3735
def __init__(
3836
self,
39-
machine: "StateChart",
37+
instance_states: "Mapping[str, State]",
4038
model: Any,
4139
state_field: str,
4240
states_map: "Dict[Any, State]",
43-
for_instance_cache: "Dict[State, State]",
4441
):
45-
self._machine_ref: "ref[StateChart]" = ref(machine)
42+
self._instance_states = instance_states
4643
self._model = model
4744
self._state_field = state_field
4845
self._states_map = states_map
49-
self._for_instance = for_instance_cache
5046
self._cached: "OrderedSet[State] | None" = None
5147
self._cached_value: Any = _SENTINEL
5248

@@ -83,20 +79,11 @@ def states(self) -> "OrderedSet[State]":
8379
if csv is None:
8480
return OrderedSet()
8581

86-
machine = self._machine_ref()
87-
assert machine is not None
88-
82+
instance_states = self._instance_states
8983
if not isinstance(csv, MutableSet):
90-
result = OrderedSet(
91-
[self._states_map[csv].for_instance(machine=machine, cache=self._for_instance)]
92-
)
84+
result = OrderedSet([instance_states[self._states_map[csv].id]])
9385
else:
94-
result = OrderedSet(
95-
[
96-
self._states_map[v].for_instance(machine=machine, cache=self._for_instance)
97-
for v in csv
98-
]
99-
)
86+
result = OrderedSet([instance_states[self._states_map[v].id] for v in csv])
10087

10188
self._cached = result
10289
self._cached_value = csv

statemachine/engines/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,8 @@ def add_descendant_states_to_enter( # noqa: C901
804804
)
805805

806806
for transition in state.transitions:
807-
info_history = StateTransition(transition=transition, state=transition.target)
807+
target = cast(State, transition.target)
808+
info_history = StateTransition(transition=transition, state=target)
808809
default_history_content[parent_id].append(info_history)
809810
self.add_descendant_states_to_enter(
810811
info_history,
@@ -813,7 +814,8 @@ def add_descendant_states_to_enter( # noqa: C901
813814
default_history_content,
814815
) # noqa: E501
815816
for transition in state.transitions:
816-
info_history = StateTransition(transition=transition, state=transition.target)
817+
target = cast(State, transition.target)
818+
info_history = StateTransition(transition=transition, state=target)
817819

818820
self.add_ancestor_states_to_enter(
819821
info_history,

statemachine/event_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ class EventData:
6363
source: "State" = field(init=False)
6464
"""The :ref:`State` which :ref:`statemachine` was in when the Event started."""
6565

66-
target: "State" = field(init=False)
67-
"""The destination :ref:`State` of the :ref:`transition`."""
66+
target: "State | None" = field(init=False)
67+
"""The destination :ref:`State` of the :ref:`transition`, or ``None`` for targetless."""
6868

6969
def __post_init__(self):
7070
self.state = self.transition.source

statemachine/state.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from enum import Enum
22
from typing import TYPE_CHECKING
33
from typing import Any
4-
from typing import Dict
54
from typing import Generator
65
from typing import List
76
from typing import cast
@@ -12,7 +11,6 @@
1211
from .callbacks import CallbackSpecList
1312
from .event import _expand_event_id
1413
from .exceptions import InvalidDefinition
15-
from .exceptions import StateMachineError
1614
from .i18n import _
1715
from .invoke import normalize_invoke_callbacks
1816
from .transition import Transition
@@ -295,22 +293,6 @@ def __repr__(self):
295293
def __str__(self):
296294
return self.name
297295

298-
def __get__(self, machine, owner):
299-
if machine is None:
300-
return self
301-
return self.for_instance(machine=machine, cache=machine._config._for_instance)
302-
303-
def __set__(self, instance, value):
304-
raise StateMachineError(
305-
_("State overriding is not allowed. Trying to add '{}' to {}").format(value, self.id)
306-
)
307-
308-
def for_instance(self, machine: "StateChart", cache: Dict["State", "State"]) -> "State":
309-
if self not in cache:
310-
cache[self] = InstanceState(self, machine)
311-
312-
return cache[self]
313-
314296
@property
315297
def id(self) -> str:
316298
return self._id

statemachine/statemachine.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
from .event_data import TriggerData
2626
from .exceptions import InvalidDefinition
2727
from .exceptions import InvalidStateValue
28+
from .exceptions import StateMachineError
2829
from .exceptions import TransitionNotAllowed
2930
from .factory import StateMachineMetaclass
3031
from .graph import iterate_states_and_transitions
3132
from .i18n import _
3233
from .model import Model
3334
from .signature import SignatureAdapter
35+
from .state import InstanceState
3436
from .utils import run_async_from_sync
3537

3638
if TYPE_CHECKING:
@@ -151,12 +153,18 @@ def __init__(
151153
[start_value] if start_value is not None else list(self.start_configuration_values)
152154
)
153155
self._callbacks = CallbacksRegistry()
156+
instance_states: Dict[str, Any] = {}
157+
events = self.__class__._events
158+
for state in self.states_map.values():
159+
ist = InstanceState(state, self)
160+
instance_states[state.id] = ist
161+
if state.id not in events:
162+
vars(self)[state.id] = ist
154163
self._config = Configuration(
155-
machine=self,
164+
instance_states=instance_states,
156165
model=self.model,
157166
state_field=self.state_field,
158167
states_map=self.states_map,
159-
for_instance_cache={},
160168
)
161169
self._listeners: Dict[int, Any] = {}
162170
"""Listeners that provides attributes to be used as callbacks."""
@@ -212,6 +220,17 @@ def _processing_loop(self, caller_future: "Any | None" = None) -> Any:
212220
return result
213221
return run_async_from_sync(result)
214222

223+
def __setattr__(self, name, value):
224+
# Fast path: internal/private attributes are never state IDs.
225+
if not name.startswith("_") and name in self.__class__.states_map:
226+
if not isinstance(value, InstanceState):
227+
raise StateMachineError(
228+
_("State overriding is not allowed. Trying to add '{}' to {}").format(
229+
value, name
230+
)
231+
)
232+
super().__setattr__(name, value)
233+
215234
def __repr__(self):
216235
configuration_ids = [s.id for s in self.configuration]
217236
return (
@@ -220,7 +239,7 @@ def __repr__(self):
220239
)
221240

222241
def __getstate__(self):
223-
state = self.__dict__.copy()
242+
state = {k: v for k, v in self.__dict__.items() if not isinstance(v, InstanceState)}
224243
del state["_callbacks"]
225244
del state["_config"]
226245
del state["_engine"]
@@ -230,12 +249,18 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
230249
listeners = state.pop("_listeners")
231250
self.__dict__.update(state) # type: ignore[attr-defined]
232251
self._callbacks = CallbacksRegistry()
252+
instance_states: Dict[str, Any] = {}
253+
events = self.__class__._events
254+
for sm_state in self.states_map.values():
255+
ist = InstanceState(sm_state, self)
256+
instance_states[sm_state.id] = ist
257+
if sm_state.id not in events:
258+
vars(self)[sm_state.id] = ist
233259
self._config = Configuration(
234-
machine=self,
260+
instance_states=instance_states,
235261
model=self.model,
236262
state_field=self.state_field,
237263
states_map=self.states_map,
238-
for_instance_cache={},
239264
)
240265
self._listeners = {}
241266

tests/test_configuration.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def test_set_empty_configuration(self):
3333

3434
def test_set_multi_element_configuration(self):
3535
sm = ParallelSM()
36-
s1_inst = ParallelSM.s1.for_instance(machine=sm, cache=sm._config._for_instance)
37-
s2_inst = ParallelSM.s2.for_instance(machine=sm, cache=sm._config._for_instance)
36+
s1_inst = sm.s1
37+
s2_inst = sm.s2
3838

3939
sm.configuration = OrderedSet([s1_inst, s2_inst])
4040
assert isinstance(sm.current_state_value, OrderedSet)
@@ -55,8 +55,8 @@ def test_discard_nonmatching_scalar(self):
5555
class TestConfigurationCurrentState:
5656
def test_current_state_with_multiple_active_states(self):
5757
sm = ParallelSM()
58-
s1_inst = ParallelSM.s1.for_instance(machine=sm, cache=sm._config._for_instance)
59-
s2_inst = ParallelSM.s2.for_instance(machine=sm, cache=sm._config._for_instance)
58+
s1_inst = sm.s1
59+
s2_inst = sm.s2
6060
sm.configuration = OrderedSet([s1_inst, s2_inst])
6161

6262
with warnings.catch_warnings():

0 commit comments

Comments
 (0)