Skip to content

Commit 9d7f114

Browse files
committed
feat: add weighted (probabilistic) transitions contrib module
Add `weighted_transitions()` utility that enables probabilistic transition selection based on relative weights. Works entirely through the existing `cond` guard system with zero engine changes. API: weighted_transitions(source, (target, weight), ..., seed=N) to(target, weight, cond=..., on=..., ...) # for transition kwargs Inspired by PR #539 (@bcorfman).
1 parent fde13d9 commit 9d7f114

6 files changed

Lines changed: 959 additions & 0 deletions

File tree

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ async
1717
mixins
1818
integrations
1919
diagram
20+
weighted_transitions
2021
processing_model
2122
statecharts
2223
api

docs/releases/3.0.0.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,35 @@ flag `validate_disconnected_states: bool = True` that can be used to disable thi
346346
It's already disabled when parsing SCXML files.
347347

348348

349+
### Weighted (probabilistic) transitions
350+
351+
A new contrib module `statemachine.contrib.weighted` provides `weighted_transitions()`,
352+
enabling probabilistic transition selection based on relative weights. This works entirely
353+
through the existing condition system — no engine changes required:
354+
355+
```python
356+
from statemachine.contrib.weighted import weighted_transitions
357+
358+
class GameCharacter(StateChart):
359+
standing = State(initial=True)
360+
shift_weight = State()
361+
adjust_hair = State()
362+
bang_shield = State()
363+
364+
idle = weighted_transitions(
365+
standing,
366+
(shift_weight, 70),
367+
(adjust_hair, 20),
368+
(bang_shield, 10),
369+
seed=42,
370+
)
371+
372+
finish = shift_weight.to(standing) | adjust_hair.to(standing) | bang_shield.to(standing)
373+
```
374+
375+
See {ref}`weighted-transitions` for full documentation.
376+
377+
349378
## Bugfixes in 3.0.0
350379

351380
- Fixes [#XXX](https://github.com/fgmacedo/python-statemachine/issues/XXX).

docs/weighted_transitions.md

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
(weighted-transitions)=
2+
3+
# Weighted transitions
4+
5+
```{versionadded} 3.0.0
6+
```
7+
8+
The `weighted_transitions` utility lets you define **probabilistic transitions** — where
9+
each transition from a state has a relative weight that determines how likely it is to be
10+
selected when the event fires.
11+
12+
This is a contrib module that works entirely through the existing {ref}`guards` system.
13+
No engine modifications are needed.
14+
15+
## Installation
16+
17+
The module is included in the `python-statemachine` package. Import it from the contrib
18+
namespace:
19+
20+
```python
21+
from statemachine.contrib.weighted import weighted_transitions
22+
23+
# Only needed when passing transition kwargs (cond, on, etc.)
24+
from statemachine.contrib.weighted import to
25+
```
26+
27+
## Basic usage
28+
29+
Pass a **source state** followed by `(target, weight)` tuples. The result is a regular
30+
{ref}`TransitionList` that you assign to a class attribute as an event:
31+
32+
```{testsetup}
33+
34+
>>> from statemachine import State, StateChart
35+
>>> from statemachine.contrib.weighted import to, weighted_transitions
36+
37+
```
38+
39+
```py
40+
>>> class GameCharacter(StateChart):
41+
... standing = State(initial=True)
42+
... shift_weight = State()
43+
... adjust_hair = State()
44+
... bang_shield = State()
45+
...
46+
... idle = weighted_transitions(
47+
... standing,
48+
... (shift_weight, 70),
49+
... (adjust_hair, 20),
50+
... (bang_shield, 10),
51+
... seed=42,
52+
... )
53+
...
54+
... finish = (
55+
... shift_weight.to(standing)
56+
... | adjust_hair.to(standing)
57+
... | bang_shield.to(standing)
58+
... )
59+
60+
>>> sm = GameCharacter()
61+
>>> sm.send("idle")
62+
>>> any(
63+
... s in sm.configuration_values
64+
... for s in ("shift_weight", "adjust_hair", "bang_shield")
65+
... )
66+
True
67+
68+
```
69+
70+
When `idle` fires, the engine randomly selects one of the three transitions based on
71+
their relative weights: 70% chance for `shift_weight`, 20% for `adjust_hair`,
72+
10% for `bang_shield`.
73+
74+
## Weights
75+
76+
Weights can be any **positive number** — integers, floats, or a mix of both. They are
77+
relative, not absolute percentages:
78+
79+
```python
80+
# These are equivalent (same 70/20/10 ratio):
81+
idle = weighted_transitions(
82+
standing,
83+
(shift_weight, 70),
84+
(adjust_hair, 20),
85+
(bang_shield, 10),
86+
)
87+
88+
idle = weighted_transitions(
89+
standing,
90+
(shift_weight, 7),
91+
(adjust_hair, 2),
92+
(bang_shield, 1),
93+
)
94+
95+
idle = weighted_transitions(
96+
standing,
97+
(shift_weight, 0.7),
98+
(adjust_hair, 0.2),
99+
(bang_shield, 0.1),
100+
)
101+
```
102+
103+
The tuple format `(target, weight)` follows the standard Python pattern used by
104+
{py:func}`random.choices`.
105+
106+
## Reproducibility with `seed`
107+
108+
Pass a `seed` parameter for deterministic, reproducible sequences — useful for testing:
109+
110+
```python
111+
go = weighted_transitions(
112+
s1,
113+
(s2, 50),
114+
(s3, 50),
115+
seed=42, # same seed always produces the same sequence
116+
)
117+
```
118+
119+
```{note}
120+
The seed initializes a per-group `random.Random` instance that is shared across all
121+
instances of the same state machine class. This means the sequence is deterministic
122+
for a given program execution, but different instances advance the same RNG.
123+
```
124+
125+
## Per-transition options
126+
127+
Use the {func}`~statemachine.contrib.weighted.to` helper to pass transition keyword
128+
arguments (``cond``, ``unless``, ``before``, ``on``, ``after``, …) as natural kwargs.
129+
For simple destinations without extra options, a plain ``(target, weight)`` tuple is
130+
enough — ``to()`` is only needed when you want to customize the transition:
131+
132+
```py
133+
>>> class GuardedWeighted(StateChart):
134+
... idle = State(initial=True)
135+
... walk = State()
136+
... run = State()
137+
...
138+
... move = weighted_transitions(
139+
... idle,
140+
... (walk, 70),
141+
... to(run, 30, cond="has_energy"),
142+
... )
143+
... stop = walk.to(idle) | run.to(idle)
144+
...
145+
... has_energy = True
146+
147+
>>> sm = GuardedWeighted()
148+
149+
```
150+
151+
```{important}
152+
**No fallback when a guard fails.** If the weighted selection picks a transition whose
153+
guard evaluates to ``False``, the event fails — the engine does **not** silently fall back
154+
to another transition. This preserves the probability semantics: a 70/30 split means
155+
exactly that, not "70/30 unless the 30% is blocked, in which case always 100% for
156+
the other".
157+
158+
This behavior follows {ref}`conditions` evaluation: the first transition whose **all**
159+
conditions pass is executed.
160+
```
161+
162+
## Combining with callbacks
163+
164+
All standard {ref}`actions` work with weighted events — `before`, `on`, `after` callbacks
165+
and naming conventions like `on_<event>()`:
166+
167+
```python
168+
class WithCallbacks(StateChart):
169+
s1 = State(initial=True)
170+
s2 = State()
171+
s3 = State()
172+
173+
go = weighted_transitions(s1, (s2, 60), (s3, 40))
174+
back = s2.to(s1) | s3.to(s1)
175+
176+
def on_go(self):
177+
print("go event fired!")
178+
179+
def after_go(self):
180+
print("after go!")
181+
```
182+
183+
## Multiple independent groups
184+
185+
Each call to `weighted_transitions()` creates an independent weighted group with its
186+
own RNG. You can have multiple weighted events on the same state machine:
187+
188+
```python
189+
class MultiGroup(StateChart):
190+
idle = State(initial=True)
191+
walk = State()
192+
run = State()
193+
wave = State()
194+
bow = State()
195+
196+
move = weighted_transitions(idle, (walk, 70), (run, 30), seed=1)
197+
greet = weighted_transitions(idle, (wave, 80), (bow, 20), seed=2)
198+
back = walk.to(idle) | run.to(idle) | wave.to(idle) | bow.to(idle)
199+
```
200+
201+
The `move` and `greet` events use separate RNGs and don't interfere with each other.
202+
203+
## Validation
204+
205+
`weighted_transitions()` validates inputs at class definition time:
206+
207+
- The first argument must be a `State` (the source).
208+
- Each destination must be a `(target_state, weight)` or
209+
`(target_state, weight, kwargs_dict)` tuple.
210+
- Weights must be positive numbers (`int` or `float`).
211+
- At least one destination is required.
212+
213+
```py
214+
>>> weighted_transitions(State(initial=True))
215+
Traceback (most recent call last):
216+
...
217+
ValueError: weighted_transitions() requires at least one (target, weight) destination
218+
219+
>>> s1, s2 = State(initial=True), State()
220+
>>> weighted_transitions(s1, (s2, -5))
221+
Traceback (most recent call last):
222+
...
223+
ValueError: Destination 0: weight must be positive, got -5
224+
225+
>>> weighted_transitions(s1, (s2, "ten"))
226+
Traceback (most recent call last):
227+
...
228+
TypeError: Destination 0: weight must be a positive number, got str
229+
230+
```
231+
232+
## How it works
233+
234+
Under the hood, `weighted_transitions()`:
235+
236+
1. Creates a `_WeightedGroup` holding the weights and a `random.Random` instance.
237+
2. Calls `source.to(target, **kwargs)` for each destination, creating standard
238+
transitions.
239+
3. Attaches a lightweight condition callable to each transition's `cond` list.
240+
4. When the event fires, the engine evaluates conditions in order. The first condition
241+
to run rolls the dice (using `random.choices`) and caches the result. Subsequent
242+
conditions check against the cache.
243+
5. Only the selected transition's condition returns `True` — the engine picks it.
244+
245+
This means weighted transitions are fully compatible with all engine features:
246+
{ref}`actions`, {ref}`validators-and-guards`, {ref}`listeners`, async engines,
247+
and {ref}`diagram generation <diagram>`.

0 commit comments

Comments
 (0)