Cpaniaguam/fix rl likelihood builder pymc6#975
Conversation
There was a problem hiding this comment.
Pull request overview
This PR improves the PyTensor↔JAX integration by preventing PyTensor from constant-folding JAX-backed Ops (which can lead to incorrect/unsupported compile-time evaluation), and adjusts tests to be more robust under recent PyTensor/PyMC behavior changes.
Changes:
- Disabled PyTensor constant folding for the JAX-backed
LANLogpOpandLANLogpVJPOp. - Updated the RL likelihood builder gradient test to evaluate gradients via a compiled
pytensor.function(..., mode="FAST_COMPILE"). - Marked
test_predictive_idata_to_dataframeasxfaildue to upstream PyMC changes.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
src/hssm/distribution_utils/jax.py |
Prevents constant folding of JAX-backed Ops to avoid incorrect compile-time evaluation. |
tests/rl/test_rl_likelihood_builder.py |
Uses a compiled PyTensor function in FAST_COMPILE mode for gradient evaluation reliability. |
tests/test_utils.py |
Marks a known-broken test as xfail to avoid CI false negatives. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| "h5netcdf>=1.6.3", | ||
| "h5py>=3.14.0", |
There was a problem hiding this comment.
Why do we need these packages? Are these absolutely necessary to maintain basic functionalities?
There was a problem hiding this comment.
I think so. HSSM still restores .nc traces through az.from_netcdf(...)
Line 1490 in 95b27df
Line 1508 in 95b27df
Line 1665 in 95b27df
What seems to have changed is that newer xarray/ArviZ installs no longer seem to guarantee a working NetCDF backend transitively, and CI started failing when pytest tried to load the .nc fixture in
Line 91 in 95b27df
There was a problem hiding this comment.
Yeah it seems that both xarray and arviz now have them as optional and/or dev dependencies. I think these packages do that for a good reason, and we probably should do the same and then add these packages to dev dependencies. In code, we can probably check if these packages are installed and throw an informative error if not.
I like to keep the dependencies as slim as possible and not add additional packages for things that are occasionally used. @cpaniaguam @AlexanderFengler @krishnbera thoughts?
@cpaniaguam these .nc files might need to be generated again since they came from old InferenceData saves
There was a problem hiding this comment.
Additionally, I think it could be a good idea to have an io optional group and have cloudpickle, h5py and h5netcdf in that group
| def do_constant_folding(self, fgraph, node): | ||
| """Keep PyTensor from trying to precompute opaque JAX-backed outputs.""" | ||
| return False |
There was a problem hiding this comment.
Wow... this pattern from pytensor is really bad...
Prevent PyTensor from incorrectly optimizing JAX-backed operations.
JAX-PyTensor Integration Improvements:
do_constant_foldingmethods to bothLANLogpOpandLANLogpVJPOpclasses insrc/hssm/distribution_utils/jax.pyto prevent PyTensor from attempting to precompute (constant fold) outputs of JAX-backed operations, ensuring correct runtime behavior. [1] [2]Test Suite Updates:
test_make_rl_logp_opintests/rl/test_rl_likelihood_builder.pyto usepytensor.functionwithmode="FAST_COMPILE"for gradient evaluation, improving test reliability and compatibility.test_predictive_idata_to_dataframeintests/test_utils.pywith@pytest.mark.xfailto indicate it is expected to fail due to recent changes in PyMC, preventing it from causing false negatives in CI. See Cpaniaguam/fix predictive idata to dataframe datatree #974 with a fix.