Skip to content

Auto-validate state transitions; make log_level the runtime-validation policy#360

Merged
hmgaudecker merged 42 commits into
mainfrom
feat/phase-1b-auto-state-transition-validation
May 25, 2026
Merged

Auto-validate state transitions; make log_level the runtime-validation policy#360
hmgaudecker merged 42 commits into
mainfrom
feat/phase-1b-auto-state-transition-validation

Conversation

@hmgaudecker

@hmgaudecker hmgaudecker commented May 18, 2026

Copy link
Copy Markdown
Member

Stacked on top of #359. Merge that one first.

Summary

pylcm did almost all validation checks internally, except for exposing validate_transition_probs and leaving it to the user to validate state transition probs, whereas we had an internal mechanism for regime transition probs. Furthermore, which validation was done was dependent on a rather obscure / ad-hoc mix of settings.

This PR unifies validation and control over it by adding runtime-validation of state transition probabilities and controlling all costly validation via log_level (off < warning < progress < debug) and log_path:

log_level log_path Runtime validation Console output Snapshots to disk
"off" (ignored) not run silent none
"warning" None runs → failures warn warnings none
"warning" set runs → failures warn warnings one per warned failure, capped at log_keep_n_latest
"progress" None runs → failures warn warnings + timing none
"progress" set runs → failures warn warnings + timing one per warned failure, capped at log_keep_n_latest
"debug" None runs → failures raise warnings + timing + V_arr stats none
"debug" set runs → failures raise warnings + timing + V_arr stats one per solve and on raise, capped at log_keep_n_latest

validate_transition_probs is removed from the public API.

log_level governs only the costly runtime numerical validation — the transition-probability sweeps, the initial-conditions check, and the NaN/Inf check on the value-function arrays. The cheap construction-time sanity checks (regime/model structure, grid definitions, function signatures, the probs_array subscript-order check) always run when a Regime or Model is built, regardless of log_level.

log_level is a required argument

solve() and simulate() no longer default log_level — the caller picks it deliberately. Start every project at "debug" (fail early, gather full diagnostics) and ease to "warning" / "off" only once the model is trusted and the run needs the speed or the non-raising behaviour. A loose default would hide that "debug" exists; a "debug" default would make pylcm look slow.

simulate()'s check_initial_conditions flag is removed too: initial-conditions validation now follows the same log_level policy as every other runtime check.

Runtime validation

log_level is the single knob. The logger it produces carries the policy — there is no separate validation_mode value threaded around the engine. It governs four checks, run before backward induction:

  • State transition probabilities — every MarkovTransition state transition is swept over the regime's grid; outcome-axis size, [0, 1] range, and sum-to-1 are checked.
  • Regime transition probabilities — finiteness, [0, 1] range, sum-to-1, no probability mass to inactive regimes, and no mass to targets with incomplete stochastic transitions.
  • Initial conditions (simulate() only) — states on-grid, regime IDs valid, at least one feasible action combination per subject.
  • Value function — NaN/Inf check on each period's V_arr, with the offending (regime, period) localised.

"off" skips all four; "warning" / "progress" log a warning and let the run continue (the returned solution may carry NaN); "debug" raises on the first failure.

Construction-time guard

batch_size > 0 paired with distributed=True on a single axis is rejected at grid init. Each Python-level batch is its own jax.jit dispatch in the solve loop, and on a distributed axis every dispatch carries a cross-device collective; batching therefore multiplies the per-period collective count by ceil(n_per_device / batch_size), and for small batch_size the collective overhead per kernel dwarfs the compute per kernel — sharding becomes a regression rather than a speedup. _fail_if_batch_size_combined_with_distributed in grids/base.py fires from _init_uniform_grid (covers Lin/LogSpacedGrid), IrregSpacedGrid.__init__, and DiscreteGrid.__init__; piecewise grids inherit batch_size=0, distributed=False defaults from ContinuousGrid and need no change. Error message points at the right escape valves (more devices or another distributed axis) rather than restoring batch_size. Construction-time tests cover all four grid types.

Test plan

  • pixi run -e tests-cpu tests
  • pixi run ty
  • prek run --all-files

🤖 Generated with Claude Code

hmgaudecker and others added 8 commits May 18, 2026 18:57
Move `collect_state_transitions`, `_make_identity_fn`, and
`_add_raw_transition` from `regime_building/validation.py` to a new
focused module. Update imports in 5 callers. Drop unused imports from
`validation.py`. Part of the Phase 1 effort to delete the
"validation" and "error_handling" umbrellas; see
`Phase 1 — Validation Cleanup.md`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move 8 validators from `regime_building/validation.py` into
`user_regime.py` and privatize the two formerly-public ones:
- `validate_mapping_contents` -> `_validate_mapping_contents`
- `validate_logical_consistency` -> `_validate_logical_consistency`
- `_validate_distributed_grids`, `_validate_function_output_grid_indexing`,
  `_find_function_output_grid_indexing`, `_validate_active`,
  `_validate_state_transitions`, `_validate_per_target_dict`

The validators are sole-called from `UserRegime.__post_init__`;
co-locating them with the class eliminates a misleading umbrella
module and the cross-module delayed import in `__post_init__`. Delete
`regime_building/validation.py`. Part of Phase 1 — see
`Phase 1 — Validation Cleanup.md`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Move `_get_func_indexing_params`, `_slice_references_params`,
`_collect_subscripts`, `_extract_bare_names` from
`utils/error_handling.py` into a new focused module. Update imports
in `pandas_utils.py` and `tests/test_validate_array_indexing.py`.
`error_handling.py::validate_transition_probs` still uses
`_get_func_indexing_params` and now imports it from the new module
(deferred to M5' for full extraction). Part of Phase 1 — see
`Phase 1 — Validation Cleanup.md`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Create `regime_building/runtime_checks.py` and absorb two families
from `utils/error_handling.py`:
- V family (`validate_V`, `_enrich_with_diagnostics`,
  `_summarize_diagnostics`, `_format_diagnostic_summary`)
- regime-prob family (`validate_regime_transition_probs` ->
  `_validate_regime_transition_probs`, `_format_sum_violation`,
  `validate_regime_transitions_all_periods`,
  `_validate_regime_transition_single`,
  `_validate_no_reachable_incomplete_targets`)

Both families fit the unifying concept "defensive checks on JAX
arrays produced during solve/simulate." Privatize
`validate_regime_transition_probs` (only tests call it directly).
Update imports in 6 callers (3 src, 3 test). `diagnostics.py` keeps
its name. Part of Phase 1 — see `Phase 1 — Validation Cleanup.md`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_regime.py

Drop the regime-probs overload — that mode is redundant with
`validate_regime_transitions_all_periods` which runs unconditionally
during `model.solve()` and `model.simulate()` and additionally
checks inactive-regime probability and reachability.

Keep the state-probs mode. It is the only defence against four
silent-correctness bug classes in user-written MarkovTransition
functions for states: wrong-shape broadcasting, values outside
[0, 1], rows not summing to 1, and subscript-order swaps relative
to the function signature.

Move slimmed function + helpers (`_extract_markov_transition`,
`_build_grids`, `_build_expected_shape`) to `user_regime.py` next to
the `Regime` class they operate on. Update `lcm/__init__.py` import.
Drop three regime-probs tests from `tests/test_pandas_utils.py`.
Update `docs/user_guide/pandas_interop.md` accordingly. Part of
Phase 1 — see `Phase 1 — Validation Cleanup.md`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
All contents have been absorbed: V family + regime-prob family into
`regime_building/runtime_checks.py` (M4); AST helpers into
`utils/ast_inspection.py` (M3); slimmed `validate_transition_probs`
into `user_regime.py` (M5'). The "error_handling" umbrella was
misleading from the start — three unrelated concerns under one
name. Part of Phase 1 — see `Phase 1 — Validation Cleanup.md`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…me.py

Three renames driven by Phase 1's src moves:
- `tests/test_error_handling_invalid_vf.py` → `tests/test_invalid_vf.py`
  (the "error_handling" concept is gone from src).
- `tests/test_validate_array_indexing.py` → `tests/test_ast_inspection.py`
  (AST helpers moved to `lcm/utils/ast_inspection.py`).
- Extract the four `validate_transition_probs` state-probs tests from
  `tests/test_pandas_utils.py` into `tests/test_regime.py` (the
  function now lives in `lcm/user_regime.py`). Duplicate the
  three-line `_make_partner_probs_array` helper into the new
  location rather than imposing a cross-file import.

Closes the Phase 1 plan in `Phase 1 — Validation Cleanup.md`.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Validate state transition probability functions automatically — both
statically at process time and numerically at solve time — so users no
longer need to call `lcm.validate_transition_probs` manually for state
transitions. Plan: `Phase 1b — Automatic State Transition Validation.md`.

What runs when:
- **Process time** (during `process_regimes`, always on, cheap):
  AST subscript-order check on every `MarkovTransition.func` —
  permissive: skipped when the function doesn't use the
  `probs_array[...]` pattern. Outcome-axis size is derived from the
  state's `DiscreteGrid` and cached on the canonical `Regime` via the
  new `stochastic_state_transitions` field. For per-target dicts, the
  target regime's grid wins (cross-grid state spaces).
- **Solve / simulate time** (gated by `log_level != "off"`):
  new `validate_state_transitions_all_periods` evaluates each
  `MarkovTransition` function on the Cartesian product of the
  function's accepted grid args (via vmap) and checks outcome-axis
  size, [0, 1] range, and sum-to-1 along the last axis. Raises a new
  `InvalidStateTransitionProbabilitiesError` on failure.

Fast-exits when no regime has any `MarkovTransition` state transition.

The slimmed `lcm.validate_transition_probs` (Phase 1) is deprecated
with a `DeprecationWarning` pointing at the automatic validator. It
will be removed in a subsequent phase.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@read-the-docs-community

read-the-docs-community Bot commented May 18, 2026

Copy link
Copy Markdown

@github-actions

github-actions Bot commented May 18, 2026

Copy link
Copy Markdown

Benchmark comparison (main → HEAD)

Comparing 339b8ffc (main) → ded3a44d (HEAD)

Benchmark Statistic before after Ratio Alert
aca-baseline execution time 27.024 s 14.603 s 0.54
peak GPU mem 4.33 GB 581 MB 0.13
compilation time 297.54 s 274.09 s 0.92
peak CPU mem 6.96 GB 6.56 GB 0.94
aca-baseline-debug execution time 79.754 s
peak GPU mem 581 MB
compilation time 376.37 s
peak CPU mem 7.54 GB
Mahler-Yum execution time 4.602 s 4.268 s 0.93
peak GPU mem 529 MB 529 MB 1.00
compilation time 13.84 s 12.78 s 0.92
peak CPU mem 1.72 GB 1.69 GB 0.98
Precautionary Savings - Solve execution time 46.4 ms 25.9 ms 0.56
peak GPU mem 101 MB 101 MB 1.00
compilation time 2.69 s 2.14 s 0.79
peak CPU mem 1.15 GB 1.12 GB 0.98
Precautionary Savings - Simulate execution time 123.0 ms 103.4 ms 0.84
peak GPU mem 349 MB 349 MB 1.00
compilation time 4.80 s 4.83 s 1.01
peak CPU mem 1.34 GB 1.32 GB 0.99
Precautionary Savings - Solve & Simulate execution time 147.6 ms 125.3 ms 0.85
peak GPU mem 586 MB 586 MB 1.00
compilation time 7.28 s 6.27 s 0.86
peak CPU mem 1.31 GB 1.28 GB 0.98
Precautionary Savings - Solve & Simulate (irreg) execution time 294.2 ms 265.3 ms 0.90
peak GPU mem 2.20 GB 2.20 GB 1.00
compilation time 7.28 s 6.66 s 0.92
peak CPU mem 1.36 GB 1.34 GB 0.99

@hmgaudecker hmgaudecker changed the title Phase 1b: auto-validate stochastic state transition probabilities Auto-validate stochastic state transition probabilities May 19, 2026
hmgaudecker and others added 5 commits May 19, 2026 14:40
…cks.py

regime_building/ is the build-time pipeline (UserRegime → engine.Regime).
The runtime checks fire on solve / simulate, not at build, so they don't
belong here. Split by caller:

- validate_V (+ its diagnostic-attachment helpers) is a tight subroutine
  of the backward-induction loop in solve_brute.py and the V handed to
  simulate.py. Moves to solution/validate_V.py (solve owns producing V;
  simulation imports from there).
- The regime-transition-probability pre-flight sweep is called from
  Model.solve()/simulate() before backward induction runs. Moves to a
  top-level lcm/_transition_checks.py — orthogonal to solution/ and
  simulation/, sibling of engine.py and model_processing.py.

Imports updated in model.py, solve_brute.py, simulate.py and three tests.
Docstring on user_regime.validate_transition_probs updated to point at the
new location.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…checks.py

phase-1 split runtime_checks.py into solution/validate_V.py and
lcm/_transition_checks.py. phase-1b added a third runtime family (state
transition probability validation) to the old runtime_checks.py. After
this merge, that family lives in lcm/_transition_checks.py beside the
regime-prob family.

Resolution details:
- runtime_checks.py: take the phase-1 deletion; state-prob functions land
  in _transition_checks.py alongside the regime-prob family.
- model.py: import both validate_regime_transitions_all_periods and
  validate_state_transitions_all_periods from lcm._transition_checks.
- Docstrings in exceptions.py, interfaces.py, regime_building/static_checks.py,
  and user_regime.py updated to reference lcm/_transition_checks.py.
- Test file renamed: tests/regime_building/test_state_transition_validation.py
  → tests/test_transition_checks.py (the source it covers is no longer in
  regime_building/).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- `_transition_checks.py` module docstring no longer references a
  non-existent `regime_building/static_checks.py`; it describes the
  actual distinction (runtime numerical checks vs. construction-time
  regime-spec validators).
- `regime_building/transitions.py`: the `collect_state_transitions`
  docstring said ShockGrid states get a `lambda: None` stub; the code
  skips them entirely. Docstring corrected.
- `solution/validate_V.py`: `Dict` → `dict` in a Returns section to
  match the `dict[str, Any]` annotation.
- `_transition_checks.py` `_validate_no_reachable_incomplete_targets`:
  drop the override that, when a target regime was absent from the
  source's `state_transitions`, listed every state of the target as
  "missing" — including non-stochastic states that need no explicit
  entry. The preceding line already computes the correct stochastic-
  only set for that case.
- `user_regime.py`: add a scope-boundary note to
  `_validate_function_output_grid_indexing` — the AST check is
  deliberately best-effort and should be deleted rather than hardened
  if it ever produces false positives.
- Replace the two contentless happy-path validator tests with boundary
  inputs: values at the inclusive [0, 1] bounds and row sums just
  inside the sum-to-1 tolerance, so "does not raise" pins the
  tolerance/bound logic instead of being a bare smoke check.
…-auto-state-transition-validation

# Conflicts:
#	tests/test_regime.py
Resolves the PR review's main finding: `log_level="off"` silently
disabled state-transition validation while regime-transition validation
still ran unconditionally — an asymmetric footgun.

`log_level` now governs all runtime validation uniformly:

- `"off"` — validation does not run.
- `"warning"` / `"progress"` — validation runs; failures are logged as
  warnings and the run continues.
- `"debug"` — validation runs and raises on the first failure.

The default `log_level` moves from `"progress"` to `"debug"`, so the
default `solve()` / `simulate()` validates and raises (secure default).
The mode applies to state-transition checks, regime-transition checks
(`validate_regime_transitions_all_periods` — previously unconditional),
and the `validate_V` NaN check.

`debug` no longer requires `log_path`: the `_validate_log_args` rule is
removed (it would make the new default unusable without a path).
`log_path` is optional everywhere; snapshots are written only when set.

Warn-mode disk safety: in warn mode an invalid model keeps running, so a
diagnostic snapshot is written on each warned NaN failure (when log_path
is set), retention-capped at `log_keep_n_latest`. `_enforce_retention`
now orders snapshot directories by parsed integer counter rather than
lexically, so retention stays correct past 999 iterations.

Review fixes:

- `_check_subscript_order` runs after the `DiscreteGrid` guard, so a
  continuous-state `MarkovTransition` no longer gets a spurious
  process-time raise.
- `_find_state_grid` returns `None` for a per-target dict whose target
  lacks the state, rather than sizing `n_outcomes` off the source grid.
- `_validate_state_transition_single` warns instead of silently
  skipping a transition with an unrecognized parameter.
- Docstrings drop "now" history wording, the rST `.. deprecated::`
  directive, and hard-coded internal module paths.

Tests: a hidden-invalidity test (valid at some grid points, invalid at
others, swept via the continuous grid), warn/raise-per-level coverage,
and a parametrized check pinning the `log_level` x `log_path` snapshot
table. Docs updated with the full behaviour table.
@hmgaudecker hmgaudecker changed the title Auto-validate stochastic state transition probabilities Auto-validate state transitions; make log_level the runtime-validation policy May 20, 2026
State and regime transition probabilities are validated automatically
during solve()/simulate(), gated by log_level. The standalone
validate_transition_probs entry point and its helpers are redundant, so
drop them along with their tests and doc references.

Also trim a redundant diagnostics_enabled guard in solve_brute.py: the
"raise" validation mode already implies diagnostics are enabled.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@review-notebook-app

Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

hmgaudecker and others added 2 commits May 20, 2026 11:50
The per-period NaN/Inf tracking in solve() exists to feed runtime
validation. Gating it on logger.isEnabledFor(WARNING) duplicated the
log-level partition that validation_mode already encodes. Derive the
gate from validation_mode != "off" so its source matches its purpose;
behaviour is unchanged.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- Re-array notebook cell sources in stochastic_transitions.ipynb so each
  line is its own JSON element (one-string sources produce noisy diffs).
- Drop the stale per-level table from debugging.md; it duplicated and
  had drifted from the canonical log_level x log_path table in
  solving_and_simulating.md, which debugging.md already links to.
- Trim "per-period timing" to "timing" in the behaviour table.
- Document the notebook cell-formatting check in AGENTS.md.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
hmgaudecker and others added 6 commits May 20, 2026 13:52
The file tests the user-facing `Regime`; `user_regime` disambiguates it
from the canonical `Regime` and the `regime_building` / regime-template
modules.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`derive_stochastic_state_transitions` becomes
`collect_stochastic_state_transitions`, mirroring `collect_state_transitions`
(the structurally identical walk over `state_transitions`). Both collectors
now live in `regime_building/transitions.py`; `static_checks.py` is removed.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The validation policy (off / warn / raise) was threaded as a separate
`validation_mode` argument alongside `logger` through solve(), both
transition validators, and _solve_compiled — carrying the same
information twice, since both derive from `log_level`.

The logger is now the single source of truth. Two named predicates,
`validation_enabled()` and `validation_raises()`, read the policy off
the logger's level; `raise_or_warn()` drops its `mode` parameter.
`ValidationMode`, `_VALIDATION_MODE_MAP`, and `get_validation_mode` are
removed.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
`log_level` no longer defaults to `"debug"`. Forcing the caller to pass
it makes the choice deliberate: start at `"debug"` (fail early, full
diagnostics) and ease to `"warning"` / `"off"` only once the model is
trusted and the run needs the speed or non-raising behaviour. A loose
default would hide that `"debug"` exists; a `"debug"` default would make
pylcm look slow.

Sweeps every solve/simulate call site in the test suite and docs to
pass `log_level` explicitly.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Drop the `check_initial_conditions` parameter of `simulate()`. Initial-
conditions validation now follows the same `log_level` policy as the
transition checks: `"off"` skips it, `"warning"` / `"progress"` warn,
`"debug"` raises. One knob governs all runtime validation.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
hmgaudecker and others added 3 commits May 22, 2026 17:50
- CI pixi 0.68.1 → 0.69.0, setup-pixi 0.9.5 → 0.9.6
- ruff-pre-commit 0.15.12 → 0.15.14, pyproject-fmt 2.21.1 → 2.21.2
- bump the .ai-instructions submodule to latest main

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- CI pixi 0.68.1 → 0.69.0, setup-pixi 0.9.5 → 0.9.6
- ruff-pre-commit 0.15.12 → 0.15.14, pyproject-fmt 2.21.1 → 2.21.2
- bump the .ai-instructions submodule to latest main

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The GPU benchmark runner now ships pixi 0.69, which writes lockfile
format v7; regenerate the lock so it matches.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@hmgaudecker hmgaudecker force-pushed the feat/phase-1b-auto-state-transition-validation branch from a8e0202 to e1be354 Compare May 22, 2026 16:02
The GPU benchmark runner now ships pixi 0.69, which writes lockfile
format v7; regenerate the lock so it matches.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@hmgaudecker hmgaudecker force-pushed the refactor/phase-1-validation-cleanup branch from 923c419 to 957ba87 Compare May 22, 2026 16:02
hmgaudecker and others added 9 commits May 22, 2026 18:20
The GPU benchmark runner has pixi installed at ~/.pixi/bin, which a
non-login Actions shell does not pick up. Prepend it to GITHUB_PATH so
the bare `pixi install` step resolves.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The GPU benchmark runner's pixi 0.69 install lives in `~/.pixi/bin`,
which a non-login Actions shell does not pick up. Prepend it to
`$GITHUB_PATH` so the bare `pixi` invocations resolve.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The GPU benchmark runner's pixi 0.69 install resolves on the bare
`pixi` invocations without a $GITHUB_PATH prepend, same as before the
lockfile bump. Restore the workflows to that state.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The GPU benchmark runner's pixi 0.69 install resolves on the bare
`pixi` invocations without a $GITHUB_PATH prepend, same as before the
lockfile bump. Restore the workflows to that state.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Renames `AcaBaselineDebugLog`'s display label to `aca-baseline-debug`
and places it second in the PR-comment table, right after the
`aca-baseline` block. Adds `AcaBaselineDebugLogGpuPeakMem` so the
debug-mode block carries a peak GPU mem row symmetric with
`aca-baseline`.

`AcaBaselineDebugLog.setup_for_gpu_measurement` mirrors `setup`'s
`log_path` setup so the cold-measurement subprocess exercises
snapshot writing too. The tmpdir leaks at subprocess exit — `/tmp`
gets OS-cleaned, and the subprocess doesn't run ASV's teardown path.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Adds `paths-ignore` to the `pull_request` triggers of `main` and
`benchmark-pr`. Doc-only PRs (Markdown, notebooks under `docs/`)
no longer spin up the GPU runner pool or the self-hosted benchmark
runner. `main` also skips when the diff is benchmark-only — the
benchmark workflow covers that surface. Pushes to `main` still
exercise the full matrix; src/test changes still trigger everything.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ransition validator

`regime.resolved_fixed_params` and the per-iteration `flat_params` for a
regime both key their entries by qualified names (`next_<state>__<param>`,
or `next_<state>__<target>__<param>` for per-target dicts). The validator
calls the `MarkovTransition`'s user function with the raw parameter names
from its signature, so without the strip every transition-function param
that isn't a grid axis falls through to the "not numerically validated"
skip branch and the per-transition numerical check never runs.

Adds a `_params_callable_for_state_transition` helper that merges fixed
and flat params (same merge order as `solve`) and returns a
`FlatRegimeParams` keyed by the raw signature names accepted by one
specific transition. The state-transition validator calls into it before
dispatching to `_validate_state_transition_single`.

Adds two regression tests on a model whose `health` `MarkovTransition`
reads a parameter from `fixed_params`:
- one asserts no "not numerically validated" warning fires;
- one asserts that an invalid probability *is* surfaced at log_level=debug,
  proving the validator actually ran rather than silently skipping.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Each Python-level batch is its own `jax.jit` dispatch in the solve
loop, and on a distributed axis every dispatch carries a cross-device
collective. Batching therefore multiplies the per-period collective
count by `ceil(n_per_device / batch_size)`; for small `batch_size`
the collective overhead per kernel dwarfs the compute per kernel and
sharding becomes a regression rather than a speedup.

Adds `_fail_if_batch_size_combined_with_distributed` in `grids/base.py`
and calls it from `_init_uniform_grid` (covers Lin/LogSpacedGrid),
`IrregSpacedGrid.__init__`, and `DiscreteGrid.__init__`. Piecewise
grids inherit `batch_size=0, distributed=False` defaults from
`ContinuousGrid` and don't expose them in `__init__`, so they need
no change.

Error message points users at the right escape valves — more devices
or another distributed axis — rather than restoring batch_size.

Adds construction-time tests across all four grid types: the (bs=1,
distributed=True) combo raises, the (bs=0, distributed=True) combo
constructs cleanly.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
hmgaudecker added a commit that referenced this pull request May 23, 2026
The cascade merge from #360 brought in the call sites under
`src/_lcm/grids/{continuous,discrete}.py` but the existing `from
_lcm.grids.base import Grid` lines on #361 weren't extended to also
import the helper. ty caught it on the post-cascade run; now both
modules import the helper alongside `Grid`.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@hmgaudecker

Copy link
Copy Markdown
Member Author

@mj023, in case you are wondering about 867a362, here are some timings. We always distribute the assets grid.

GPUs n_assets_bs n_aime_bs 3-reg per age ratio 4-reg per age ratio
4 1 0 1.2 s 7.7× OOM(*)
4 1 1 2.4 s 15.5× 16.2 m 8.5×
1 1 1 2.1 m 1.10×
1 0 1 155 ms 1.9 m
4 0 1 94 ms 0.61× 0.41 m 0.22×
4 0 0 68 ms 0.44× 0.40 m 0.21×

(*) died at solve running_any_nan.block_until_ready() (RESOURCE_EXHAUSTED: 81.53 GiB on device 0).

I certainly did not expect this...

@hmgaudecker

Copy link
Copy Markdown
Member Author

But ofc the great thing here is that we are supra-linear in the number of GPUs once configured correctly! 🎉

The pre-flight numerical validator's `_params_callable_for_state_transition`
strips the qualified prefix from `regime.resolved_fixed_params` /
`flat_params` so the user's transition function can be called with its
raw parameter names. The prefix for per-target dict transitions used
`next_<state>__<target>__`, but `create_regime_params_template` builds
the canonical key as `to_<target>_next_<state>__<param>`. The mismatch
made every per-target MarkovTransition with a custom param fall through
to the "not numerically validated" skip-and-warn branch — the per-target
numerical check never ran in production.

Aligning the validator's prefix with the template builder's key lets
per-target transitions exercise the same numerical-validation path that
already covers simple `next_<state>__<param>` transitions.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@hmgaudecker

hmgaudecker commented May 24, 2026

Copy link
Copy Markdown
Member Author

The check just did its job:

● I see the root cause: _build_health_trans_cross only emits age 64 in its series;
  pylcm fills other ages with NaN, so the validator sees NaN rows at ages 51-63 where the
  source regime is active.

  Fix: broadcast the age-64 row across the source regimes' full pre-65 active range (51-64)
  so every row sums to 1. 
  Functional behavior unchanged (only age-64 values are ever
  invoked at runtime via the regime-transition gate).

Well, turns out it was too greedy in that case, but still... 😆

`solve()` and `simulate()` only dispatch a per-target MarkovTransition
for the targets in `active_regimes_next_period` at the source's period;
targets that deactivate before the source can reach them never fire at
runtime. The pre-solve validator mirrors that gate so a per-target
function whose output shape only needs to match the (always-zero-
weighted) target's outcome grid in principle is not numerically
evaluated against the source's state grid.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

@mj023 mj023 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good changes. I can see how batched + distributed would be problematic, good catch! I am wondering if the axis order might also be important, e.g. if it is better to have the distributed axes on the outside in productmap.

Base automatically changed from refactor/phase-1-validation-cleanup to main May 25, 2026 05:35
hmgaudecker and others added 2 commits May 25, 2026 07:35
The merge of main into this branch brought in 4 tests that called
`validate_transition_probs` — a function this branch deleted as part
of the auto-validate refactor. Those tests are dead; the auto-validator
covers the same ground. Also drops the corresponding unused imports
(`jnp`, `_get_func_indexing_params`, `TYPE_CHECKING`) and the dangling
comment in `user_regime.py` that referenced the removed function. The
prior amend missed `pixi.lock` (the jaxtyping 0.3.10 bump) — re-locked.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@hmgaudecker hmgaudecker force-pushed the feat/phase-1b-auto-state-transition-validation branch from bb35a85 to 4741d53 Compare May 25, 2026 05:59
Adds distributed-first ordering to `_ordered_state_action_names` so the
sharded axis becomes the outermost productmap axis within its topology
group. XLA can then place the cross-device collective at the outer
loop, wrapping a purely per-device kernel.

Sort key per state is `(not distributed, batch_size)` with 0 last.
@hmgaudecker hmgaudecker merged commit 9f9483f into main May 25, 2026
10 checks passed
@hmgaudecker hmgaudecker deleted the feat/phase-1b-auto-state-transition-validation branch May 25, 2026 08:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants