fix(cuda): retune PWU kernel when m_batch grows after initial m=1 update#763
fix(cuda): retune PWU kernel when m_batch grows after initial m=1 update#763Zhaoxian-Wu wants to merge 1 commit into
Conversation
6f677b3 to
aeb25d3
Compare
|
Hello @maljoras @maljoras-sony, can you help us and take a look if everything is ok with this? |
aeb25d3 to
6f677b3
Compare
|
Hello @Zhaoxian-Wu! Please update this branch with the latest commits on master so we can check everything runs ok on the CICD side since I fixed the problem with the linting. Thanks! |
|
Hello @maljoras @maljoras-sony did you have any chance to look at this? Thanks in advance! |
maljoras
left a comment
There was a problem hiding this comment.
many thanks @Zhaoxian-Wu. Nice catch!
|
Hello @Zhaoxian-Wu , here is the same thing as in the PR #764. Sync up with master so we can check everything pass, thanks! |
139741e to
348fc3d
Compare
|
Hi @PabloCarmona, Thanks for following up and your effort on that matter. I've synced up the latest code. |
|
Sorry for that @Zhaoxian-Wu, but since I saw this errors on linting coming up, I address them and merge the fix on master. Could you sync with master one more time? Thanks and sorry for the inconvenience. |
348fc3d to
8dbc2c6
Compare
|
Hi @PabloCarmona , thanks for fixing this existing issue. I have synced and pushed the merged code of all my PRs again |
When a tile's first update uses m_batch=1, tuneUpdate() benchmarks all kernels valid for that batch size, which includes SingleFunctor — a CUDA kernel with no inner batch loop that processes only batch item 0. Due to GPU timing jitter, SingleFunctor can win the benchmark race (~1/5 cold- start runs). If a subsequent update uses m_batch=M>>1, the cached kernel_pars_ is silently reused, producing a weight change of ~1/M instead of the correct value (~99% relative error). Fix: add tuned_m_batch_ to PulsedWeightUpdater to track the m_batch used during the last tuneUpdate() call. When m_batch grows beyond this value, invalidate kernel_pars_ and force-retune with the new batch size. SingleFunctor is marked invalid for m_batch>1 (via SingleBase), so a correct batch-aware kernel (BatchShared*, BatchSum, ...) is selected. Add regression test in AnalogTileTest that primes a CUDA tile with m_batch=1 then updates with m_batch=128, comparing the result against a reference tile with no priming. The test fails with the old code (~99% error) and passes with the fix (~0% error). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
8dbc2c6 to
3fdf08c
Compare
Background
PulsedWeightUpdater::tuneUpdate()benchmarks all valid CUDA kernels on the firstupdate()call and permanently caches the winner inkernel_pars_. Among the candidates isSingleFunctor(kernel classSingleBase), which has no inner batch loop and is only correct whenm_batch=1.Due to GPU cold-start timing jitter,
SingleFunctor(~0.025 ms) and batch-aware kernels likeBatchSharedBase(~0.026 ms) have nearly identical benchmark times. In approximately 1 in 5 cold-start runs,SingleFunctorwins the race and is permanently selected.This becomes a silent correctness bug when:
update()call usesm_batch=1(e.g. a priming update, a warm-up step, or a gradient accumulation flush), causingtuneUpdate()to potentially selectSingleFunctor.update()call usesm_batch=M >> 1.SingleFunctoris reused without re-tuning and silently processes only batch item 0, producing a weight change of~1/Minstead of the correct value — roughly 99% relative error.The bug affects all pulsed leaf devices (
ConstantStepDevice,LinearStepDevice,SoftBoundsDevice,ExpStepDevice,PowStepDevice,PiecewiseStepDevice, etc.), as all of them includeSingleFunctorin their valid kernel list.Fix
Add
tuned_m_batch_toPulsedWeightUpdaterto track them_batchused during the lasttuneUpdate()call. When a subsequentupdate()arrives with a largerm_batch, invalidatekernel_pars_and force-retune with the new batch size.SingleFunctoris then correctly excluded (itsSingleBasevalidity check rejectsm_batch > 1), and a batch-aware kernel is selected instead.Minimal Working Example
The following self-contained script reproduces the bug (pre-fix) and verifies the fix.
Pre-fix output (when
SingleFunctorwins the cold-start benchmark, ~1/5 runs):Post-fix output (guaranteed, all runs):
Changes
src/rpucuda/cuda/pulsed_weight_updater.htuned_m_batch_field with explanatory commentsrc/rpucuda/cuda/pulsed_weight_updater.cutuned_m_batch_on device-type change; add retune guard whenm_batchgrows; recordtuned_m_batch_aftertuneUpdate()tests/test_bindings_tiles.pyAnalogTileTest::test_update_mbatch_changeregression test for CUDAConstantStepDevice