Skip to content

Cpaniaguam/fix rl likelihood builder pymc6#975

Draft
cpaniaguam wants to merge 4 commits into
970-compatibility-with-pymc6from
cpaniaguam/fix-rl-likelihood-builder-pymc6
Draft

Cpaniaguam/fix rl likelihood builder pymc6#975
cpaniaguam wants to merge 4 commits into
970-compatibility-with-pymc6from
cpaniaguam/fix-rl-likelihood-builder-pymc6

Conversation

@cpaniaguam

@cpaniaguam cpaniaguam commented Jun 3, 2026

Copy link
Copy Markdown
Collaborator

Prevent PyTensor from incorrectly optimizing JAX-backed operations.

JAX-PyTensor Integration Improvements:

  • Added do_constant_folding methods to both LANLogpOp and LANLogpVJPOp classes in src/hssm/distribution_utils/jax.py to prevent PyTensor from attempting to precompute (constant fold) outputs of JAX-backed operations, ensuring correct runtime behavior. [1] [2]

Test Suite Updates:

  • Updated test_make_rl_logp_op in tests/rl/test_rl_likelihood_builder.py to use pytensor.function with mode="FAST_COMPILE" for gradient evaluation, improving test reliability and compatibility.
  • Marked test_predictive_idata_to_dataframe in tests/test_utils.py with @pytest.mark.xfail to indicate it is expected to fail due to recent changes in PyMC, preventing it from causing false negatives in CI. See Cpaniaguam/fix predictive idata to dataframe datatree #974 with a fix.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR improves the PyTensor↔JAX integration by preventing PyTensor from constant-folding JAX-backed Ops (which can lead to incorrect/unsupported compile-time evaluation), and adjusts tests to be more robust under recent PyTensor/PyMC behavior changes.

Changes:

  • Disabled PyTensor constant folding for the JAX-backed LANLogpOp and LANLogpVJPOp.
  • Updated the RL likelihood builder gradient test to evaluate gradients via a compiled pytensor.function(..., mode="FAST_COMPILE").
  • Marked test_predictive_idata_to_dataframe as xfail due to upstream PyMC changes.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
src/hssm/distribution_utils/jax.py Prevents constant folding of JAX-backed Ops to avoid incorrect compile-time evaluation.
tests/rl/test_rl_likelihood_builder.py Uses a compiled PyTensor function in FAST_COMPILE mode for gradient evaluation reliability.
tests/test_utils.py Marks a known-broken test as xfail to avoid CI false negatives.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/test_utils.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

Comment thread tests/test_utils.py
Comment thread pyproject.toml
Comment on lines +29 to +30
"h5netcdf>=1.6.3",
"h5py>=3.14.0",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need these packages? Are these absolutely necessary to maintain basic functionalities?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. HSSM still restores .nc traces through az.from_netcdf(...)

traces = az.from_netcdf(traces)

traces = az.from_netcdf(traces)

idata_dict["idata_mcmc"] = az.from_netcdf(traces_path)

What seems to have changed is that newer xarray/ArviZ installs no longer seem to guarantee a working NetCDF backend transitively, and CI started failing when pytest tried to load the .nc fixture in

return az.from_netcdf("tests/fixtures/cavanagh_idata.nc")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it seems that both xarray and arviz now have them as optional and/or dev dependencies. I think these packages do that for a good reason, and we probably should do the same and then add these packages to dev dependencies. In code, we can probably check if these packages are installed and throw an informative error if not.

I like to keep the dependencies as slim as possible and not add additional packages for things that are occasionally used. @cpaniaguam @AlexanderFengler @krishnbera thoughts?

@cpaniaguam these .nc files might need to be generated again since they came from old InferenceData saves

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, I think it could be a good idea to have an io optional group and have cloudpickle, h5py and h5netcdf in that group

Comment on lines +49 to +51
def do_constant_folding(self, fgraph, node):
"""Keep PyTensor from trying to precompute opaque JAX-backed outputs."""
return False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow... this pattern from pytensor is really bad...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[pymc6 migration] Fix likelihood-related issues with RL due to change of internals

3 participants