Skip to content

Initial migration to pymc6#973

Open
digicosmos86 wants to merge 5 commits into
migration-pymc6from
970-compatibility-with-pymc6
Open

Initial migration to pymc6#973
digicosmos86 wants to merge 5 commits into
migration-pymc6from
970-compatibility-with-pymc6

Conversation

@digicosmos86

Copy link
Copy Markdown
Collaborator

This is the first of many PRs required to get HSSM to be compatible with PyMC6. All future PRs need to be merged to migration-pymc6 branch

Changes in this PR:

  1. updated dependencies in pyproject.toml (still many things need to be decided, such as whether to drop numpyro in favor of nutpie
  2. Updated a few utilities that post-process InferenceData object. Now that the InferenceData is replace with a xr.DataTree object, most functions that deals with the InferenceData object need to be updated
  3. I dug a little bit into the main tutorial. Many arviz functions need to be updated

@digicosmos86 digicosmos86 changed the title 970 compatibility with pymc6 [WIP] Compatibility with pymc6 Jun 3, 2026
@digicosmos86 digicosmos86 changed the title [WIP] Compatibility with pymc6 Initial compatibility fix for pymc6 compatibility Jun 5, 2026
@digicosmos86 digicosmos86 changed the title Initial compatibility fix for pymc6 compatibility Initial migration to pymc6 Jun 5, 2026
@digicosmos86 digicosmos86 requested review from AlexanderFengler, Copilot, cpaniaguam and krishnbera and removed request for AlexanderFengler June 5, 2026 19:08
@digicosmos86 digicosmos86 marked this pull request as ready for review June 5, 2026 19:08

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR starts the migration of HSSM toward compatibility with the newer PyMC/Bambi/ArviZ stack by updating dependencies, switching parts of the trace/log-likelihood pipeline from ArviZ InferenceData to xarray.DataTree, and adjusting a test import for the updated PyTensor module layout.

Changes:

  • Update project dependency constraints and Python minimum version in pyproject.toml.
  • Refactor log-likelihood computation helper to operate on xr.DataTree instead of az.InferenceData.
  • Update sampling post-processing to clean the posterior group using DataTree-style access, plus a PyTensor test import adjustment.

Reviewed changes

Copilot reviewed 4 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/test_likelihoods.py Updates NanGuardMode import path for PyTensor’s new module location.
src/hssm/utils.py Refactors _compute_log_likelihood to use xr.DataTree and updates log-likelihood group assignment.
src/hssm/base.py Updates sample() return typing and switches posterior cleanup to DataTree-style manipulation.
pyproject.toml Adjusts Python requirement and dependency set for the migration.
.gitignore Ignores uv.lock.
Comments suppressed due to low confidence (2)

src/hssm/base.py:557

  • The return-type annotation for sample() was updated to xr.DataTree | pm.Approximation, but the docstring still states that model.traces is an ArviZ InferenceData for most samplers. This is misleading for users and should be updated to match the new trace container type.
        xr.DataTree | pm.Approximation
            A reference to the `model.traces` object, which stores the traces of the
            last call to `model.sample()`. `model.traces` is an ArviZ `InferenceData`
            instance if `sampler` is `"pymc"` (default), `"numpyro"`,
            `"blackjax"` or "`laplace".

pyproject.toml:38

  • pymc and arviz were removed from the project dependencies, but they are imported throughout the package (e.g., src/hssm/base.py, src/hssm/utils.py). Relying on transitive dependencies from bambi is fragile and can break installs if upstream changes extras/optional deps; declare these as direct dependencies again.
dependencies = [
    "absl-py>=2.3.1",
    "bambi>=0.18.0",
    "cloudpickle>=3.0.0",
    "hddm-wfpt>=0.1.6",
    "huggingface-hub>=1.17.0",
    "jaxonnxruntime>=0.3.0",
    "numpyro>=0.19",
    "onnx>=1.16.0",
    "pandas>=2.2,<3",
    "seaborn>=0.13.2",
    "ssm-simulators>=0.12.2",
    "tqdm>=4.66.0",
]

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/hssm/base.py
"posterior",
idata["posterior"][vars_to_keep_clean],
)
dt["posterior"] = dt["posterior"].to_dataset()[vars_to_keep_clean]
Comment thread src/hssm/base.py
# We drop all distributional components, IF they are deterministics
# (in which case they will be trial wise systematically)
# and we keep distributional components, IF they are
# basic random-variabels (in which case they should never
Comment thread src/hssm/utils.py
)

required_kwargs = {"model": model, "posterior": idata["posterior"], "data": data}
required_kwargs = {"model": model, "posterior": dt["posterior"], "data": data}

@cpaniaguam cpaniaguam left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider marking the yet-to-address broken tests as failing.

@digicosmos86

Copy link
Copy Markdown
Collaborator Author

Consider marking the yet-to-address broken tests as failing.

Through what mechanism? Commenting in code?

@cpaniaguam

Copy link
Copy Markdown
Collaborator

@digicosmos86

Through what mechanism? Commenting in code?

You can use pytest.xfail.

@cpaniaguam

cpaniaguam commented Jun 5, 2026

Copy link
Copy Markdown
Collaborator

@digicosmos86

Through what mechanism? Commenting in code?

See https://docs.pytest.org/en/stable/how-to/skipping.html#xfail

@digicosmos86

Copy link
Copy Markdown
Collaborator Author

@digicosmos86

Through what mechanism? Commenting in code?

See https://docs.pytest.org/en/stable/how-to/skipping.html#xfail

I think that's fair only when there's not a whole lot of tests breaking. Otherwise adding the marks itself could be labor intensive. Why don''t we evaluate after merging your recent PRs? Looks like a few broken tests were fixed in them

@cpaniaguam

Copy link
Copy Markdown
Collaborator

Why don''t we evaluate after merging your recent PRs? Looks like a few broken tests were fixed in them

Those were just two of the fast tests. There are many other slow ones that are failing (see #982).

I think that's fair only when there's not a whole lot of tests breaking.

I understand your point. I still think there’s value in marking the failures we already understand, since it keeps CI easier to read and helps distinguish known issues from new regressions. Once a fix is introduced, those tests will show up as xpassed in the summary, so we’ll have a clear signal to remove the markers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants