Skip to content

feat(ctx): define data attribution for AnalogContext with read-only protection#765

Open
Zhaoxian-Wu wants to merge 6 commits into
IBM:masterfrom
Zhaoxian-Wu:feat/AnalogCtx-attribution
Open

feat(ctx): define data attribution for AnalogContext with read-only protection#765
Zhaoxian-Wu wants to merge 6 commits into
IBM:masterfrom
Zhaoxian-Wu:feat/AnalogCtx-attribution

Conversation

@Zhaoxian-Wu

@Zhaoxian-Wu Zhaoxian-Wu commented Mar 26, 2026

Copy link
Copy Markdown

Note: Supersedes #717. The original PR was developed on the fork's master branch, which made it difficult to keep clean and rebase. This PR continues the same work on a dedicated feature branch (feat/AnalogCtx-attribution) for a cleaner history. All reviewer feedback from #717 has been addressed — see the detailed response below.

Problem

AnalogContext is exposed as an nn.Parameter, but its .data is a dummy scalar tensor. This means standard tensor operations produce wrong or meaningless results:

from aihwkit.nn import AnalogLinear
model = AnalogLinear(6, 4, False)

ctx = next(model.parameters())
ctx.size()      # torch.Size([])     — expected (out, in)
ctx.norm()      # tensor(1.)         — meaningless
ctx > 0         # tensor(False)      — wrong
ctx.nonzero()   # tensor([], size=(1, 0)) — wrong

This makes it impossible to inspect analog weights through the standard PyTorch parameter interface, which breaks compatibility with many training tools and analysis scripts.

Solution

Bind analog_ctx.data to the actual tile weights, so all read operations work naturally. At the same time, block in-place mutations by default to respect the physical constraints of analog devices.

Key changes

  1. as_ref parameter for get_weights() — Python tiles return a direct reference when as_ref=True; default (False) preserves the existing convention (detached CPU copy).

  2. _bind_shared_weights() for C++ tiles — Allocates a shared torch.Tensor and passes it to the C++ tile via set_shared_weights(), so both Python and C++ operate on the same memory with no explicit sync needed.

  3. ReadOnlyWeightView — A torch.Tensor subclass that blocks all in-place ops using PyTorch's trailing-underscore naming convention (future-proof, zero maintenance).

  4. Three-level readonly controlrpu_config.mapping.readonly_weights (per-layer), convert_to_analog(readonly=) (global), ctx.writable() (runtime).

Test results

  • test_analog_ctx: 123 passed, 27 skipped (pre-existing)
  • Full suite: 3763 passed, 0 regressions
  • A few pre-existing failures on master (Conv3d, RNN ~1e-4 numerical mismatches) are not introduced by this PR. Environment: 2x NVIDIA RTX PRO 6000 Blackwell (driver 580.95), CUDA 12.0, PyTorch 2.8.0+cu128, cuDNN 91002.

Response to #717 review

Hi @maljoras-sony and @PabloCarmona, thanks for your patience — I know it's been a while since the original review, and I really appreciate you coming back to this discussion. I've taken the time to carefully address all the concerns raised.

@maljoras-sony, your review raised three key concerns, and this update addresses all of them. Feel free to let me know if there remains any concerns.

Concern 1: Out-of-sync weights for C++ tiles

"Note that this will only be a copy of the current weights. So if you update the weights (using RPUCuda) the analog_ctx.data will not be synchronized correctly with the actual weight."

Root cause: C++ tiles store weights in C++ memory. get_weights() can only return a copy — any as_ref=True approach that works for Python tiles does not apply here.

Solution — _bind_shared_weights(): At tile construction time, we allocate a torch.Tensor on the Python side and pass it to the C++ tile via set_shared_weights(). After this call, both Python and C++ operate on the same memory:

Python:  analog_ctx.data ──→ _shared_weight_tensor ←── C++ tile internal storage
                              (same data_ptr)
  • tile.update() / tile.set_weights() modify the tensor in-place — no explicit sync needed during normal training
  • Device moves (cpu() / cuda()) invalidate the old pointer and rebind via _bind_shared_weights()
  • __getstate__ / __setstate__ skip the shared tensor (rebuilt on load)

Concern 2: Unintended weight modification

"we do not encourage users to fiddle with the analog weights, which are not accessible directly in reality."

Solution — ReadOnlyWeightView: A torch.Tensor subclass that blocks all in-place mutations. Instead of a fragile blocklist, we use PyTorch's own naming convention — all in-place ops end with _ — so the guard is simply func_name.endswith('_'). This automatically covers any new ops PyTorch adds in the future.

Three levels of control (all optional, default is read-only):

Level API Use case
Per-layer rpu_config.mapping.readonly_weights Fine-grained via specific_rpu_config_fun
Global convert_to_analog(readonly=False) Quick toggle for research
Runtime ctx.writable() context manager Temporary access in a code block
from aihwkit.nn import AnalogLinear
model = AnalogLinear(6, 4, False)
ctx = next(model.parameters())
ctx.data.norm()        # ✅ reads work
ctx.data.add_(1.0)     # ❌ RuntimeError

with ctx.writable():
    ctx.data.add_(1.0) # ✅ explicit opt-in

Concern 3: Breaking the get_weights convention

"the current convention is that get_weights always returns CPU weights... Moreover, get_weights will always produce a copy (without the backward trace) by design, to avoid implicit things that cannot be done with analog weights."

We preserve this convention fully. The public get_weights() API still returns a detached CPU copy by default (as_ref=False). The as_ref=True path is strictly internal — used only by _get_tile_weights_ref() and _bind_shared_weights() to set up the shared storage between analog_ctx.data and the tile. Users calling tile.get_weights() see no behavior change.

# Public API — unchanged, returns detached CPU copy
w, b = analog_tile.get_weights()     # as_ref=False (default)
w.add_(1.0)                          # safe — this is a copy, tile is unaffected

# Internal only — used to bind analog_ctx.data
ref = tile.tile.get_weights(as_ref=True)  # direct reference to tile storage

@PabloCarmona, could you take a look at these changes when you have a chance? I'd really appreciate any comments or suggestions. If the direction looks good, it would be great to move toward merging — happy to make any further adjustments needed. Thanks!

@Zhaoxian-Wu Zhaoxian-Wu force-pushed the feat/AnalogCtx-attribution branch 2 times, most recently from d4695cb to da36bc9 Compare March 29, 2026 03:09
@PabloCarmona

PabloCarmona commented Mar 30, 2026

Copy link
Copy Markdown
Collaborator

@Zhaoxian-Wu, please check the lint errors here: https://github.com/IBM/aihwkit/actions/runs/23700170869/job/69154808771?pr=765 and push again. Run the make commands related to it in your local env so you don't face those after push, thanks!

And don't forget to sign off your commits to pass the DCO check! 😉

@PabloCarmona

Copy link
Copy Markdown
Collaborator

Hello @maljoras @maljoras-sony can you look at this one? Thanks!

@Zhaoxian-Wu Zhaoxian-Wu force-pushed the feat/AnalogCtx-attribution branch from da36bc9 to abd3578 Compare April 1, 2026 22:38
@Zhaoxian-Wu

Zhaoxian-Wu commented Apr 6, 2026

Copy link
Copy Markdown
Author

Hi @PabloCarmona, it's weird that my local test result varied from the online one. All tests and code style checks have passed on my end. Could you please provide a short testing guide on how to reproduce the test by myself? It would be better if the guide could start with the environment setup, since I suspect it's due to a version mismatch.

@PabloCarmona

PabloCarmona commented Apr 6, 2026

Copy link
Copy Markdown
Collaborator

@Zhaoxian-Wu as I can see this PR is changing code related to our custom simulator tiles involving the CUDA, Python and C code and that can lead to fails in the tests related to them like the one pointed here in the logs:

FAILED tests/test_torch_tiles.py::test_discretization_behavior[BoundManagementType.NONE--1--1-10.0-10.0] - assert False
 +  where False = allclose(tensor([[10.]], grad_fn=<ClampBackward1>), tensor([[-8.9111]], grad_fn=<AnalogFunctionBackward>), atol=1e-05)

Can you review if that could be the case? Since you are introducing changes in the methods arguments or/and behaviour related to it? If you have any more doubts we can help you with let me know again, thanks!

@PabloCarmona

Copy link
Copy Markdown
Collaborator

Hello @Zhaoxian-Wu! Please update this branch with the latest commits on master so we can check everything runs ok on the CICD side since I fixed the problem with the linting. Thanks!

@PabloCarmona

Copy link
Copy Markdown
Collaborator

Hello @maljoras @maljoras-sony did you have any chance to look at this? Thanks in advance!

maljoras
maljoras previously approved these changes Jun 7, 2026

@maljoras maljoras left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HI @Zhaoxian-Wu ,
looks very nice, sorry for the delay. Strictly speaking, convenient perfectly reading weights is a bit dangerous, at least when assuming it during an algorithm that supposedly runs on the analog chip. Of course, we always had the perfect read available already. But now it's easier for the user to make a wrong assumptions if some external tool uses the weight in some wrong way, implicitly using a perfect readout etc.

But still, I agree that debugging becomes easier with this change and some external info tools should work out of the box. If we warn the users for paying attention to not use it wrongly, it might be fine. It looks like a very nice addition overall. Many thanks.

@Zhaoxian-Wu

Copy link
Copy Markdown
Author

Hi @maljoras,

Thanks for your review and comments! I agree with you that convenient perfectly reading weights is a bit dangerous. To ensure a safer reading, here I design a limited access scheme. As I annotate in the code of AnalogContext (also attached below for quick review), there are three types of access modes: PLACEHOLDER (default), DATA_VIEW, BUFFER. By default, only meta-data, like size, device are available, which avoids the accidental perfect-reading. If users intend to do so for diagnostic purposes, they need to explicitly call analog_ctx.enable_data_view(). I have added a large number of test cases to ensure the coverage of most usage cases. I hope this design makes sense.

My bad for accidentally dismissing your review when I re-pushed the code. All updates are inside the latest commit 463214c1. I hope this won't cause too much trouble for your review.

Let me know if there are any concerns or comments. Thanks for your attention to that matter!

Hi @PabloCarmona,
I have synced the code with the master, and all tests passed from my side. Could you trigger an auto-test again to see whether everything works correctly? Thank!

"""
For diagnostic purposes, `AnalogContext` provides three public data view modes.
    Consider the code:
        ---
        layer = AnalogLinear(4, 3, bias=False, rpu_config=rpu_config)
        analog_tile = layer.analog_module
        analog_ctx = analog_tile.analog_ctx
        weight = analog_tile.get_weights()[0]
        ---
    where `weight` is the logical weight view, which is already ``physical weights x scaling``

    Data view modes are controlled by `analog_ctx.data_view_mode` and the corresponding methods:
        ---
        analog_ctx.enable_placeholder().    # PLACEHOLDER mode (default)
        analog_ctx.enable_data_view().      # DATA_VIEW mode
        analog_ctx.enable_buffer().         # BUFFER mode
        ---

    * PLACEHOLDER (default): only metadata, such as ``size()``, ``shape``.
        Since the RPU conductance values is not directly accessible in physic, the weight values,
            as well as value-based operations, such as ``norm()``, are blocked by default
            Access them raises ``RuntimeError``.
        ---
        # inspect metadata without reading values:
        analog_ctx.size()         # [4, 3]
        analog_ctx.device()       # 'cpu'
        analog_ctx.norm()         # RuntimeError
        ---
    * DATA_VIEW: exposes a read-only logical weight view through the `data` attribute,
        which is equivalent to `analog_tile.get_weights()[0]`.
        This allows users to inspect the effective weights.
        Since the changes of both weights and scaling affect the logical weights,
            we adopt the convetion that this logical view is read-only
        Therefore, in-place operations, such as ``add_``, ``mul_``, etc, are blocked
        ---
        # The following three lines will print the same value:
        analog_ctx.size()
        analog_ctx.data.size()
        weight.size()
        # Accessing values is allowed, but they are read-only:
        analog_ctx.norm()                   # Successfully returns the norm
        analog_ctx.norm() == weight.norm()  # True
        analog_ctx.add_(1.0)                # RuntimeError
        ---
    * BUFFER: exposes a zero-initialized tensor with the logical weight shape through the `data`
        At that mode, `data` is an independent buffer that is not connected to the analog tile.
        It is intended for optimizers with digital auxiliary state,
            such as mixed-precision training or TT-v2.
        ---
        analog_ctx.norm() == weight.norm()  # Typically False, since the buffer is independent
        analog_ctx.add_(1.0)                # Successfully adds 1.0 to the buffer, but does not
                                              affect the analog tile weights
        ---

    To update the internal analog weights, use the following update methods instead of
        writing `data` directly in the analog optimizer:
        ---
        analog_ctx.analog_tile.update(...)
        analog_ctx.analog_tile.update_indexed(...)
        ---

    Caution: Even though DATA_VIEW mode allows us to access the weights directly,
        always keep in mind that it is used only for diagnostic purposes.
        To simulate the real reading, call the `read_weights` method
        instead, i.e. given `analog_ctx: AnalogContext`,
        estimated_weights, estimated_bias = analog_ctx.analog_tile.read_weights()
"""

@PabloCarmona

Copy link
Copy Markdown
Collaborator

Sorry for that @Zhaoxian-Wu, but since I saw this errors on linting coming up, I address them and merge the fix on master. Could you sync with master one more time? Thanks and sorry for the inconvenience.

@Zhaoxian-Wu Zhaoxian-Wu force-pushed the feat/AnalogCtx-attribution branch from 463214c to 82002eb Compare June 15, 2026 15:29
Zhaoxian-Wu and others added 6 commits June 20, 2026 11:31
Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
…r AnalogContext

- weight data management
- support checkpoint loading for old version toolkit
- sync analog_ctx and tile when recreating

Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
…nalogCtx tests

- Fix TileModuleArray.get_weights() returning self.bias (Parameter) instead of
  self.bias.data (Tensor), which caused TypeError when Conv layers re-assign
  bias as a bool in reset_parameters.
- Add test_analog_ctx.py verifying PR IBM#717 AnalogContext data attribution:
  correct shape, norm, nonzero, comparison ops, CUDA support, backward
  compatibility with old checkpoints, and convert_to_analog.

Signed-off-by: Zhaoxian-Wu <wuzhaoxian97@gmail.com>
Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
…weights

Add as_ref parameter to tile-level get_weights() across all simulator
tiles. When as_ref=True, return a direct reference to internal weight
storage; when False (default), return an independent copy.

For C++ tiles (RPUCuda) where get_weights() can only return a copy,
add _bind_shared_weights() to allocate a torch.Tensor and pass it to
the C++ tile via set_shared_weights() — both Python and C++ then
operate on the same memory with no explicit sync needed.

Add ReadOnlyWeightView (Tensor subclass) to prevent accidental
in-place modification of analog weights. Uses PyTorch's trailing-
underscore naming convention to block all in-place ops (future-proof).
Configurable via AnalogContext.readonly flag, writable() context
manager, MappingParameter.readonly_weights, and convert_to_analog
readonly parameter.

Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
replacing readonly flag with always-on __torch_function__ interception

AnalogContext now enforces the invariant ctx.data == ctx.analog_tile.get_weights()[0]:
the public data attribute always returns logical weights (physical × scales, bias column
hidden), and all in-place modifications — add_(), mul_(), copy_(), out= writes, item
assignment, and data rebinding — are unconditionally blocked via RuntimeError.

This eliminates the mutable readonly flag and all escape hatches: writable(),
set_data(), readonly property, convert_to_analog(readonly=...), and
mapping.readonly_weights config field.

Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
… to AnalogContext

Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
@Zhaoxian-Wu Zhaoxian-Wu force-pushed the feat/AnalogCtx-attribution branch from 82002eb to e5413b8 Compare June 20, 2026 15:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants