Skip to content

Add foundation types and renames for NN Builder refactor#1872

Open
satwiksps wants to merge 10 commits into
sbi-dev:gsoc-2026from
satwiksps:nn-builder-refactor-pr-one
Open

Add foundation types and renames for NN Builder refactor#1872
satwiksps wants to merge 10 commits into
sbi-dev:gsoc-2026from
satwiksps:nn-builder-refactor-pr-one

Conversation

@satwiksps

Copy link
Copy Markdown
Contributor

What does this PR do?

This PR sets up the groundwork for the Neural Network (NN) Builder refactor project under GSoC 2026. It doesn't add the final builders yet, but it provides a clean, reviewable foundation for the upcoming PRs.

Here are the main changes:

  1. New build_context.py File:
    Adds the core pieces needed to set up a neural network:

    • ZScoreConfig: Tracks how the user wants to preprocess data (e.g., "none", "independent", "structured").
    • ZScoreStats: Holds the calculated mean and standard deviation.
    • BuildContext: Bundles everything needed to build network (shapes, device, dtype, and z-score stats) into one place.
    • compute_z_score_stats(): A simple helper function to calculate the z-score stats from the data.
  2. Renamed ConditionalEstimatorBuilder to ConditionalEstimatorBuildFn:

    • This clarifies that this protocol is actually a function, not an object.
    • It also frees up the "Builder" name for the new classes coming in the next PR.
    • This was a simple find-and-replace across 16 trainer files.
  3. Renamed _EstimatorConfigBase to _EstimatorBuilderBase:

    • Renamed the base class and added an empty build() method to it.
    • The old from_kwargs() and to_dict() methods are untouched and work as usual.

Files Changed

  • sbi/neural_nets/build_context.py (NEW): Added ZScoreConfig, ZScoreStats, BuildContext, and compute_z_score_stats as suggested in updated plan.md
  • sbi/neural_nets/estimators/base.py: Renamed ConditionalEstimatorBuilder to ConditionalEstimatorBuildFn
  • sbi/neural_nets/net_builders/estimator_configs.py: Renamed _EstimatorConfigBase to _EstimatorBuilderBase and added an empty build() method
  • sbi/neural_nets/net_builders/vector_field_nets.py: Updated imports and docstrings
  • Trainer files (npe/, nle/, nre/, vfpe/): Updated Protocol imports across 16 files
  • tests/build_context_test.py (NEW): Added 26 passing tests to fully cover the new types, much of these codes in this file is suggested by AI, we can trim out if any tests are unnecessary

AI Usage

  • Claude Opus 4.6 (via Antigravity 2.0): Used to improve code docstrings and comments, suggest additional docstring changes, and write additional tests for build_context_test.py.
  • Grammarly: Used to refine and improve the grammar of this PR description.
  • VS Code auto-complete: Used some suggested docstrings and comments.

Does this close any issues?

N/A

Any relevant code examples, logs, or error messages?

N/A

@satwiksps satwiksps marked this pull request as ready for review May 29, 2026 05:36
@codecov

codecov Bot commented May 29, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 98.55072% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 87.92%. Comparing base (f8037af) to head (e3ecd5a).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
sbi/neural_nets/build_context.py 97.56% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1872      +/-   ##
==========================================
+ Coverage   87.89%   87.92%   +0.03%     
==========================================
  Files         143      144       +1     
  Lines       13353    13403      +50     
==========================================
+ Hits        11736    11785      +49     
- Misses       1617     1618       +1     
Flag Coverage Δ
fast 81.45% <98.55%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/trainers/nle/mnle.py 95.83% <100.00%> (ø)
sbi/inference/trainers/nle/nle_a.py 100.00% <ø> (ø)
sbi/inference/trainers/nle/nle_base.py 96.87% <100.00%> (ø)
sbi/inference/trainers/npe/mnpe.py 95.83% <100.00%> (ø)
sbi/inference/trainers/npe/npe_a.py 79.35% <ø> (ø)
sbi/inference/trainers/npe/npe_b.py 100.00% <ø> (ø)
sbi/inference/trainers/npe/npe_base.py 92.74% <100.00%> (ø)
sbi/inference/trainers/npe/npe_c.py 93.75% <ø> (ø)
sbi/inference/trainers/npe/npe_pfn.py 85.05% <ø> (ø)
sbi/inference/trainers/nre/bnre.py 97.22% <100.00%> (ø)
... and 11 more

@satwiksps satwiksps force-pushed the nn-builder-refactor-pr-one branch from cd4f349 to d31f0cf Compare May 29, 2026 05:49

@janfb janfb left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @satwiksps for this clean PR! it's great to receive these well-scoped PR.

As discussed in our meeting, I added a couple of comments and suggestions. Let me know if you have any questions.

Comment on lines +21 to +23
"""Build an estimator from theta and x, used for shape inference and
z-scoring. The returned object should be a ``ConditionalEstimator``
subclass.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Comment on lines 29 to +30
@dataclass
class _EstimatorConfigBase:
"""Shared base providing ``from_kwargs()`` and ``to_dict()`` for all configs."""
class _EstimatorBuilderBase:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to make this an abstract class as well to then make build and abstract method:

from abc import ABC, abstractmethod
  
  @dataclass
  class _EstimatorBuilderBase(ABC):
      @abstractmethod
      def build(self, context: BuildContext) -> Any: ...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, this would break the child classes that don't have build() yet. so better keep as is now and do this in the upcoming PR 👍

theta_shape=torch.Size(theta.shape[1:]),
x_shape=torch.Size(x.shape[1:]),
z_score_stats=z_score_stats or ZScoreStats(),
device=theta.device,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it could happen that x and theta are on different devices by mistake, then this could fail at runtime.
better add a quick check here to ensure x and theta are on the same device.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we should also cover these device scenarios in the new tests using a parametrization in pytest. do you have access to a GPU to test this in action?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have GPU on my machine, I will add new GPU tests and verify using my GPU before pushing

Comment thread sbi/neural_nets/build_context.py Outdated
Comment on lines +132 to +133
structured = config.theta == "structured"
theta_mean, theta_std = z_standardization(theta, structured)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make this and the one below more readable, I suggest to make the kwarg explicit and do it in one line:

if config.theta != "none":
      theta_mean, theta_std = z_standardization(
          theta, structured_dims=(config.theta == "structured")
      )

Comment thread tests/build_context_test.py Outdated
assert issubclass(MarginalFlowConfig, _EstimatorBuilderBase)


class TestProtocolRename:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one can be dropped. we have a test for the protocol implicitly in the suite and through the type checker.

what would be a meaningful test here is a deprecation test on the rename:

def test_deprecated_alias_works_with_future_warning():
     with pytest.warns(FutureWarning, match="renamed to.*ConditionalEstimatorBuildFn"):
         from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
     assert ConditionalEstimatorBuilder is ConditionalEstimatorBuildFn

Comment thread tests/build_context_test.py Outdated
Comment on lines +211 to +215
def test_subclasses_inherit(self):
"""All existing config subclasses now inherit from _EstimatorBuilderBase."""
assert issubclass(ConditionalFlowConfig, _EstimatorBuilderBase)
assert issubclass(ClassifierConfig, _EstimatorBuilderBase)
assert issubclass(MarginalFlowConfig, _EstimatorBuilderBase)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this type of tests can also be dropped. we have full control over this inheritance and it will be implicitly tested in other tests.

Comment thread sbi/neural_nets/build_context.py Outdated
)


@dataclass(frozen=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed, this can lead to subtle bugs when we later want to compare zscore_stats_a == zscore_stats_b because of the way torch implements __equ__.

For now, I suggest we just add here

@dataclass(frozen=True, eq=False)

Thus, equality becomes identity-based (a == b iff a is b).

@dataclass(frozen=True)
class ZScoreStats:
"""Computed z-score statistics derived from training data.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

additionally, we should add sth like this here in the docstring.

.. note::
         Equality and hashing use object identity (``a == b`` iff ``a is b``).
         For *content* comparison of stats, compare each tensor field directly
         with ``torch.equal`` or use ``torch.testing.assert_close`` in tests.
         Custom value-equality may be added later if a concrete consumer
         (e.g., a build cache) requires it.

Comment thread sbi/neural_nets/build_context.py Outdated
x_std: Optional[Tensor] = None


@dataclass(frozen=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, because it holds dataclasses with tensor fields: add eq=False

@@ -14,22 +14,20 @@
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to add a deprecation future warning here, sth like:

def __getattr__(name: str):
      """Module-level __getattr__ (PEP 562) for deprecated import names."""
      if name == "ConditionalEstimatorBuilder":
          import warnings
          warnings.warn(
              "`ConditionalEstimatorBuilder` has been renamed to "
              "`ConditionalEstimatorBuildFn`. The old name still works but will be "
              "removed in a future release. Update your import to: "
              "`from sbi.neural_nets.estimators.base import ConditionalEstimatorBuildFn`.",
              FutureWarning,
              stacklevel=2,
          )
          return ConditionalEstimatorBuildFn
      raise AttributeError(
          f"module {__name__!r} has no attribute {name!r}"
      )

this will give users an informative error message and announce the change.

@satwiksps satwiksps changed the base branch from main to gsoc-2026 June 9, 2026 12:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants