Skip to content

Enable TBO Support & Fix Accuracy Regressions for Kimi K2.5#1369

Open
jpy794 wants to merge 9 commits into
ROCm:mainfrom
RadeonFlow:rf-dpa-tbo-rebase
Open

Enable TBO Support & Fix Accuracy Regressions for Kimi K2.5#1369
jpy794 wants to merge 9 commits into
ROCm:mainfrom
RadeonFlow:rf-dpa-tbo-rebase

Conversation

@jpy794

@jpy794 jpy794 commented Jun 26, 2026

Copy link
Copy Markdown

Motivation

Kimi K2.5 inference under Data-Parallel Attention (DPA) combined with Two-Batch Overlap (TBO) exposed several gaps that either crashed the engine or left performance on the table. This PR enables the DPA + TBO path end-to-end: it fixes the fused-MoE fallback and tensor lifetimes in the TBO overlap, aligns cross-DP prefill admission with TBO's two-batch requirement, and extends persistent MLA to multi-rank DP.

Technical Details

  • Persistent MLA for DP attention (attention_mla.py): relax use_persistent_mode from "single-rank only" (not (dp_size > 1)) to dp_size <= 8, so persistent MLA also runs in the multi-rank DP configuration used by Kimi K2.5.

  • Fix fused MoE on the DPA fallback path (moe.py, topK.py):

    • In the DP-attn fallback (dp_size > 1, no MORI all2all), MoE runs after all_gather_with_padding, so the token dim can grow to dp_size × the per-rank max. Scale max_num_tokens for the topK / fused-MoE metadata accordingly to avoid undersized buffers.
    • Only select the MORI all2all path when expert parallel is actually enabled (enable_expert_parallel), so DPA-without-EP correctly falls back instead of assuming all2all.
  • Fix TBO tensor live range (moe.py): add a per-(role, ubatch) _TBO_KEEPALIVE holder around the all-gather and reduce-scatter comm/compute switches. Under TBO the source/output tensors of in-flight collectives could be freed before the overlapping ubatch waited on the comm; the keepalive defers release to the next same-role hold, which is past the wait point.

  • Two-batch-aware prefill alignment (scheduler.py, prefill_delayer.py): TBO prefill splitting needs at least two local prefill requests per DP rank. Replace _can_admit_head_prefill (boolean) with _count_admittable_head_prefills(limit) and a _prefill_delayer_readiness() helper that reports both "has any prefill" and "alignment-ready" (>= 2 requests when TBO is on, >= 1 otherwise). PrefillDelayer gains a 4th MAX-reduce slot (local_alignment_ready) so prefill is delayed until every DP rank can launch a full two-batch, not just until one rank has a request.

  • Fix TBO prefill ubatch DP offsets by propagating per-ubatch per-rank token counts through ForwardContext, then rebuilding ubatch-local DPMetadata inside UBatchWrapper. This prevents DP all_gatherv/reduce_scatterv from using full-batch offsets for individual ubatches.

  • Zero MoE all-gather padding rows before fused-MoE routing/sort/dispatch. Padding rows are later sliced away, but they still participate in fused MoE internals; leaving them uninitialized can introduce NaN/Inf garbage, perturb expert buckets/shared scratch, and corrupt real tokens.

Bugfix Validation

Bad TBO run before MoE padding fix: GSM8K flexible 0.8999, with 125 invalid responses and 238 corrupted outputs.
After fix: GSM8K flexible 0.9742, invalid down to 1, corrupted outputs down to 0.

Perf Benchmark Plan

We tested Kimi K2.5 MXFP4 end-to-end inference on MI355X with ROCm 7.2.3, TP4.

The comparison includes:

  • baseline without DPA / TBO
  • DPA only
  • DPA + TBO

Test Results

Numbers in parentheses are throughput/GPU changes relative to the baseline.

Conc baseline throughput/GPU baseline interactivity DPA throughput/GPU DPA interactivity DPA+TBO throughput/GPU DPA+TBO interactivity
4 898.9 104.95 611.7 (-32.0%) 73.53 610.2 (-32.1%) 73.63
8 1502.8 89.61 1061.4 (-29.4%) 65.48 1077.1 (-28.3%) 65.67
16 2045.6 61.85 1653.5 (-19.2%) 51.93 1677.6 (-18.0%) 52.86
32 2964.5 43.46 2542.4 (-14.2%) 38.51 2588.0 (-12.7%) 39.21
64 3946.8 29.06 3761.2 (-4.7%) 28.52 3961.7 (+0.4%) 30.09
128 4984.5 19.81 5133.1 (+3.0%) 21.04 5745.4 (+15.3%) 23.55

At higher concurrency, DPA and TBO show a higher throughput ceiling, with DPA+TBO reaching +15.3% throughput/GPU over the baseline at conc=128.

Submission Checklist

Copilot AI review requested due to automatic review settings June 26, 2026 08:44

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables end-to-end Data-Parallel Attention (DPA) + Two-Batch Overlap (TBO) for Kimi K2.5 by fixing DP/TBO micro-batch metadata, strengthening MoE correctness on DP fallback paths, and improving cross-DP prefill admission alignment for TBO’s two-batch requirement.

Changes:

  • Propagates per-ubatch per-rank token counts (ub_tokens_across_dp) through DP sync and ForwardContext, and rebuilds ubatch-local DPMetadata in UBatchWrapper.
  • Fixes MoE correctness/stability under DP fallback + TBO (zero padding rows; scale max token metadata; keepalive tensors across overlapped collectives).
  • Updates prefill alignment logic so PrefillDelayer delays until all DP ranks are “alignment-ready” (>=2 local head prefills when TBO is enabled).

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
atom/utils/tbo/ubatching.py Extends DP sync result to include per-ubatch per-rank token counts for TBO/DP variable-length collectives.
atom/utils/tbo/ubatch_wrapper.py Rebuilds DPMetadata per ubatch using per-ubatch token counts; propagates dp_uniform_decode into ubatch Context.
atom/utils/forward_context.py Adds ub_tokens_across_dp plumbing into ForwardContext / set_forward_context.
atom/model_ops/topK.py Adjusts MORI/all2all gating intended for DPA fallback vs EP mode (but currently has a logic issue).
atom/model_ops/moe.py Zeroes DP all-gather padding rows; scales MoE max token metadata for DP fallback; adds TBO keepalive to prevent premature tensor frees.
atom/model_ops/attention_mla.py Enables persistent MLA for multi-rank DP up to dp_size <= 8.
atom/model_engine/scheduler.py Replaces boolean “prefillable” with counted head-prefill admission and exports both presence + alignment readiness signals.
atom/model_engine/prefill_delayer.py Adds local_alignment_ready and expands MAX-reduce buffer to gate prefill on cross-DP alignment readiness.
atom/model_engine/model_runner.py Threads ub_tokens_across_dp from DP sync into set_forward_context for downstream ubatch/DP metadata.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/model_ops/topK.py Outdated
Comment on lines 70 to 74
and config.enable_expert_parallel
)
if use_mori_all2all:
return False
return True
@valarLip valarLip requested a review from ZhangLirong-amd June 26, 2026 10:07
Comment thread atom/model_ops/topK.py Outdated
return False
break

dp_size = config.parallel_config.data_parallel_size

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a duplicate? You can check line 24.
if dp_size > 1 and _has_module("mori") and config.enable_dp_attention: return False

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out. That's a rebase error (fixed now). Here we try to enable shared expert fusion for DPA for allgather/reducescatter MoE path (not mori all2all).

Comment thread atom/model_ops/moe.py
and not self.moe_parallel_config.use_all2all_kernels
and atom_config.enable_dp_attention
):
moe_max_num_tokens *= self.moe_parallel_config.dp_size

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we need moe_max_num_tokens *= self.moe_parallel_config.dp_size here.. In all_gahter and model runner, we have padded, * dp_size here will make BS large and kernel bad perf

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we only increase the size of the preallocated internal buffer in FusedMoE, not the actual batch size used in the forward pass. This internal buffer needs to be large enough to accommodate tokens from all DP ranks, so we multiply by dp_size, similar to what we've already done for the all-gather / reduce-scatter buffers.

@jpy794 jpy794 force-pushed the rf-dpa-tbo-rebase branch from 31ab320 to 426f176 Compare June 26, 2026 12:49
Copilot AI review requested due to automatic review settings June 26, 2026 13:32

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Comment on lines 537 to +541
num_new_tokens = seq.num_tokens - seq.num_cached_tokens
if num_new_tokens > self.max_num_batched_tokens:
continue
if self.block_manager.can_allocate(seq) < 0:
return False # KV-pressured: definitely cannot prefill
return True
return False
break # KV-pressured: definitely cannot prefill more now.
Comment thread atom/model_ops/moe.py
Comment on lines 3406 to +3408
if _tbo:
tbo_switch_to_compute_sync()
self._hold_tbo_keepalive("ag_output", hidden_states, router_logits)
Comment thread atom/model_ops/moe.py
Comment on lines 3443 to +3445
if _tbo:
tbo_switch_to_compute_sync()
self._hold_tbo_keepalive("rs_output", final_hidden_states)
@zufayu zufayu requested a review from ZhangLirong-amd June 26, 2026 14:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants