|
| 1 | +"""Contract tests: observable behavior of public Configuration APIs. |
| 2 | +
|
| 3 | +Documents the exact values returned by each public API across all supported |
| 4 | +topologies (flat, compound, parallel, complex parallel) and lifecycle phases |
| 5 | +(initial state, after transitions, final state). |
| 6 | +
|
| 7 | +APIs under test (StateChart): |
| 8 | + sm.current_state_value -- raw value stored on the model |
| 9 | + sm.configuration_values -- OrderedSet of raw values |
| 10 | + sm.configuration -- OrderedSet[State] |
| 11 | + sm.current_state -- State or OrderedSet[State] (deprecated) |
| 12 | +
|
| 13 | +API under test (Model): |
| 14 | + model.state -- raw attribute on the model object |
| 15 | +""" |
| 16 | + |
| 17 | +import warnings |
| 18 | +from typing import Any |
| 19 | + |
| 20 | +import pytest |
| 21 | +from statemachine.orderedset import OrderedSet |
| 22 | + |
| 23 | +from statemachine import State |
| 24 | +from statemachine import StateChart |
| 25 | + |
| 26 | +# --------------------------------------------------------------------------- |
| 27 | +# Model |
| 28 | +# --------------------------------------------------------------------------- |
| 29 | + |
| 30 | + |
| 31 | +class Model: |
| 32 | + """Explicit model to verify raw state persistence independently.""" |
| 33 | + |
| 34 | + def __init__(self): |
| 35 | + self.state: Any = None |
| 36 | + |
| 37 | + |
| 38 | +# --------------------------------------------------------------------------- |
| 39 | +# Topologies |
| 40 | +# --------------------------------------------------------------------------- |
| 41 | + |
| 42 | + |
| 43 | +class FlatSC(StateChart): |
| 44 | + s1 = State(initial=True) |
| 45 | + s2 = State() |
| 46 | + s3 = State(final=True) |
| 47 | + |
| 48 | + go = s1.to(s2) |
| 49 | + finish = s2.to(s3) |
| 50 | + |
| 51 | + |
| 52 | +class CompoundSC(StateChart): |
| 53 | + class parent(State.Compound): |
| 54 | + child1 = State(initial=True) |
| 55 | + child2 = State() |
| 56 | + move = child1.to(child2) |
| 57 | + |
| 58 | + done = State(final=True) |
| 59 | + leave = parent.to(done) |
| 60 | + |
| 61 | + |
| 62 | +class ParallelSC(StateChart): |
| 63 | + class regions(State.Parallel): |
| 64 | + class region_a(State.Compound): |
| 65 | + a1 = State(initial=True) |
| 66 | + a2 = State() |
| 67 | + go_a = a1.to(a2) |
| 68 | + |
| 69 | + class region_b(State.Compound): |
| 70 | + b1 = State(initial=True) |
| 71 | + b2 = State() |
| 72 | + go_b = b1.to(b2) |
| 73 | + |
| 74 | + |
| 75 | +class ComplexParallelSC(StateChart): |
| 76 | + class top(State.Parallel): |
| 77 | + class left(State.Compound): |
| 78 | + class nested(State.Compound): |
| 79 | + l1 = State(initial=True) |
| 80 | + l2 = State() |
| 81 | + move_l = l1.to(l2) |
| 82 | + |
| 83 | + left_done = State(final=True) |
| 84 | + finish_left = nested.to(left_done) |
| 85 | + |
| 86 | + class right(State.Compound): |
| 87 | + r1 = State(initial=True) |
| 88 | + r2 = State() |
| 89 | + move_r = r1.to(r2) |
| 90 | + |
| 91 | + |
| 92 | +# --------------------------------------------------------------------------- |
| 93 | +# Assertion helper |
| 94 | +# --------------------------------------------------------------------------- |
| 95 | + |
| 96 | + |
| 97 | +def assert_contract(sm, model, expected_ids: set): |
| 98 | + """Assert the full observable API contract. |
| 99 | +
|
| 100 | + When exactly one state is active, the model stores a scalar and |
| 101 | + ``current_state`` returns a single ``State``. When multiple states |
| 102 | + are active (compound/parallel), the model stores an ``OrderedSet`` |
| 103 | + and ``current_state`` returns ``OrderedSet[State]``. |
| 104 | + """ |
| 105 | + scalar = len(expected_ids) == 1 |
| 106 | + |
| 107 | + # model.state and current_state_value point to the same object |
| 108 | + assert model.state is sm.current_state_value |
| 109 | + |
| 110 | + if scalar: |
| 111 | + val = next(iter(expected_ids)) |
| 112 | + assert model.state == val |
| 113 | + assert not isinstance(model.state, OrderedSet) |
| 114 | + else: |
| 115 | + assert isinstance(model.state, OrderedSet) |
| 116 | + assert set(model.state) == expected_ids |
| 117 | + |
| 118 | + # configuration_values -- always OrderedSet of raw values |
| 119 | + assert isinstance(sm.configuration_values, OrderedSet) |
| 120 | + assert set(sm.configuration_values) == expected_ids |
| 121 | + |
| 122 | + # configuration -- always OrderedSet[State] |
| 123 | + assert len(sm.configuration) == len(expected_ids) |
| 124 | + assert {s.id for s in sm.configuration} == expected_ids |
| 125 | + |
| 126 | + # current_state (deprecated) -- unwrapped when single |
| 127 | + with warnings.catch_warnings(): |
| 128 | + warnings.simplefilter("ignore", DeprecationWarning) |
| 129 | + cs = sm.current_state |
| 130 | + if scalar: |
| 131 | + assert not isinstance(cs, OrderedSet) |
| 132 | + assert cs.id == next(iter(expected_ids)) |
| 133 | + else: |
| 134 | + assert isinstance(cs, OrderedSet) |
| 135 | + assert {s.id for s in cs} == expected_ids |
| 136 | + |
| 137 | + |
| 138 | +# --------------------------------------------------------------------------- |
| 139 | +# Main contract matrix: topology x lifecycle x engine |
| 140 | +# --------------------------------------------------------------------------- |
| 141 | + |
| 142 | + |
| 143 | +SCENARIOS = [ |
| 144 | + # -- Flat -- |
| 145 | + pytest.param(FlatSC, [], {"s1"}, id="flat-initial"), |
| 146 | + pytest.param(FlatSC, ["go"], {"s2"}, id="flat-after-go"), |
| 147 | + pytest.param(FlatSC, ["go", "finish"], {"s3"}, id="flat-final"), |
| 148 | + # -- Compound -- |
| 149 | + pytest.param(CompoundSC, [], {"parent", "child1"}, id="compound-initial"), |
| 150 | + pytest.param(CompoundSC, ["move"], {"parent", "child2"}, id="compound-inner-move"), |
| 151 | + pytest.param(CompoundSC, ["leave"], {"done"}, id="compound-exit"), |
| 152 | + # -- Parallel -- |
| 153 | + pytest.param( |
| 154 | + ParallelSC, |
| 155 | + [], |
| 156 | + {"regions", "region_a", "a1", "region_b", "b1"}, |
| 157 | + id="parallel-initial", |
| 158 | + ), |
| 159 | + pytest.param( |
| 160 | + ParallelSC, |
| 161 | + ["go_a"], |
| 162 | + {"regions", "region_a", "a2", "region_b", "b1"}, |
| 163 | + id="parallel-one-region", |
| 164 | + ), |
| 165 | + pytest.param( |
| 166 | + ParallelSC, |
| 167 | + ["go_a", "go_b"], |
| 168 | + {"regions", "region_a", "a2", "region_b", "b2"}, |
| 169 | + id="parallel-both-regions", |
| 170 | + ), |
| 171 | + # -- Complex parallel -- |
| 172 | + pytest.param( |
| 173 | + ComplexParallelSC, |
| 174 | + [], |
| 175 | + {"top", "left", "nested", "l1", "right", "r1"}, |
| 176 | + id="complex-initial", |
| 177 | + ), |
| 178 | + pytest.param( |
| 179 | + ComplexParallelSC, |
| 180 | + ["move_l"], |
| 181 | + {"top", "left", "nested", "l2", "right", "r1"}, |
| 182 | + id="complex-nested-move", |
| 183 | + ), |
| 184 | + pytest.param( |
| 185 | + ComplexParallelSC, |
| 186 | + ["move_r"], |
| 187 | + {"top", "left", "nested", "l1", "right", "r2"}, |
| 188 | + id="complex-other-region", |
| 189 | + ), |
| 190 | + pytest.param( |
| 191 | + ComplexParallelSC, |
| 192 | + ["move_l", "move_r"], |
| 193 | + {"top", "left", "nested", "l2", "right", "r2"}, |
| 194 | + id="complex-both-regions", |
| 195 | + ), |
| 196 | + pytest.param( |
| 197 | + ComplexParallelSC, |
| 198 | + ["finish_left"], |
| 199 | + {"top", "left", "left_done", "right", "r1"}, |
| 200 | + id="complex-exit-nested", |
| 201 | + ), |
| 202 | +] |
| 203 | + |
| 204 | + |
| 205 | +@pytest.mark.parametrize(("sc_class", "events", "expected_ids"), SCENARIOS) |
| 206 | +async def test_configuration_contract(sm_runner, sc_class, events, expected_ids): |
| 207 | + model = Model() |
| 208 | + sm = await sm_runner.start(sc_class, model=model) |
| 209 | + for event in events: |
| 210 | + await sm_runner.send(sm, event) |
| 211 | + assert_contract(sm, model, expected_ids) |
| 212 | + |
| 213 | + |
| 214 | +# --------------------------------------------------------------------------- |
| 215 | +# Model setter contract |
| 216 | +# --------------------------------------------------------------------------- |
| 217 | + |
| 218 | +SETTER_SCENARIOS = [ |
| 219 | + pytest.param(FlatSC, "s2", {"s2"}, id="scalar-on-flat"), |
| 220 | + pytest.param( |
| 221 | + CompoundSC, |
| 222 | + OrderedSet(["parent", "child2"]), |
| 223 | + {"parent", "child2"}, |
| 224 | + id="orderedset-on-compound", |
| 225 | + ), |
| 226 | + pytest.param(CompoundSC, "done", {"done"}, id="scalar-collapses-orderedset"), |
| 227 | +] |
| 228 | + |
| 229 | + |
| 230 | +@pytest.mark.parametrize(("sc_class", "new_value", "expected_ids"), SETTER_SCENARIOS) |
| 231 | +async def test_setter_contract(sm_runner, sc_class, new_value, expected_ids): |
| 232 | + model = Model() |
| 233 | + sm = await sm_runner.start(sc_class, model=model) |
| 234 | + sm.current_state_value = new_value |
| 235 | + assert_contract(sm, model, expected_ids) |
| 236 | + |
| 237 | + |
| 238 | +async def test_set_none_clears_configuration(sm_runner): |
| 239 | + model = Model() |
| 240 | + sm = await sm_runner.start(FlatSC, model=model) |
| 241 | + |
| 242 | + sm.current_state_value = None |
| 243 | + |
| 244 | + assert model.state is None |
| 245 | + assert sm.current_state_value is None |
| 246 | + assert sm.configuration_values == OrderedSet([None]) |
| 247 | + assert sm.configuration == OrderedSet() |
| 248 | + |
| 249 | + |
| 250 | +# --------------------------------------------------------------------------- |
| 251 | +# Uninitialized state (async-only: sync enters initial state in __init__) |
| 252 | +# --------------------------------------------------------------------------- |
| 253 | + |
| 254 | +UNINITIALIZED_SCENARIOS = [ |
| 255 | + pytest.param(FlatSC, {"s1"}, id="flat"), |
| 256 | + pytest.param(CompoundSC, {"parent", "child1"}, id="compound"), |
| 257 | + pytest.param( |
| 258 | + ParallelSC, |
| 259 | + {"regions", "region_a", "a1", "region_b", "b1"}, |
| 260 | + id="parallel", |
| 261 | + ), |
| 262 | +] |
| 263 | + |
| 264 | + |
| 265 | +@pytest.mark.parametrize(("sc_class", "expected_ids"), UNINITIALIZED_SCENARIOS) |
| 266 | +async def test_uninitialized_then_activated(sc_class, expected_ids): |
| 267 | + from tests.conftest import _AsyncListener |
| 268 | + |
| 269 | + model = Model() |
| 270 | + sm = sc_class(model=model, listeners=[_AsyncListener()]) |
| 271 | + |
| 272 | + # Before activation: model.state is None, configuration_values wraps it |
| 273 | + assert model.state is None |
| 274 | + assert sm.current_state_value is None |
| 275 | + assert sm.configuration_values == OrderedSet([None]) |
| 276 | + assert sm.configuration == OrderedSet() |
| 277 | + |
| 278 | + # After activation: full contract holds |
| 279 | + await sm.activate_initial_state() |
| 280 | + assert_contract(sm, model, expected_ids) |
0 commit comments