Add foundation types and renames for NN Builder refactor#1872
Add foundation types and renames for NN Builder refactor#1872satwiksps wants to merge 10 commits into
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1872 +/- ##
==========================================
+ Coverage 87.89% 87.92% +0.03%
==========================================
Files 143 144 +1
Lines 13353 13403 +50
==========================================
+ Hits 11736 11785 +49
- Misses 1617 1618 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
|
cd4f349 to
d31f0cf
Compare
janfb
left a comment
There was a problem hiding this comment.
Thanks a lot @satwiksps for this clean PR! it's great to receive these well-scoped PR.
As discussed in our meeting, I added a couple of comments and suggestions. Let me know if you have any questions.
| """Build an estimator from theta and x, used for shape inference and | ||
| z-scoring. The returned object should be a ``ConditionalEstimator`` | ||
| subclass. |
| @dataclass | ||
| class _EstimatorConfigBase: | ||
| """Shared base providing ``from_kwargs()`` and ``to_dict()`` for all configs.""" | ||
| class _EstimatorBuilderBase: |
There was a problem hiding this comment.
I suggest to make this an abstract class as well to then make build and abstract method:
from abc import ABC, abstractmethod
@dataclass
class _EstimatorBuilderBase(ABC):
@abstractmethod
def build(self, context: BuildContext) -> Any: ...
There was a problem hiding this comment.
ah, this would break the child classes that don't have build() yet. so better keep as is now and do this in the upcoming PR 👍
| theta_shape=torch.Size(theta.shape[1:]), | ||
| x_shape=torch.Size(x.shape[1:]), | ||
| z_score_stats=z_score_stats or ZScoreStats(), | ||
| device=theta.device, |
There was a problem hiding this comment.
it could happen that x and theta are on different devices by mistake, then this could fail at runtime.
better add a quick check here to ensure x and theta are on the same device.
There was a problem hiding this comment.
then we should also cover these device scenarios in the new tests using a parametrization in pytest. do you have access to a GPU to test this in action?
There was a problem hiding this comment.
Yes, I have GPU on my machine, I will add new GPU tests and verify using my GPU before pushing
| structured = config.theta == "structured" | ||
| theta_mean, theta_std = z_standardization(theta, structured) |
There was a problem hiding this comment.
to make this and the one below more readable, I suggest to make the kwarg explicit and do it in one line:
if config.theta != "none":
theta_mean, theta_std = z_standardization(
theta, structured_dims=(config.theta == "structured")
)
| assert issubclass(MarginalFlowConfig, _EstimatorBuilderBase) | ||
|
|
||
|
|
||
| class TestProtocolRename: |
There was a problem hiding this comment.
this one can be dropped. we have a test for the protocol implicitly in the suite and through the type checker.
what would be a meaningful test here is a deprecation test on the rename:
def test_deprecated_alias_works_with_future_warning():
with pytest.warns(FutureWarning, match="renamed to.*ConditionalEstimatorBuildFn"):
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
assert ConditionalEstimatorBuilder is ConditionalEstimatorBuildFn| def test_subclasses_inherit(self): | ||
| """All existing config subclasses now inherit from _EstimatorBuilderBase.""" | ||
| assert issubclass(ConditionalFlowConfig, _EstimatorBuilderBase) | ||
| assert issubclass(ClassifierConfig, _EstimatorBuilderBase) | ||
| assert issubclass(MarginalFlowConfig, _EstimatorBuilderBase) |
There was a problem hiding this comment.
this type of tests can also be dropped. we have full control over this inheritance and it will be implicitly tested in other tests.
| ) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) |
There was a problem hiding this comment.
as discussed, this can lead to subtle bugs when we later want to compare zscore_stats_a == zscore_stats_b because of the way torch implements __equ__.
For now, I suggest we just add here
@dataclass(frozen=True, eq=False)
Thus, equality becomes identity-based (a == b iff a is b).
| @dataclass(frozen=True) | ||
| class ZScoreStats: | ||
| """Computed z-score statistics derived from training data. | ||
|
|
There was a problem hiding this comment.
additionally, we should add sth like this here in the docstring.
.. note::
Equality and hashing use object identity (``a == b`` iff ``a is b``).
For *content* comparison of stats, compare each tensor field directly
with ``torch.equal`` or use ``torch.testing.assert_close`` in tests.
Custom value-equality may be added later if a concrete consumer
(e.g., a build cache) requires it.
| x_std: Optional[Tensor] = None | ||
|
|
||
|
|
||
| @dataclass(frozen=True) |
There was a problem hiding this comment.
same here, because it holds dataclasses with tensor fields: add eq=False
| @@ -14,22 +14,20 @@ | |||
| ) | |||
|
|
|||
There was a problem hiding this comment.
I suggest to add a deprecation future warning here, sth like:
def __getattr__(name: str):
"""Module-level __getattr__ (PEP 562) for deprecated import names."""
if name == "ConditionalEstimatorBuilder":
import warnings
warnings.warn(
"`ConditionalEstimatorBuilder` has been renamed to "
"`ConditionalEstimatorBuildFn`. The old name still works but will be "
"removed in a future release. Update your import to: "
"`from sbi.neural_nets.estimators.base import ConditionalEstimatorBuildFn`.",
FutureWarning,
stacklevel=2,
)
return ConditionalEstimatorBuildFn
raise AttributeError(
f"module {__name__!r} has no attribute {name!r}"
)this will give users an informative error message and announce the change.
What does this PR do?
This PR sets up the groundwork for the Neural Network (NN) Builder refactor project under GSoC 2026. It doesn't add the final builders yet, but it provides a clean, reviewable foundation for the upcoming PRs.
Here are the main changes:
New
build_context.pyFile:Adds the core pieces needed to set up a neural network:
ZScoreConfig: Tracks how the user wants to preprocess data (e.g., "none", "independent", "structured").ZScoreStats: Holds the calculated mean and standard deviation.BuildContext: Bundles everything needed to build network (shapes, device, dtype, and z-score stats) into one place.compute_z_score_stats(): A simple helper function to calculate the z-score stats from the data.Renamed
ConditionalEstimatorBuildertoConditionalEstimatorBuildFn:Renamed
_EstimatorConfigBaseto_EstimatorBuilderBase:build()method to it.from_kwargs()andto_dict()methods are untouched and work as usual.Files Changed
sbi/neural_nets/build_context.py(NEW): AddedZScoreConfig,ZScoreStats,BuildContext, andcompute_z_score_statsas suggested in updated plan.mdsbi/neural_nets/estimators/base.py: RenamedConditionalEstimatorBuildertoConditionalEstimatorBuildFnsbi/neural_nets/net_builders/estimator_configs.py: Renamed_EstimatorConfigBaseto_EstimatorBuilderBaseand added an emptybuild()methodsbi/neural_nets/net_builders/vector_field_nets.py: Updated imports and docstringsnpe/,nle/,nre/,vfpe/): Updated Protocol imports across 16 filestests/build_context_test.py(NEW): Added 26 passing tests to fully cover the new types, much of these codes in this file is suggested by AI, we can trim out if any tests are unnecessaryAI Usage
build_context_test.py.Does this close any issues?
N/A
Any relevant code examples, logs, or error messages?
N/A