feat(ctx): define data attribution for AnalogContext with read-only protection#765
feat(ctx): define data attribution for AnalogContext with read-only protection#765Zhaoxian-Wu wants to merge 6 commits into
Conversation
d4695cb to
da36bc9
Compare
|
@Zhaoxian-Wu, please check the lint errors here: https://github.com/IBM/aihwkit/actions/runs/23700170869/job/69154808771?pr=765 and push again. Run the make commands related to it in your local env so you don't face those after push, thanks! And don't forget to sign off your commits to pass the DCO check! 😉 |
|
Hello @maljoras @maljoras-sony can you look at this one? Thanks! |
da36bc9 to
abd3578
Compare
|
Hi @PabloCarmona, it's weird that my local test result varied from the online one. All tests and code style checks have passed on my end. Could you please provide a short testing guide on how to reproduce the test by myself? It would be better if the guide could start with the environment setup, since I suspect it's due to a version mismatch. |
|
@Zhaoxian-Wu as I can see this PR is changing code related to our custom simulator tiles involving the CUDA, Python and C code and that can lead to fails in the tests related to them like the one pointed here in the logs: Can you review if that could be the case? Since you are introducing changes in the methods arguments or/and behaviour related to it? If you have any more doubts we can help you with let me know again, thanks! |
|
Hello @Zhaoxian-Wu! Please update this branch with the latest commits on master so we can check everything runs ok on the CICD side since I fixed the problem with the linting. Thanks! |
|
Hello @maljoras @maljoras-sony did you have any chance to look at this? Thanks in advance! |
maljoras
left a comment
There was a problem hiding this comment.
HI @Zhaoxian-Wu ,
looks very nice, sorry for the delay. Strictly speaking, convenient perfectly reading weights is a bit dangerous, at least when assuming it during an algorithm that supposedly runs on the analog chip. Of course, we always had the perfect read available already. But now it's easier for the user to make a wrong assumptions if some external tool uses the weight in some wrong way, implicitly using a perfect readout etc.
But still, I agree that debugging becomes easier with this change and some external info tools should work out of the box. If we warn the users for paying attention to not use it wrongly, it might be fine. It looks like a very nice addition overall. Many thanks.
abd3578 to
463214c
Compare
|
Hi @maljoras, Thanks for your review and comments! I agree with you that convenient perfectly reading weights is a bit dangerous. To ensure a safer reading, here I design a limited access scheme. As I annotate in the code of
Let me know if there are any concerns or comments. Thanks for your attention to that matter! Hi @PabloCarmona, """
For diagnostic purposes, `AnalogContext` provides three public data view modes.
Consider the code:
---
layer = AnalogLinear(4, 3, bias=False, rpu_config=rpu_config)
analog_tile = layer.analog_module
analog_ctx = analog_tile.analog_ctx
weight = analog_tile.get_weights()[0]
---
where `weight` is the logical weight view, which is already ``physical weights x scaling``
Data view modes are controlled by `analog_ctx.data_view_mode` and the corresponding methods:
---
analog_ctx.enable_placeholder(). # PLACEHOLDER mode (default)
analog_ctx.enable_data_view(). # DATA_VIEW mode
analog_ctx.enable_buffer(). # BUFFER mode
---
* PLACEHOLDER (default): only metadata, such as ``size()``, ``shape``.
Since the RPU conductance values is not directly accessible in physic, the weight values,
as well as value-based operations, such as ``norm()``, are blocked by default
Access them raises ``RuntimeError``.
---
# inspect metadata without reading values:
analog_ctx.size() # [4, 3]
analog_ctx.device() # 'cpu'
analog_ctx.norm() # RuntimeError
---
* DATA_VIEW: exposes a read-only logical weight view through the `data` attribute,
which is equivalent to `analog_tile.get_weights()[0]`.
This allows users to inspect the effective weights.
Since the changes of both weights and scaling affect the logical weights,
we adopt the convetion that this logical view is read-only
Therefore, in-place operations, such as ``add_``, ``mul_``, etc, are blocked
---
# The following three lines will print the same value:
analog_ctx.size()
analog_ctx.data.size()
weight.size()
# Accessing values is allowed, but they are read-only:
analog_ctx.norm() # Successfully returns the norm
analog_ctx.norm() == weight.norm() # True
analog_ctx.add_(1.0) # RuntimeError
---
* BUFFER: exposes a zero-initialized tensor with the logical weight shape through the `data`
At that mode, `data` is an independent buffer that is not connected to the analog tile.
It is intended for optimizers with digital auxiliary state,
such as mixed-precision training or TT-v2.
---
analog_ctx.norm() == weight.norm() # Typically False, since the buffer is independent
analog_ctx.add_(1.0) # Successfully adds 1.0 to the buffer, but does not
affect the analog tile weights
---
To update the internal analog weights, use the following update methods instead of
writing `data` directly in the analog optimizer:
---
analog_ctx.analog_tile.update(...)
analog_ctx.analog_tile.update_indexed(...)
---
Caution: Even though DATA_VIEW mode allows us to access the weights directly,
always keep in mind that it is used only for diagnostic purposes.
To simulate the real reading, call the `read_weights` method
instead, i.e. given `analog_ctx: AnalogContext`,
estimated_weights, estimated_bias = analog_ctx.analog_tile.read_weights()
""" |
|
Sorry for that @Zhaoxian-Wu, but since I saw this errors on linting coming up, I address them and merge the fix on master. Could you sync with master one more time? Thanks and sorry for the inconvenience. |
463214c to
82002eb
Compare
Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
…r AnalogContext - weight data management - support checkpoint loading for old version toolkit - sync analog_ctx and tile when recreating Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
…nalogCtx tests - Fix TileModuleArray.get_weights() returning self.bias (Parameter) instead of self.bias.data (Tensor), which caused TypeError when Conv layers re-assign bias as a bool in reset_parameters. - Add test_analog_ctx.py verifying PR IBM#717 AnalogContext data attribution: correct shape, norm, nonzero, comparison ops, CUDA support, backward compatibility with old checkpoints, and convert_to_analog. Signed-off-by: Zhaoxian-Wu <wuzhaoxian97@gmail.com> Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
…weights Add as_ref parameter to tile-level get_weights() across all simulator tiles. When as_ref=True, return a direct reference to internal weight storage; when False (default), return an independent copy. For C++ tiles (RPUCuda) where get_weights() can only return a copy, add _bind_shared_weights() to allocate a torch.Tensor and pass it to the C++ tile via set_shared_weights() — both Python and C++ then operate on the same memory with no explicit sync needed. Add ReadOnlyWeightView (Tensor subclass) to prevent accidental in-place modification of analog weights. Uses PyTorch's trailing- underscore naming convention to block all in-place ops (future-proof). Configurable via AnalogContext.readonly flag, writable() context manager, MappingParameter.readonly_weights, and convert_to_analog readonly parameter. Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
replacing readonly flag with always-on __torch_function__ interception AnalogContext now enforces the invariant ctx.data == ctx.analog_tile.get_weights()[0]: the public data attribute always returns logical weights (physical × scales, bias column hidden), and all in-place modifications — add_(), mul_(), copy_(), out= writes, item assignment, and data rebinding — are unconditionally blocked via RuntimeError. This eliminates the mutable readonly flag and all escape hatches: writable(), set_data(), readonly property, convert_to_analog(readonly=...), and mapping.readonly_weights config field. Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
… to AnalogContext Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
82002eb to
e5413b8
Compare
Problem
AnalogContextis exposed as annn.Parameter, but its.datais a dummy scalar tensor. This means standard tensor operations produce wrong or meaningless results:This makes it impossible to inspect analog weights through the standard PyTorch parameter interface, which breaks compatibility with many training tools and analysis scripts.
Solution
Bind
analog_ctx.datato the actual tile weights, so all read operations work naturally. At the same time, block in-place mutations by default to respect the physical constraints of analog devices.Key changes
as_refparameter forget_weights()— Python tiles return a direct reference whenas_ref=True; default (False) preserves the existing convention (detached CPU copy)._bind_shared_weights()for C++ tiles — Allocates a sharedtorch.Tensorand passes it to the C++ tile viaset_shared_weights(), so both Python and C++ operate on the same memory with no explicit sync needed.ReadOnlyWeightView— Atorch.Tensorsubclass that blocks all in-place ops using PyTorch's trailing-underscore naming convention (future-proof, zero maintenance).Three-level
readonlycontrol —rpu_config.mapping.readonly_weights(per-layer),convert_to_analog(readonly=)(global),ctx.writable()(runtime).Test results
test_analog_ctx: 123 passed, 27 skipped (pre-existing)master(Conv3d, RNN ~1e-4 numerical mismatches) are not introduced by this PR. Environment: 2x NVIDIA RTX PRO 6000 Blackwell (driver 580.95), CUDA 12.0, PyTorch 2.8.0+cu128, cuDNN 91002.Response to #717 review
Hi @maljoras-sony and @PabloCarmona, thanks for your patience — I know it's been a while since the original review, and I really appreciate you coming back to this discussion. I've taken the time to carefully address all the concerns raised.
@maljoras-sony, your review raised three key concerns, and this update addresses all of them. Feel free to let me know if there remains any concerns.
Concern 1: Out-of-sync weights for C++ tiles
Root cause: C++ tiles store weights in C++ memory.
get_weights()can only return a copy — anyas_ref=Trueapproach that works for Python tiles does not apply here.Solution —
_bind_shared_weights(): At tile construction time, we allocate atorch.Tensoron the Python side and pass it to the C++ tile viaset_shared_weights(). After this call, both Python and C++ operate on the same memory:tile.update()/tile.set_weights()modify the tensor in-place — no explicit sync needed during normal trainingcpu()/cuda()) invalidate the old pointer and rebind via_bind_shared_weights()__getstate__/__setstate__skip the shared tensor (rebuilt on load)Concern 2: Unintended weight modification
Solution —
ReadOnlyWeightView: Atorch.Tensorsubclass that blocks all in-place mutations. Instead of a fragile blocklist, we use PyTorch's own naming convention — all in-place ops end with_— so the guard is simplyfunc_name.endswith('_'). This automatically covers any new ops PyTorch adds in the future.Three levels of control (all optional, default is read-only):
rpu_config.mapping.readonly_weightsspecific_rpu_config_funconvert_to_analog(readonly=False)ctx.writable()context managerConcern 3: Breaking the
get_weightsconventionWe preserve this convention fully. The public
get_weights()API still returns a detached CPU copy by default (as_ref=False). Theas_ref=Truepath is strictly internal — used only by_get_tile_weights_ref()and_bind_shared_weights()to set up the shared storage betweenanalog_ctx.dataand the tile. Users callingtile.get_weights()see no behavior change.@PabloCarmona, could you take a look at these changes when you have a chance? I'd really appreciate any comments or suggestions. If the direction looks good, it would be great to move toward merging — happy to make any further adjustments needed. Thanks!