Enable TBO Support & Fix Accuracy Regressions for Kimi K2.5#1369
Enable TBO Support & Fix Accuracy Regressions for Kimi K2.5#1369jpy794 wants to merge 9 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR enables end-to-end Data-Parallel Attention (DPA) + Two-Batch Overlap (TBO) for Kimi K2.5 by fixing DP/TBO micro-batch metadata, strengthening MoE correctness on DP fallback paths, and improving cross-DP prefill admission alignment for TBO’s two-batch requirement.
Changes:
- Propagates per-ubatch per-rank token counts (
ub_tokens_across_dp) through DP sync andForwardContext, and rebuilds ubatch-localDPMetadatainUBatchWrapper. - Fixes MoE correctness/stability under DP fallback + TBO (zero padding rows; scale max token metadata; keepalive tensors across overlapped collectives).
- Updates prefill alignment logic so PrefillDelayer delays until all DP ranks are “alignment-ready” (>=2 local head prefills when TBO is enabled).
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| atom/utils/tbo/ubatching.py | Extends DP sync result to include per-ubatch per-rank token counts for TBO/DP variable-length collectives. |
| atom/utils/tbo/ubatch_wrapper.py | Rebuilds DPMetadata per ubatch using per-ubatch token counts; propagates dp_uniform_decode into ubatch Context. |
| atom/utils/forward_context.py | Adds ub_tokens_across_dp plumbing into ForwardContext / set_forward_context. |
| atom/model_ops/topK.py | Adjusts MORI/all2all gating intended for DPA fallback vs EP mode (but currently has a logic issue). |
| atom/model_ops/moe.py | Zeroes DP all-gather padding rows; scales MoE max token metadata for DP fallback; adds TBO keepalive to prevent premature tensor frees. |
| atom/model_ops/attention_mla.py | Enables persistent MLA for multi-rank DP up to dp_size <= 8. |
| atom/model_engine/scheduler.py | Replaces boolean “prefillable” with counted head-prefill admission and exports both presence + alignment readiness signals. |
| atom/model_engine/prefill_delayer.py | Adds local_alignment_ready and expands MAX-reduce buffer to gate prefill on cross-DP alignment readiness. |
| atom/model_engine/model_runner.py | Threads ub_tokens_across_dp from DP sync into set_forward_context for downstream ubatch/DP metadata. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| and config.enable_expert_parallel | ||
| ) | ||
| if use_mori_all2all: | ||
| return False | ||
| return True |
…5X (this perf bug seems MI355X only)" This reverts commit 0a7feab.
| return False | ||
| break | ||
|
|
||
| dp_size = config.parallel_config.data_parallel_size |
There was a problem hiding this comment.
This is a duplicate? You can check line 24.
if dp_size > 1 and _has_module("mori") and config.enable_dp_attention: return False
There was a problem hiding this comment.
Thanks for pointing out. That's a rebase error (fixed now). Here we try to enable shared expert fusion for DPA for allgather/reducescatter MoE path (not mori all2all).
| and not self.moe_parallel_config.use_all2all_kernels | ||
| and atom_config.enable_dp_attention | ||
| ): | ||
| moe_max_num_tokens *= self.moe_parallel_config.dp_size |
There was a problem hiding this comment.
I don't understand why we need moe_max_num_tokens *= self.moe_parallel_config.dp_size here.. In all_gahter and model runner, we have padded, * dp_size here will make BS large and kernel bad perf
There was a problem hiding this comment.
Here, we only increase the size of the preallocated internal buffer in FusedMoE, not the actual batch size used in the forward pass. This internal buffer needs to be large enough to accommodate tokens from all DP ranks, so we multiply by dp_size, similar to what we've already done for the all-gather / reduce-scatter buffers.
31ab320 to
426f176
Compare
| num_new_tokens = seq.num_tokens - seq.num_cached_tokens | ||
| if num_new_tokens > self.max_num_batched_tokens: | ||
| continue | ||
| if self.block_manager.can_allocate(seq) < 0: | ||
| return False # KV-pressured: definitely cannot prefill | ||
| return True | ||
| return False | ||
| break # KV-pressured: definitely cannot prefill more now. |
| if _tbo: | ||
| tbo_switch_to_compute_sync() | ||
| self._hold_tbo_keepalive("ag_output", hidden_states, router_logits) |
| if _tbo: | ||
| tbo_switch_to_compute_sync() | ||
| self._hold_tbo_keepalive("rs_output", final_hidden_states) |
Motivation
Kimi K2.5 inference under Data-Parallel Attention (DPA) combined with Two-Batch Overlap (TBO) exposed several gaps that either crashed the engine or left performance on the table. This PR enables the DPA + TBO path end-to-end: it fixes the fused-MoE fallback and tensor lifetimes in the TBO overlap, aligns cross-DP prefill admission with TBO's two-batch requirement, and extends persistent MLA to multi-rank DP.
Technical Details
Persistent MLA for DP attention (
attention_mla.py): relaxuse_persistent_modefrom "single-rank only" (not (dp_size > 1)) todp_size <= 8, so persistent MLA also runs in the multi-rank DP configuration used by Kimi K2.5.Fix fused MoE on the DPA fallback path (
moe.py,topK.py):dp_size > 1, no MORI all2all), MoE runs afterall_gather_with_padding, so the token dim can grow todp_size ×the per-rank max. Scalemax_num_tokensfor the topK / fused-MoE metadata accordingly to avoid undersized buffers.enable_expert_parallel), so DPA-without-EP correctly falls back instead of assuming all2all.Fix TBO tensor live range (
moe.py): add a per-(role, ubatch)_TBO_KEEPALIVEholder around the all-gather and reduce-scatter comm/compute switches. Under TBO the source/output tensors of in-flight collectives could be freed before the overlapping ubatch waited on the comm; the keepalive defers release to the next same-role hold, which is past the wait point.Two-batch-aware prefill alignment (
scheduler.py,prefill_delayer.py): TBO prefill splitting needs at least two local prefill requests per DP rank. Replace_can_admit_head_prefill(boolean) with_count_admittable_head_prefills(limit)and a_prefill_delayer_readiness()helper that reports both "has any prefill" and "alignment-ready" (>= 2requests when TBO is on,>= 1otherwise).PrefillDelayergains a 4th MAX-reduce slot (local_alignment_ready) so prefill is delayed until every DP rank can launch a full two-batch, not just until one rank has a request.Fix TBO prefill ubatch DP offsets by propagating per-ubatch per-rank token counts through ForwardContext, then rebuilding ubatch-local DPMetadata inside UBatchWrapper. This prevents DP all_gatherv/reduce_scatterv from using full-batch offsets for individual ubatches.
Zero MoE all-gather padding rows before fused-MoE routing/sort/dispatch. Padding rows are later sliced away, but they still participate in fused MoE internals; leaving them uninitialized can introduce NaN/Inf garbage, perturb expert buckets/shared scratch, and corrupt real tokens.
Bugfix Validation
Bad TBO run before MoE padding fix: GSM8K flexible 0.8999, with 125 invalid responses and 238 corrupted outputs.
After fix: GSM8K flexible 0.9742, invalid down to 1, corrupted outputs down to 0.
Perf Benchmark Plan
We tested Kimi K2.5 MXFP4 end-to-end inference on MI355X with ROCm 7.2.3, TP4.
The comparison includes:
Test Results
Numbers in parentheses are throughput/GPU changes relative to the baseline.
At higher concurrency, DPA and TBO show a higher throughput ceiling, with DPA+TBO reaching +15.3% throughput/GPU over the baseline at conc=128.
Submission Checklist