Initial migration to pymc6#973
Conversation
There was a problem hiding this comment.
Pull request overview
This PR starts the migration of HSSM toward compatibility with the newer PyMC/Bambi/ArviZ stack by updating dependencies, switching parts of the trace/log-likelihood pipeline from ArviZ InferenceData to xarray.DataTree, and adjusting a test import for the updated PyTensor module layout.
Changes:
- Update project dependency constraints and Python minimum version in
pyproject.toml. - Refactor log-likelihood computation helper to operate on
xr.DataTreeinstead ofaz.InferenceData. - Update sampling post-processing to clean the posterior group using
DataTree-style access, plus a PyTensor test import adjustment.
Reviewed changes
Copilot reviewed 4 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
tests/test_likelihoods.py |
Updates NanGuardMode import path for PyTensor’s new module location. |
src/hssm/utils.py |
Refactors _compute_log_likelihood to use xr.DataTree and updates log-likelihood group assignment. |
src/hssm/base.py |
Updates sample() return typing and switches posterior cleanup to DataTree-style manipulation. |
pyproject.toml |
Adjusts Python requirement and dependency set for the migration. |
.gitignore |
Ignores uv.lock. |
Comments suppressed due to low confidence (2)
src/hssm/base.py:557
- The return-type annotation for
sample()was updated toxr.DataTree | pm.Approximation, but the docstring still states thatmodel.tracesis an ArviZInferenceDatafor most samplers. This is misleading for users and should be updated to match the new trace container type.
xr.DataTree | pm.Approximation
A reference to the `model.traces` object, which stores the traces of the
last call to `model.sample()`. `model.traces` is an ArviZ `InferenceData`
instance if `sampler` is `"pymc"` (default), `"numpyro"`,
`"blackjax"` or "`laplace".
pyproject.toml:38
pymcandarvizwere removed from the project dependencies, but they are imported throughout the package (e.g.,src/hssm/base.py,src/hssm/utils.py). Relying on transitive dependencies frombambiis fragile and can break installs if upstream changes extras/optional deps; declare these as direct dependencies again.
dependencies = [
"absl-py>=2.3.1",
"bambi>=0.18.0",
"cloudpickle>=3.0.0",
"hddm-wfpt>=0.1.6",
"huggingface-hub>=1.17.0",
"jaxonnxruntime>=0.3.0",
"numpyro>=0.19",
"onnx>=1.16.0",
"pandas>=2.2,<3",
"seaborn>=0.13.2",
"ssm-simulators>=0.12.2",
"tqdm>=4.66.0",
]
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| "posterior", | ||
| idata["posterior"][vars_to_keep_clean], | ||
| ) | ||
| dt["posterior"] = dt["posterior"].to_dataset()[vars_to_keep_clean] |
| # We drop all distributional components, IF they are deterministics | ||
| # (in which case they will be trial wise systematically) | ||
| # and we keep distributional components, IF they are | ||
| # basic random-variabels (in which case they should never |
| ) | ||
|
|
||
| required_kwargs = {"model": model, "posterior": idata["posterior"], "data": data} | ||
| required_kwargs = {"model": model, "posterior": dt["posterior"], "data": data} |
cpaniaguam
left a comment
There was a problem hiding this comment.
Consider marking the yet-to-address broken tests as failing.
Through what mechanism? Commenting in code? |
You can use pytest.xfail. |
See https://docs.pytest.org/en/stable/how-to/skipping.html#xfail |
I think that's fair only when there's not a whole lot of tests breaking. Otherwise adding the marks itself could be labor intensive. Why don''t we evaluate after merging your recent PRs? Looks like a few broken tests were fixed in them |
Those were just two of the fast tests. There are many other slow ones that are failing (see #982).
I understand your point. I still think there’s value in marking the failures we already understand, since it keeps CI easier to read and helps distinguish known issues from new regressions. Once a fix is introduced, those tests will show up as |
This is the first of many PRs required to get HSSM to be compatible with PyMC6. All future PRs need to be merged to
migration-pymc6branchChanges in this PR:
pyproject.toml(still many things need to be decided, such as whether to dropnumpyroin favor ofnutpieInferenceDataobject. Now that theInferenceDatais replace with axr.DataTreeobject, most functions that deals with theInferenceDataobject need to be updated