-
Notifications
You must be signed in to change notification settings - Fork 15
Qwen3 5 mtp final 2 #326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Qwen3 5 mtp final 2 #326
Changes from all commits
66297e5
dadf3ca
d210d82
3365b78
2db7e8b
0ebea3e
a7b86f7
40534cb
fbfe32c
19ef49b
6d850df
a56c416
1445aba
1d2c4f1
41e8371
e0f0731
c245897
c3bc3bd
32383a3
02c9835
d5466aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,26 @@ def aclgraph_use_torch_npu_update(): | |
|
|
||
|
|
||
| # AscendCudaGraphMixin methods for cudagraph buffer management. | ||
| def AscendCudaGraphMixin_support_cuda_graph( | ||
| self, | ||
| input_ids: Tensor, | ||
| position_ids: Tensor, | ||
| past_key_values: List[List[Tensor]], | ||
| attn_metadata: Any = None, | ||
| inputs_embeds: Tensor = None, | ||
| **kwargs, | ||
| ): | ||
| """Allow multi-token decode graph only when runtime length updates exist.""" | ||
| if attn_metadata is None: | ||
| return False | ||
|
|
||
| is_decoding = getattr(attn_metadata, "is_decoding", False) | ||
| is_multi_token = getattr(attn_metadata, "is_multi_token_decoding", False) | ||
| if is_multi_token and not aclgraph_use_torch_npu_update(): | ||
| return False | ||
| return is_decoding or is_multi_token | ||
|
|
||
|
|
||
| def AscendCudaGraphMixin_make_buffers_cudagraph( | ||
| self, graph_meta: CudaGraphMeta, *args, **kwargs | ||
| ) -> BuffType: | ||
|
|
@@ -58,29 +78,35 @@ def AscendCudaGraphMixin_make_buffers_cudagraph( | |
| (max_batches, num_blocks), dtype=torch.int32, device=device | ||
| ) | ||
|
|
||
| input_buffers["q_seqlens"] = torch.ones( | ||
| max_batches, dtype=torch.int32, device=device | ||
| ) | ||
|
|
||
| input_buffers["kv_seqlens"] = torch.ones(max_batches, dtype=torch.int32) | ||
|
|
||
| input_buffers["q_start_loc"] = torch.arange( | ||
| max_batches + 1, dtype=torch.int32, device=device | ||
| ) | ||
|
|
||
| input_buffers["kv_start_indices"] = -torch.ones( | ||
| (max_batches), dtype=torch.int32, device=device | ||
| (max_tokens), dtype=torch.int32, device=device | ||
| ) | ||
|
|
||
| input_buffers["x_active_mask"] = torch.zeros( | ||
| (max_batches), dtype=torch.bool, device=device | ||
| (max_tokens), dtype=torch.bool, device=device | ||
| ) | ||
|
|
||
| input_buffers["attention_mask"] = torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=device), diagonal=1) | ||
|
|
||
| # ssm | ||
| if graph_meta.is_ssm: | ||
| input_buffers["state_ids"] = torch.full( | ||
| (max_batches,), -1, dtype=torch.int64, device=device | ||
| ) | ||
| input_buffers["cache_seqlens"] = torch.zeros( | ||
| max_batches, dtype=torch.int32, device=device | ||
| ) | ||
|
|
||
| if max_batches != max_tokens: | ||
| max_q_seq_len = max_tokens // max_batches | ||
| input_buffers["q_seqlens"] = torch.arange(1, max_batches+1, dtype=torch.int32) * max_q_seq_len | ||
| input_buffers["q_start_loc"] = torch.arange(max_batches + 1, dtype=torch.int32, device=device) * max_q_seq_len | ||
|
|
||
| else: | ||
| input_buffers["q_seqlens"] = torch.arange(1, max_batches + 1, dtype=torch.int32) | ||
| input_buffers["q_start_loc"] = torch.arange(max_batches + 1, dtype=torch.int32, device=device) | ||
|
|
||
| # mrope | ||
| if graph_meta.use_mrope: | ||
|
|
@@ -108,6 +134,9 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( | |
| moe_metadata = get_step_ctx_manager().current_context().moe_metadata | ||
| x_active_mask: Tensor = moe_metadata.x_active_mask | ||
| q_start_loc: Tensor = attn_metadata.q_start_loc | ||
| cache_seqlens: Tensor = attn_metadata.cache_seqlens | ||
|
|
||
| is_multi_token_decoding = attn_metadata.is_multi_token_decoding | ||
|
|
||
| input_buffers: BuffType = graph_meta.input_buffers | ||
|
|
||
|
|
@@ -121,27 +150,34 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( | |
| input_buffers["input_ids"][:, num_tokens:max_num_tokens].random_( | ||
| 0, graph_meta.vocab_size | ||
| ) | ||
|
|
||
| input_buffers["input_ids"][:, :num_tokens] = input_ids | ||
| input_buffers["position_ids"].zero_() | ||
| input_buffers["position_ids"][:, :num_tokens] = position_ids | ||
| input_buffers["block_offsets"].zero_() | ||
| input_buffers["block_offsets"][:batch_size, :num_blocks] = block_offsets | ||
|
|
||
| input_buffers["kv_seqlens"].fill_(0) | ||
| input_buffers["kv_seqlens"][:batch_size] = kv_seqlens | ||
| input_buffers["kv_start_indices"].fill_(-1) | ||
| input_buffers["kv_start_indices"][:batch_size] = kv_start_indices | ||
| input_buffers["kv_start_indices"][:kv_start_indices.size(0)] = kv_start_indices | ||
| if x_active_mask is not None: | ||
| input_buffers["x_active_mask"].fill_(0) | ||
| input_buffers["x_active_mask"][:batch_size] = x_active_mask | ||
| input_buffers["x_active_mask"][:x_active_mask.size(0)] = x_active_mask | ||
|
|
||
| # ssm | ||
| if graph_meta.is_ssm: | ||
| input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc | ||
| input_buffers["q_start_loc"][batch_size + 1 :] = q_start_loc[-1] | ||
|
|
||
| state_ids = kwargs["state_ids"] | ||
| input_buffers["state_ids"].fill_(-1) | ||
| input_buffers["state_ids"][: state_ids.size(0)].copy_(state_ids) | ||
| input_buffers["state_ids"].fill_(0) | ||
| input_buffers["state_ids"][: batch_size].copy_(state_ids) | ||
|
|
||
| if is_multi_token_decoding: | ||
| input_buffers["cache_seqlens"].fill_(0) | ||
| input_buffers["cache_seqlens"][: batch_size].copy_(cache_seqlens) | ||
|
|
||
| attn_metadata.cache_seqlens = input_buffers["cache_seqlens"] | ||
|
|
||
| if is_multi_token_decoding: | ||
| attn_metadata.attention_mask = [input_buffers["attention_mask"]] | ||
|
|
||
| if inputs_embeds is not None: | ||
| emb_size = inputs_embeds.size(-1) | ||
|
|
@@ -151,15 +187,13 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( | |
| 1, max_num_tokens, emb_size | ||
| ) | ||
| input_buffers["inputs_embeds"][:, :num_tokens] = inputs_embeds | ||
| # create inputs | ||
| # Use compatible size but cap at graph's max_batchs to avoid buffer overflow | ||
| new_batch_size = min(get_ascend_compatible_size(batch_size), graph_meta.max_batchs) | ||
|
|
||
|
|
||
| attn_metadata.block_offsets = input_buffers["block_offsets"] | ||
| attn_metadata.kv_seqlens = input_buffers["kv_seqlens"] | ||
| attn_metadata.kv_start_indices = input_buffers["kv_start_indices"] | ||
| moe_metadata.x_active_mask = input_buffers["x_active_mask"] | ||
| attn_metadata.q_start_loc = input_buffers["q_start_loc"] | ||
| attn_metadata.q_seqlens = input_buffers["q_seqlens"] | ||
|
|
||
| new_inputs = dict( | ||
| past_key_values=past_key_values, | ||
|
|
@@ -175,7 +209,6 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph( | |
|
|
||
| new_inputs.update(kwargs) | ||
|
|
||
| # ssm: override kwargs' variable-length state_ids with the fixed-size buffer | ||
| if graph_meta.is_ssm: | ||
| new_inputs["state_ids"] = input_buffers["state_ids"] | ||
|
|
||
|
|
@@ -194,7 +227,6 @@ def AscendCudaGraphMixin_update_context_cudagraph(self, graph_meta, context): | |
| """update step context with input buffers.""" | ||
| input_buffers = graph_meta.input_buffers | ||
| context.block_offsets = input_buffers["block_offsets"] | ||
| context.q_seqlens = input_buffers["q_seqlens"] | ||
| context.kv_seqlens = input_buffers["kv_seqlens"] | ||
| context.q_start_loc = input_buffers["q_start_loc"] | ||
| context.kv_start_indices = input_buffers["kv_start_indices"] | ||
|
|
@@ -209,6 +241,7 @@ def AscendCudaGraphMixin_update_context_cudagraph(self, graph_meta, context): | |
| context.mrope_position_ids = input_buffers["mrope_position_ids"] | ||
|
|
||
|
|
||
| CudaGraphMixin.support_cuda_graph = AscendCudaGraphMixin_support_cuda_graph | ||
| CudaGraphMixin.make_buffers_cudagraph = AscendCudaGraphMixin_make_buffers_cudagraph | ||
| CudaGraphMixin.fill_buffers_cudagraph = AscendCudaGraphMixin_fill_buffers_cudagraph | ||
| CudaGraphMixin.update_context_cudagraph = AscendCudaGraphMixin_update_context_cudagraph | ||
|
|
@@ -358,7 +391,7 @@ def forward(self, **kwargs): | |
| ] | ||
| ) | ||
| else: | ||
| update_attn_params(self.update_stream, self.meta, self.max_tokens) | ||
| update_attn_params(self.update_stream, self.meta, self.max_batches) | ||
| self._graph.replay() | ||
| output_buffers = self.meta.output_buffers | ||
| output = self.model.get_outputs_cudagraph(output_buffers, **kwargs) | ||
|
|
@@ -427,19 +460,33 @@ def _get_capture_tokens(self, batch_size: int): | |
| def get_graph_key( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| attn_metadata: Any, | ||
| **kwargs, | ||
| ): | ||
| """Get graph key.""" | ||
| context = self.ctx_mgr.current_context() | ||
| is_decoding = context.is_decoding | ||
| num_tokens = input_ids.numel() | ||
| is_decoding = attn_metadata.is_decoding | ||
| is_multi_token_decoding = attn_metadata.is_multi_token_decoding | ||
| meta = self.get_meta() | ||
| enable_microbatch = get_step_ctx_manager().current_context().enable_microbatch | ||
|
|
||
| if is_multi_token_decoding: | ||
| q_seqlens = attn_metadata.q_seqlens | ||
| max_q_seq_len = attn_metadata.max_q_seq_len | ||
| batch_size = q_seqlens.size(0) | ||
| if meta.padding_batch_size is None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 确认meta.padding_batch_size 的值,以及相关逻辑
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. padding_batch_size 是 DP(Data Parallel)场景下,所有 DP rank 中最大的 token 数,用于对齐各 rank 的 CudaGraph key,使所有 rank 复用同一个图。 |
||
| new_batch_size = self._get_capture_tokens(batch_size) | ||
| else: | ||
| padding_num_tokens = meta.padding_batch_size | ||
| padding_batch_size = (padding_num_tokens + max_q_seq_len - 1) // max_q_seq_len | ||
| new_batch_size = self._get_capture_tokens(padding_batch_size) | ||
| return (new_batch_size, is_multi_token_decoding, enable_microbatch, max_q_seq_len) | ||
|
|
||
| num_tokens = input_ids.numel() | ||
| if meta.padding_batch_size is None: | ||
| new_num_tokens = self._get_capture_tokens(num_tokens) | ||
| else: | ||
| new_num_tokens = self._get_capture_tokens(meta.padding_batch_size) | ||
| return (new_num_tokens, is_decoding, enable_microbatch) | ||
| return (new_num_tokens, is_decoding, enable_microbatch, 1) | ||
|
|
||
| def __call__(self, **kwargs): | ||
| """call.""" | ||
|
|
@@ -451,16 +498,21 @@ def __call__(self, **kwargs): | |
| return self.model.make_output_buffers(ret) | ||
|
|
||
| graph_key = self.get_graph_key(**kwargs) | ||
| max_tokens = graph_key[0] | ||
| is_decoding = graph_key[1] | ||
| max_batches = graph_key[0] | ||
| is_decoding_or_multi_token_decoding = graph_key[1] | ||
| max_q_seq_len = graph_key[3] | ||
| if is_decoding_or_multi_token_decoding: | ||
| max_tokens = max_batches * max_q_seq_len | ||
| else: | ||
| max_tokens = max_batches | ||
| max_batches = self.max_batches | ||
| if graph_key not in self._runner_map: | ||
| max_batches = max_tokens if is_decoding else self.max_batches | ||
| runner = AscendSingleGraphRunner( | ||
| self.model, | ||
| max_batches=max_batches, | ||
| max_tokens=max_tokens, | ||
| num_blocks=self.num_blocks, | ||
| is_decoding=is_decoding, | ||
| is_decoding=is_decoding_or_multi_token_decoding, | ||
| pool=self.graph_pool_handle, | ||
| model_config=self.model_config, | ||
| device=self.device, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nv上是否有这个函数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvidia上没有。这个patch本质上是因为昇腾增加了参数is_multi_token_decoding。
CUDA只有两种状态:
Ascend 实际上将 decode 分支细分为两种:
场景 | is_decoding (局部) | is_multi_token_decoding
prefill | False | False
单 token 解码 | True | False
MTP 多 token 解码 | False | True
is_multi_token_decoding 必须存在的根本原因——Ascend 的 GatedDeltaRule(Qwen3.5 linear attention 层)需要三路 kernel 分发。不同kernel依赖的参数、buffer不一样,因此增加了细分状态flag is_multi_token_decoding。
CUDA 为何不需要? CUDA 侧用的是 FLA 库(fla.ops.utils),chunk_gated_delta_rule 通过 cu_seqlens_q 可变长接口天然支持 prefill 和 MTP 走同一套代码路径