Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ on:
env:
CI_PATH: "${HOME}/GitHub/${{ github.repository }}/${GITHUB_RUN_NUMBER}"
LMDEPLOY_PATH: "${HOME}/GitHub/lmdeploy"
LMDEPLOY_COMMIT_OR_BRANCH: 'main'
LMDEPLOY_COMMIT_OR_BRANCH: 'qwen3_5_mtp_final_2'
REPORT_DIR: "${HOME}/GitHub/ci_log/test_reports"
TEST_LMDEPLOY_E2E_LOG_PATH: "${HOME}/Github/ci_log/logs"
TEST_LMDEPLOY_E2E_MODEL_PATH: "${HOME}/Github/model"
Expand Down Expand Up @@ -73,7 +73,7 @@ jobs:
- name: Clone lmdeploy
run: |
set -ex
git clone https://github.com/InternLM/lmdeploy.git ${{ env.LMDEPLOY_PATH }}
git clone https://github.com/wanfengcxz/lmdeploy.git ${{ env.LMDEPLOY_PATH }}
cd ${{ env.LMDEPLOY_PATH }} && git checkout ${{ env.LMDEPLOY_COMMIT_OR_BRANCH }}
# git apply ${{env.CI_PATH }}/.github/ci/fix-exit-multi-npu.patch

Expand Down
118 changes: 85 additions & 33 deletions dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ def aclgraph_use_torch_npu_update():


# AscendCudaGraphMixin methods for cudagraph buffer management.
def AscendCudaGraphMixin_support_cuda_graph(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nv上是否有这个函数

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvidia上没有。这个patch本质上是因为昇腾增加了参数is_multi_token_decoding。
CUDA只有两种状态:

  • False → prefill(prompt 首次处理)
  • True → decode(自回归推理,含 MTP/spec decode)

Ascend 实际上将 decode 分支细分为两种:
场景 | is_decoding (局部) | is_multi_token_decoding
prefill | False | False
单 token 解码 | True | False
MTP 多 token 解码 | False | True

is_multi_token_decoding 必须存在的根本原因——Ascend 的 GatedDeltaRule(Qwen3.5 linear attention 层)需要三路 kernel 分发。不同kernel依赖的参数、buffer不一样,因此增加了细分状态flag is_multi_token_decoding。
CUDA 为何不需要? CUDA 侧用的是 FLA 库(fla.ops.utils),chunk_gated_delta_rule 通过 cu_seqlens_q 可变长接口天然支持 prefill 和 MTP 走同一套代码路径

self,
input_ids: Tensor,
position_ids: Tensor,
past_key_values: List[List[Tensor]],
attn_metadata: Any = None,
inputs_embeds: Tensor = None,
**kwargs,
):
"""Allow multi-token decode graph only when runtime length updates exist."""
if attn_metadata is None:
return False

is_decoding = getattr(attn_metadata, "is_decoding", False)
is_multi_token = getattr(attn_metadata, "is_multi_token_decoding", False)
if is_multi_token and not aclgraph_use_torch_npu_update():
return False
return is_decoding or is_multi_token


def AscendCudaGraphMixin_make_buffers_cudagraph(
self, graph_meta: CudaGraphMeta, *args, **kwargs
) -> BuffType:
Expand All @@ -58,29 +78,35 @@ def AscendCudaGraphMixin_make_buffers_cudagraph(
(max_batches, num_blocks), dtype=torch.int32, device=device
)

input_buffers["q_seqlens"] = torch.ones(
max_batches, dtype=torch.int32, device=device
)

input_buffers["kv_seqlens"] = torch.ones(max_batches, dtype=torch.int32)

input_buffers["q_start_loc"] = torch.arange(
max_batches + 1, dtype=torch.int32, device=device
)

input_buffers["kv_start_indices"] = -torch.ones(
(max_batches), dtype=torch.int32, device=device
(max_tokens), dtype=torch.int32, device=device
)

input_buffers["x_active_mask"] = torch.zeros(
(max_batches), dtype=torch.bool, device=device
(max_tokens), dtype=torch.bool, device=device
)

input_buffers["attention_mask"] = torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=device), diagonal=1)

# ssm
if graph_meta.is_ssm:
input_buffers["state_ids"] = torch.full(
(max_batches,), -1, dtype=torch.int64, device=device
)
input_buffers["cache_seqlens"] = torch.zeros(
max_batches, dtype=torch.int32, device=device
)

if max_batches != max_tokens:
max_q_seq_len = max_tokens // max_batches
input_buffers["q_seqlens"] = torch.arange(1, max_batches+1, dtype=torch.int32) * max_q_seq_len
input_buffers["q_start_loc"] = torch.arange(max_batches + 1, dtype=torch.int32, device=device) * max_q_seq_len

else:
input_buffers["q_seqlens"] = torch.arange(1, max_batches + 1, dtype=torch.int32)
input_buffers["q_start_loc"] = torch.arange(max_batches + 1, dtype=torch.int32, device=device)

# mrope
if graph_meta.use_mrope:
Expand Down Expand Up @@ -108,6 +134,9 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(
moe_metadata = get_step_ctx_manager().current_context().moe_metadata
x_active_mask: Tensor = moe_metadata.x_active_mask
q_start_loc: Tensor = attn_metadata.q_start_loc
cache_seqlens: Tensor = attn_metadata.cache_seqlens

is_multi_token_decoding = attn_metadata.is_multi_token_decoding

input_buffers: BuffType = graph_meta.input_buffers

Expand All @@ -121,27 +150,34 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(
input_buffers["input_ids"][:, num_tokens:max_num_tokens].random_(
0, graph_meta.vocab_size
)

input_buffers["input_ids"][:, :num_tokens] = input_ids
input_buffers["position_ids"].zero_()
input_buffers["position_ids"][:, :num_tokens] = position_ids
input_buffers["block_offsets"].zero_()
input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets

input_buffers["kv_seqlens"].fill_(0)
input_buffers["kv_seqlens"][:batch_size] = kv_seqlens
input_buffers["kv_start_indices"].fill_(-1)
input_buffers["kv_start_indices"][:batch_size] = kv_start_indices
input_buffers["kv_start_indices"][:kv_start_indices.size(0)] = kv_start_indices
if x_active_mask is not None:
input_buffers["x_active_mask"].fill_(0)
input_buffers["x_active_mask"][:batch_size] = x_active_mask
input_buffers["x_active_mask"][:x_active_mask.size(0)] = x_active_mask

# ssm
if graph_meta.is_ssm:
input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc
input_buffers["q_start_loc"][batch_size + 1 :] = q_start_loc[-1]

state_ids = kwargs["state_ids"]
input_buffers["state_ids"].fill_(-1)
input_buffers["state_ids"][: state_ids.size(0)].copy_(state_ids)
input_buffers["state_ids"].fill_(0)
input_buffers["state_ids"][: batch_size].copy_(state_ids)

if is_multi_token_decoding:
input_buffers["cache_seqlens"].fill_(0)
input_buffers["cache_seqlens"][: batch_size].copy_(cache_seqlens)

attn_metadata.cache_seqlens = input_buffers["cache_seqlens"]

if is_multi_token_decoding:
attn_metadata.attention_mask = [input_buffers["attention_mask"]]

if inputs_embeds is not None:
emb_size = inputs_embeds.size(-1)
Expand All @@ -151,15 +187,13 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(
1, max_num_tokens, emb_size
)
input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds
# create inputs
# Use compatible size but cap at graph's max_batchs to avoid buffer overflow
new_batch_size = min(get_ascend_compatible_size(batch_size), graph_meta.max_batchs)


attn_metadata.block_offsets = input_buffers["block_offsets"]
attn_metadata.kv_seqlens = input_buffers["kv_seqlens"]
attn_metadata.kv_start_indices = input_buffers["kv_start_indices"]
moe_metadata.x_active_mask = input_buffers["x_active_mask"]
attn_metadata.q_start_loc = input_buffers["q_start_loc"]
attn_metadata.q_seqlens = input_buffers["q_seqlens"]

new_inputs = dict(
past_key_values=past_key_values,
Expand All @@ -175,7 +209,6 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(

new_inputs.update(kwargs)

# ssm: override kwargs' variable-length state_ids with the fixed-size buffer
if graph_meta.is_ssm:
new_inputs["state_ids"] = input_buffers["state_ids"]

Expand All @@ -194,7 +227,6 @@ def AscendCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
"""update step context with input buffers."""
input_buffers = graph_meta.input_buffers
context.block_offsets = input_buffers["block_offsets"]
context.q_seqlens = input_buffers["q_seqlens"]
context.kv_seqlens = input_buffers["kv_seqlens"]
context.q_start_loc = input_buffers["q_start_loc"]
context.kv_start_indices = input_buffers["kv_start_indices"]
Expand All @@ -209,6 +241,7 @@ def AscendCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
context.mrope_position_ids = input_buffers["mrope_position_ids"]


CudaGraphMixin.support_cuda_graph = AscendCudaGraphMixin_support_cuda_graph
CudaGraphMixin.make_buffers_cudagraph = AscendCudaGraphMixin_make_buffers_cudagraph
CudaGraphMixin.fill_buffers_cudagraph = AscendCudaGraphMixin_fill_buffers_cudagraph
CudaGraphMixin.update_context_cudagraph = AscendCudaGraphMixin_update_context_cudagraph
Expand Down Expand Up @@ -358,7 +391,7 @@ def forward(self, **kwargs):
]
)
else:
update_attn_params(self.update_stream, self.meta, self.max_tokens)
update_attn_params(self.update_stream, self.meta, self.max_batches)
self._graph.replay()
output_buffers = self.meta.output_buffers
output = self.model.get_outputs_cudagraph(output_buffers, **kwargs)
Expand Down Expand Up @@ -427,19 +460,33 @@ def _get_capture_tokens(self, batch_size: int):
def get_graph_key(
self,
input_ids: torch.Tensor,
attn_metadata: Any,
**kwargs,
):
"""Get graph key."""
context = self.ctx_mgr.current_context()
is_decoding = context.is_decoding
num_tokens = input_ids.numel()
is_decoding = attn_metadata.is_decoding
is_multi_token_decoding = attn_metadata.is_multi_token_decoding
meta = self.get_meta()
enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch

if is_multi_token_decoding:
q_seqlens = attn_metadata.q_seqlens
max_q_seq_len = attn_metadata.max_q_seq_len
batch_size = q_seqlens.size(0)
if meta.padding_batch_size is None:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认meta.padding_batch_size 的值,以及相关逻辑

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

padding_batch_size 是 DP(Data Parallel)场景下,所有 DP rank 中最大的 token 数,用于对齐各 rank 的 CudaGraph key,使所有 rank 复用同一个图。
MTP 下 padding_batch_size 仍然是从 agent.py 传来的 token 总数,因此,将跨 rank 最大 token 数除以当前步的 draft 长度,还原成"对齐后的序列数"

new_batch_size = self._get_capture_tokens(batch_size)
else:
padding_num_tokens = meta.padding_batch_size
padding_batch_size = (padding_num_tokens + max_q_seq_len - 1) // max_q_seq_len
new_batch_size = self._get_capture_tokens(padding_batch_size)
return (new_batch_size, is_multi_token_decoding, enable_microbatch, max_q_seq_len)

num_tokens = input_ids.numel()
if meta.padding_batch_size is None:
new_num_tokens = self._get_capture_tokens(num_tokens)
else:
new_num_tokens = self._get_capture_tokens(meta.padding_batch_size)
return (new_num_tokens, is_decoding, enable_microbatch)
return (new_num_tokens, is_decoding, enable_microbatch, 1)

def __call__(self, **kwargs):
"""call."""
Expand All @@ -451,16 +498,21 @@ def __call__(self, **kwargs):
return self.model.make_output_buffers(ret)

graph_key = self.get_graph_key(**kwargs)
max_tokens = graph_key[0]
is_decoding = graph_key[1]
max_batches = graph_key[0]
is_decoding_or_multi_token_decoding = graph_key[1]
max_q_seq_len = graph_key[3]
if is_decoding_or_multi_token_decoding:
max_tokens = max_batches * max_q_seq_len
else:
max_tokens = max_batches
max_batches = self.max_batches
if graph_key not in self._runner_map:
max_batches = max_tokens if is_decoding else self.max_batches
runner = AscendSingleGraphRunner(
self.model,
max_batches=max_batches,
max_tokens=max_tokens,
num_blocks=self.num_blocks,
is_decoding=is_decoding,
is_decoding=is_decoding_or_multi_token_decoding,
pool=self.graph_pool_handle,
model_config=self.model_config,
device=self.device,
Expand Down
Loading
Loading