diff --git a/atom/__init__.py b/atom/__init__.py index 7c1c75eb3a..9049532b61 100644 --- a/atom/__init__.py +++ b/atom/__init__.py @@ -4,11 +4,10 @@ from atom.model_engine.llm_engine import LLMEngine from atom.sampling_params import SamplingParams -# interface for upper framework to construct the model from ATOM -from atom.plugin import prepare_model +from atom.plugin.sglang import prepare_model_for_sglang __all__ = [ "LLMEngine", "SamplingParams", - "prepare_model", + "prepare_model_for_sglang", ] diff --git a/atom/plugin/__init__.py b/atom/plugin/__init__.py index 27c855e511..315b40cf75 100644 --- a/atom/plugin/__init__.py +++ b/atom/plugin/__init__.py @@ -1,12 +1,6 @@ -from .prepare import ( - prepare_model, - is_sglang, - is_vllm, - is_plugin_mode, -) +from .prepare import is_plugin_mode, is_sglang, is_vllm __all__ = [ - "prepare_model", "is_sglang", "is_vllm", "is_plugin_mode", diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py index 3c7b722d29..ede7c9de64 100644 --- a/atom/plugin/prepare.py +++ b/atom/plugin/prepare.py @@ -1,8 +1,3 @@ -from typing import Any -import logging - -logger = logging.getLogger("atom") - # all of the supported frameworks, including server mode and plugin mode _SUPPORTED_FRAMEWORKS = ["vllm", "sglang", "sgl", "atom"] @@ -33,79 +28,3 @@ def _set_framework_backbone(framework: str) -> None: raise ValueError(f"Unsupported framework {framework} for ATOM to plug in") global _CURRENT_FRAMEWORK _CURRENT_FRAMEWORK = framework - - -def prepare_model(config: Any, engine: str): - """ - Prepare the model to upper framework SGLang - """ - logger.info(f"Prepare model for plugin mode, the upper engine is {engine}") - - _set_framework_backbone(engine) - - if is_sglang(): - model_arch = config.architectures[0] - else: - raise ValueError( - f"prepare_model does not support engine {engine!r} " - f"with config type {type(config)}" - ) - - # import here to avoid partial initialization - from .register import ( - _ATOM_SUPPORTED_MODELS, - # register_ops_to_vllm, - register_ops_to_sglang, - init_aiter_dist, - set_attn_cls, - ) - - if model_arch not in _ATOM_SUPPORTED_MODELS: - supported_archs = list(_ATOM_SUPPORTED_MODELS.keys()) - raise ValueError( - f"ATOM does not support the required model architecture: {model_arch}. " - f"For now supported model architectures: {supported_archs}" - ) - - from atom.plugin.config import generate_atom_config_for_plugin_mode - - atom_config = generate_atom_config_for_plugin_mode(config) - - model_cls = _ATOM_SUPPORTED_MODELS[model_arch] - logger.info(f"ATOM model class for {model_arch} is {model_cls}") - - if model_arch in { - "Qwen3_5ForConditionalGeneration", - "Qwen3_5MoeForConditionalGeneration", - }: - from atom.plugin.sglang.models.qwen3_5 import ( - apply_prepare_model_adaptations, - ) - - apply_prepare_model_adaptations(atom_config, model_arch) - - register_ops_to_sglang(atom_config=atom_config) - set_attn_cls() - - # init aiter dist for using aiter custom collective ops - init_aiter_dist(config=atom_config) - - # Patch SGLang graph_capture to also enter aiter's ca_comm.capture(), - # avoiding hipMemcpyAsync in aiter collectives when model uses aiter's - # custom all_reduce (same fix as atom/plugin/vllm/graph_capture_patch.py) - from atom.plugin.sglang.graph_capture_patch import apply_graph_capture_patch - - apply_graph_capture_patch() - - try: - model = model_cls(atom_config=atom_config) - except TypeError as exc: - # Some SGLang plugin models keep SGLang's native wrapper constructor - # and only swap their internal language_model with an ATOM model. - # Those classes accept `config=...` instead of `atom_config=...`. - if "atom_config" not in str(exc): - raise - model = model_cls(config=config) - if not hasattr(model, "atom_config"): - model.atom_config = atom_config - return model diff --git a/atom/plugin/register.py b/atom/plugin/register.py index f5db8a59a6..90ac34e71e 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -45,7 +45,7 @@ def _register_custom_attention_to_sglang() -> None: from sglang.srt.layers.attention.attention_registry import ( register_attention_backend, ) - from atom.plugin.sglang.attention_backend.sgl_attn_backend import ( + from atom.plugin.sglang.attention_backend.full_attention.full_attention_backend import ( ATOMAttnBackendForSgl, ) diff --git a/atom/plugin/sglang/__init__.py b/atom/plugin/sglang/__init__.py index e69de29bb2..d03f99a045 100644 --- a/atom/plugin/sglang/__init__.py +++ b/atom/plugin/sglang/__init__.py @@ -0,0 +1,3 @@ +from atom.plugin.sglang.prepare import prepare_model_for_sglang + +__all__ = ["prepare_model_for_sglang"] diff --git a/atom/plugin/sglang/attention.py b/atom/plugin/sglang/attention.py index f1474c45b1..44a810c152 100644 --- a/atom/plugin/sglang/attention.py +++ b/atom/plugin/sglang/attention.py @@ -1,4 +1,6 @@ -from atom.plugin.sglang.attention_backend.radix_attention import RadixAttention +from atom.plugin.sglang.attention_backend.full_attention.radix_attention import ( + RadixAttention, +) class AttentionForSGLang(RadixAttention): diff --git a/atom/plugin/sglang/attention_backend/attention_gdn.py b/atom/plugin/sglang/attention_backend/attention_gdn.py index 2403efaa9e..f29a43b1b7 100644 --- a/atom/plugin/sglang/attention_backend/attention_gdn.py +++ b/atom/plugin/sglang/attention_backend/attention_gdn.py @@ -164,7 +164,7 @@ def _build_gdn_metadata( def build( cls, forward_batch_or_metadata: Any ) -> Optional["SGLangGDNForwardContext"]: - from atom.plugin.sglang.models.base_model_wrapper import ( + from atom.plugin.sglang.runtime import ( SGLangForwardBatchMetadata, ) diff --git a/atom/plugin/sglang/attention_backend/full_attention/__init__.py b/atom/plugin/sglang/attention_backend/full_attention/__init__.py new file mode 100644 index 0000000000..2b9f6f2726 --- /dev/null +++ b/atom/plugin/sglang/attention_backend/full_attention/__init__.py @@ -0,0 +1,8 @@ +from .radix_attention import RadixAttention +from .full_attention_backend import ATOMAttnBackendForSgl, ForwardMetadata + +__all__ = [ + "RadixAttention", + "ATOMAttnBackendForSgl", + "ForwardMetadata", +] diff --git a/atom/plugin/sglang/attention_backend/sgl_attn_backend.py b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py similarity index 87% rename from atom/plugin/sglang/attention_backend/sgl_attn_backend.py rename to atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py index 9fab7eaff5..26ff848340 100644 --- a/atom/plugin/sglang/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py @@ -1,22 +1,19 @@ from __future__ import annotations -# sglang-specific attention backend replacing sglang's built-in AiterAttnBackend. -# Shared by ALL models (DeepSeek, Qwen3, etc.) — handles KV cache writes, -# page-table fixup, pa_persistent_fwd decode path, and MLA prefill kernels. -# Sits at the lowest layer of the attention stack: sglang's RadixAttention -# delegates the actual kernel dispatch here. +# SGLang full-attention backend replacing sglang's built-in AiterAttnBackend. +# Shared by ALL full-attention models (DeepSeek, Qwen3, etc.) — handles KV +# cache writes, page-table fixup, pa_persistent_fwd decode path, and MLA +# prefill kernels. Sits at the lowest layer of the attention stack: +# sglang's RadixAttention delegates the actual kernel dispatch here. # # TODO: rewrite this file once sglang's attention flow is unified into ATOM's # attention layer — KV cache management and attention kernel dispatch will then # be handled by ATOM's native backend, making sglang-specific overrides # unnecessary. -from dataclasses import dataclass from typing import TYPE_CHECKING, Optional import torch -import triton -import triton.language as tl import sglang.srt.layers.attention.aiter_backend as _sglang_aiter from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend @@ -28,6 +25,16 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import get_bool_env_var +from atom.plugin.sglang.attention_backend.full_attention.kv_cache import ( + set_kv_buffer_with_layout_shuffle as _set_kv_buffer_with_layout_shuffle, +) +from atom.plugin.sglang.attention_backend.full_attention.metadata import ForwardMetadata +from atom.plugin.sglang.attention_backend.full_attention.pa_metadata import ( + allocate_pa_metadata_buffers as _allocate_pa_metadata_buffers, + build_pa_metadata_for_decode as _build_pa_metadata_for_decode, + build_pa_metadata_for_prefill as _build_pa_metadata_for_prefill, +) + if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner @@ -38,7 +45,6 @@ flash_attn_varlen_func, dtypes, get_pa_metadata_info_v1, - get_pa_metadata_v1, mha_batch_prefill_func, pa_fwd_asm, pa_persistent_fwd, @@ -47,7 +53,7 @@ except ImportError as e: raise ImportError( "Failed to import 'aiter', which provides AMD-specific attention kernels " - "required by sgl_attn_backend. Please ensure 'aiter' is installed and " + "required by full_attention_backend. Please ensure 'aiter' is installed and " f"available on your AMD system. Original import error: {e}" ) from e @@ -71,140 +77,6 @@ pass -@triton.jit -def reshape_and_cache_shuffle_kernel( - key_ptr, # [num_tokens, num_kv_heads, head_size] - value_ptr, # [num_tokens, num_kv_heads, head_size] - key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x] - value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x] - slot_mapping_ptr, # [num_tokens] - k_scale_ptr, - v_scale_ptr, - x, - k_stride0, - v_stride0, - block_size, - head_size, - num_kv_heads, - BLOCK_SIZE: tl.constexpr, - QUANT: tl.constexpr, -): - tid = tl.program_id(0) - head_id = tl.program_id(1) - offset = tl.arange(0, BLOCK_SIZE) - src_offset_k = tid * k_stride0 + head_id * head_size - src_offset_v = tid * v_stride0 + head_id * head_size - slot_id = tl.load(slot_mapping_ptr + tid) - if slot_id < 0: - return - block_id = slot_id // block_size - block_offset = slot_id % block_size - dst_offset = ( - block_id * num_kv_heads * head_size * block_size - + head_id * head_size * block_size - ) - dst_k_shuffle_offset = ( - dst_offset + offset // x * block_size * x + block_offset * x + offset % x - ) - dst_v_shuffle_offset = ( - dst_offset + block_offset // x * head_size * x + offset * x + block_offset % x - ) - k_val = tl.load(key_ptr + src_offset_k + offset) - v_val = tl.load(value_ptr + src_offset_v + offset) - if QUANT: - k_scale = tl.load(k_scale_ptr) - v_scale = tl.load(v_scale_ptr) - k_dtype = key_cache_ptr.type.element_ty - v_dtype = value_cache_ptr.type.element_ty - k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype) - v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype) - tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) - tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) - - -def reshape_and_cache_shuffle_triton( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scales: torch.Tensor, - v_scales: torch.Tensor, -): - num_tokens = slot_mapping.shape[0] - _, num_kv_heads, head_size = key.shape - num_blocks, block_size, _, _ = key_cache.shape - x = 16 // key_cache.element_size() - k_cache_template = torch.empty( - [num_blocks, num_kv_heads, head_size // x, block_size, x], - dtype=key_cache.dtype, - device="meta", - ) - v_cache_template = torch.empty( - [num_blocks, num_kv_heads, block_size // x, head_size, x], - dtype=value_cache.dtype, - device="meta", - ) - new_key_cache = key_cache.view_as(k_cache_template) - new_value_cache = value_cache.view_as(v_cache_template) - QUANT = False - if kv_cache_dtype.startswith("fp8"): - QUANT = True - grid = ( - num_tokens, - num_kv_heads, - ) - reshape_and_cache_shuffle_kernel[grid]( - key, - value, - new_key_cache, - new_value_cache, - slot_mapping, - k_scales, - v_scales, - x, - key.stride(0), - value.stride(0), - block_size, - head_size, - num_kv_heads, - BLOCK_SIZE=head_size, - QUANT=QUANT, - ) - - -@dataclass -class ForwardMetadata: - """Per-batch metadata consumed by ATOM's attention kernels (pa_fwd_asm, mla_decode_fwd, etc.).""" - - # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode - kv_indptr: Optional[torch.Tensor] - kv_indices: Optional[torch.Tensor] - qo_indptr: Optional[torch.Tensor] - kv_last_page_len: Optional[torch.Tensor] - max_q_len: Optional[int] - max_kv_len: Optional[int] - page_table: Optional[torch.Tensor] - kv_lens: Optional[torch.Tensor] - # mla - work_metadata: Optional[torch.Tensor] = None - work_info_set: Optional[torch.Tensor] = None - work_indptr: Optional[torch.Tensor] = None - reduce_indptr: Optional[torch.Tensor] = None - reduce_final_map: Optional[torch.Tensor] = None - reduce_partial_map: Optional[torch.Tensor] = None - fp8_prefill_kv_indices: Optional[torch.Tensor] = None - num_kv_splits: Optional[int] = None - run_graph: Optional[bool] = True - # PA metadata for pa_persistent_fwd (only used in decode mode, non-MLA) - pa_metadata_qo_indptr: Optional[torch.Tensor] = None - pa_metadata_pages_kv_indptr: Optional[torch.Tensor] = None - pa_metadata_kv_indices: Optional[torch.Tensor] = None - pa_metadata_context_lens: Optional[torch.Tensor] = None - pa_metadata_max_qlen: Optional[int] = None - - class ATOMAttnBackendForSgl(AiterAttnBackend): """ATOM's custom attention backend for sglang plugin mode. @@ -410,7 +282,7 @@ def _init_decode_mha(self, bs, kv_indptr, kv_indices, forward_batch): page_table, seq_lens, ) - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + _build_pa_metadata_for_decode(self, bs, tp_q_head_num=self.num_head) else: page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : @@ -826,7 +698,7 @@ def _fixup_page_table(self, forward_batch: ForwardBatch): // self.page_size ) if self.decode_using_pa_ps: - self._build_pa_metadata_for_prefill(forward_batch.batch_size) + _build_pa_metadata_for_prefill(self, forward_batch.batch_size) if ( not self.decode_using_pa_ps and self.page_size > 1 @@ -837,151 +709,6 @@ def _fixup_page_table(self, forward_batch: ForwardBatch): // self.page_size ) - def _ensure_buffer(self, name, size, dtype, zero=True): - """Allocate or reuse a pa_metadata buffer, growing if needed.""" - if self.pa_metadata_buffers is None: - self.pa_metadata_buffers = {} - size_val = size[0] if isinstance(size, (tuple, list)) else size - buf = self.pa_metadata_buffers.get(name) - needs_alloc = ( - buf is None - or buf.shape[0] < size_val - or (isinstance(size, (tuple, list)) and len(buf.shape) < len(size)) - ) - if needs_alloc: - factory = torch.zeros if zero else torch.empty - self.pa_metadata_buffers[name] = factory( - size, dtype=dtype, device=self.device - ) - elif zero: - self.pa_metadata_buffers[name].zero_() - - def _allocate_pa_metadata_buffers(self, buffer_specs): - """Allocate or reuse pa_metadata buffers. - - Args: - buffer_specs: sequence of ((size, dtype), ...) tuples from get_pa_metadata_info_v1, - in order: work_metadata_ptrs, work_indptr, work_info, - reduce_indptr, reduce_final_map, reduce_partial_map. - """ - names = [ - "work_metadata_ptrs", - "work_indptr", - "work_info", - "reduce_indptr", - "reduce_final_map", - "reduce_partial_map", - ] - zero_flags = [False, True, True, True, True, True] - for name, (size, dtype), zero in zip(names, buffer_specs, zero_flags): - self._ensure_buffer(name, size, dtype, zero=zero) - - def _build_pa_metadata_for_decode( - self, - batch_size: int, - tp_q_head_num: Optional[int] = None, - ): - """Build pa_metadata buffers for pa_persistent_fwd in decode mode. - - This method prepares all metadata buffers needed for pa_persistent_fwd kernel. - The metadata can be reused across multiple layers in the same forward pass. - - Args: - batch_size: Batch size for the current forward pass - tp_q_head_num: Number of Q heads per TP rank. If None, uses self.num_head. - """ - max_qlen = 1 - - # Use provided tp_q_head_num or default to self.num_head - if tp_q_head_num is None: - tp_q_head_num = self.num_head - - buffer_specs = get_pa_metadata_info_v1(batch_size, self.num_kv_head) - self._allocate_pa_metadata_buffers(buffer_specs) - qo_indptr = self.pa_decode_qo_indptr[: batch_size + 1] - - # Get context_lens (kv_lens is always set before calling _build_pa_metadata_for_decode) - # Note: kv_lens comes from self.seq_lens which is already int32 - context_lens = self.forward_metadata.kv_lens - - kernel_block_size = self.page_size - num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size - # Use dedicated pa_kv_indptr buffer (similar to self.kv_indptr, but for pa_persistent_fwd) - pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] - pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) - - # Convert page_table to kv_indices (block indices) using Triton kernel to avoid sync - # page_table shape: [batch_size, max_num_blocks_per_seq] - # Note: page_table comes from self.page_table which is already int32 and always set before this call - page_table = self.forward_metadata.page_table - - # Use Triton kernel to gather kv_indices from page_table (avoids high-level indexing sync) - create_flashinfer_kv_indices_triton[(batch_size,)]( - page_table, - self.pa_batch_indices[:batch_size], # [0, 1, 2, ..., batch_size-1] - num_blocks_per_seq, - pages_kv_indptr, - None, # kv_start_idx - self.pa_kv_indices, - page_table.stride(0), - ) - # Use the full buffer - pa_persistent_fwd reads only valid elements based on pages_kv_indptr - kv_indices = self.pa_kv_indices - - get_pa_metadata_v1( - seqlens_qo_indptr=qo_indptr, - pages_kv_indptr=pages_kv_indptr, - context_lens=context_lens.int(), - num_heads_per_head_k=tp_q_head_num // self.num_kv_head, - num_heads_k=self.num_kv_head, - is_causal=True, - work_metadata_ptrs=self.pa_metadata_buffers["work_metadata_ptrs"], - work_indptr=self.pa_metadata_buffers["work_indptr"], - work_info=self.pa_metadata_buffers["work_info"], - reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], - reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], - reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], - kv_granularity=max(kernel_block_size, 16), - block_size=kernel_block_size, - max_seqlen_qo=max_qlen, - uni_seqlen_qo=max_qlen, - fast_mode=True, - topk=-1, - max_split_per_batch=-1, - ) - # Store computed values in ForwardMetadata for reuse in forward_decode - self.forward_metadata.pa_metadata_qo_indptr = qo_indptr - self.forward_metadata.pa_metadata_pages_kv_indptr = pages_kv_indptr - self.forward_metadata.pa_metadata_kv_indices = kv_indices - self.forward_metadata.pa_metadata_context_lens = context_lens - self.forward_metadata.pa_metadata_max_qlen = max_qlen - - def _build_pa_metadata_for_prefill(self, batch_size: int): - """Build metadata for mha_batch_prefill_func in prefill mode. - - This method prepares page-level metadata needed for mha_batch_prefill_func. - The metadata is computed once per forward pass and reused across all layers. - """ - block_size = self.page_size - context_lens = self.forward_metadata.kv_lens - num_blocks_per_seq = (context_lens + block_size - 1) // block_size - - # Page-level kv_indptr (reuse pa_kv_indptr buffer) - pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] - pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) - - # Build kv_indices from page_table using triton kernel - page_table = self.forward_metadata.page_table - create_flashinfer_kv_indices_triton[(batch_size,)]( - page_table, - self.pa_batch_indices[:batch_size], - num_blocks_per_seq, - pages_kv_indptr, - None, # kv_start_idx - self.pa_kv_indices, - page_table.stride(0), - ) - def init_cuda_graph_state( self, max_bs: int, @@ -1036,7 +763,7 @@ def init_cuda_graph_state( if self.decode_using_pa_ps and not self.use_mla: buffer_specs = get_pa_metadata_info_v1(max_bs, self.num_kv_head) - self._allocate_pa_metadata_buffers(buffer_specs) + _allocate_pa_metadata_buffers(self, buffer_specs) def _init_mla_cuda_graph_metadata(self, bs, req_pool_indices, seq_lens): """Shared MLA decode metadata setup for CUDA graph capture/replay.""" @@ -1165,7 +892,7 @@ def init_forward_metadata_capture_cuda_graph( seq_lens_persistent, ) if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + _build_pa_metadata_for_decode(self, bs, tp_q_head_num=self.num_head) elif forward_mode.is_target_verify(): qo_indptr = self.qo_indptr[: bs + 1] qo_indptr[: bs + 1] = torch.arange( @@ -1501,7 +1228,7 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_persistent[:bs], ) if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + _build_pa_metadata_for_decode(self, bs, tp_q_head_num=self.num_head) elif forward_mode.is_target_verify(): bs = len(req_pool_indices) qo_indptr = self.qo_indptr[: bs + 1] @@ -1879,27 +1606,15 @@ def set_kv_buffer_with_layout_shuffle( v_scale, block_size, ): - num_slots, num_kv_heads, head_dim = k_buffer.shape - num_blocks = num_slots // block_size - num_slots_with_block = num_blocks * block_size - k_buffer = k_buffer[:num_slots_with_block].view( - num_blocks, block_size, num_kv_heads, head_dim - ) - v_buffer = v_buffer[:num_slots_with_block].view( - num_blocks, block_size, num_kv_heads, head_dim - ) - kv_cache_dtype = "auto" - if k_buffer.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): - kv_cache_dtype = "fp8" - reshape_and_cache_shuffle_triton( + _set_kv_buffer_with_layout_shuffle( + cache_loc, k, v, k_buffer, v_buffer, - cache_loc, - kv_cache_dtype, k_scale, v_scale, + block_size, ) def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True): diff --git a/atom/plugin/sglang/attention_backend/full_attention/kv_cache.py b/atom/plugin/sglang/attention_backend/full_attention/kv_cache.py new file mode 100644 index 0000000000..00c2f0db0e --- /dev/null +++ b/atom/plugin/sglang/attention_backend/full_attention/kv_cache.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + + +@triton.jit +def reshape_and_cache_shuffle_kernel( + key_ptr, # [num_tokens, num_kv_heads, head_size] + value_ptr, # [num_tokens, num_kv_heads, head_size] + key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x] + value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x] + slot_mapping_ptr, # [num_tokens] + k_scale_ptr, + v_scale_ptr, + x, + k_stride0, + v_stride0, + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE: tl.constexpr, + QUANT: tl.constexpr, +): + tid = tl.program_id(0) + head_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + src_offset_k = tid * k_stride0 + head_id * head_size + src_offset_v = tid * v_stride0 + head_id * head_size + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + dst_offset = ( + block_id * num_kv_heads * head_size * block_size + + head_id * head_size * block_size + ) + dst_k_shuffle_offset = ( + dst_offset + offset // x * block_size * x + block_offset * x + offset % x + ) + dst_v_shuffle_offset = ( + dst_offset + block_offset // x * head_size * x + offset * x + block_offset % x + ) + k_val = tl.load(key_ptr + src_offset_k + offset) + v_val = tl.load(value_ptr + src_offset_v + offset) + if QUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = key_cache_ptr.type.element_ty + v_dtype = value_cache_ptr.type.element_ty + k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype) + v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype) + tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) + tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + + +def reshape_and_cache_shuffle_triton( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scales: torch.Tensor, + v_scales: torch.Tensor, +): + num_tokens = slot_mapping.shape[0] + _, num_kv_heads, head_size = key.shape + num_blocks, block_size, _, _ = key_cache.shape + x = 16 // key_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=key_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=value_cache.dtype, + device="meta", + ) + new_key_cache = key_cache.view_as(k_cache_template) + new_value_cache = value_cache.view_as(v_cache_template) + quant = kv_cache_dtype.startswith("fp8") + grid = ( + num_tokens, + num_kv_heads, + ) + reshape_and_cache_shuffle_kernel[grid]( + key, + value, + new_key_cache, + new_value_cache, + slot_mapping, + k_scales, + v_scales, + x, + key.stride(0), + value.stride(0), + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE=head_size, + QUANT=quant, + ) + + +def set_kv_buffer_with_layout_shuffle( + cache_loc, + k, + v, + k_buffer, + v_buffer, + k_scale, + v_scale, + block_size, +): + num_slots, num_kv_heads, head_dim = k_buffer.shape + num_blocks = num_slots // block_size + num_slots_with_block = num_blocks * block_size + k_buffer = k_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + v_buffer = v_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + kv_cache_dtype = "auto" + if k_buffer.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + kv_cache_dtype = "fp8" + reshape_and_cache_shuffle_triton( + k, + v, + k_buffer, + v_buffer, + cache_loc, + kv_cache_dtype, + k_scale, + v_scale, + ) diff --git a/atom/plugin/sglang/attention_backend/full_attention/metadata.py b/atom/plugin/sglang/attention_backend/full_attention/metadata.py new file mode 100644 index 0000000000..b66feaa756 --- /dev/null +++ b/atom/plugin/sglang/attention_backend/full_attention/metadata.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class ForwardMetadata: + """Per-batch metadata consumed by SGLang full-attention backend kernels.""" + + # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode + kv_indptr: Optional[torch.Tensor] + kv_indices: Optional[torch.Tensor] + qo_indptr: Optional[torch.Tensor] + kv_last_page_len: Optional[torch.Tensor] + max_q_len: Optional[int] + max_kv_len: Optional[int] + page_table: Optional[torch.Tensor] + kv_lens: Optional[torch.Tensor] + # MLA metadata + work_metadata: Optional[torch.Tensor] = None + work_info_set: Optional[torch.Tensor] = None + work_indptr: Optional[torch.Tensor] = None + reduce_indptr: Optional[torch.Tensor] = None + reduce_final_map: Optional[torch.Tensor] = None + reduce_partial_map: Optional[torch.Tensor] = None + fp8_prefill_kv_indices: Optional[torch.Tensor] = None + num_kv_splits: Optional[int] = None + run_graph: Optional[bool] = True + # PA metadata for pa_persistent_fwd (only used in decode mode, non-MLA) + pa_metadata_qo_indptr: Optional[torch.Tensor] = None + pa_metadata_pages_kv_indptr: Optional[torch.Tensor] = None + pa_metadata_kv_indices: Optional[torch.Tensor] = None + pa_metadata_context_lens: Optional[torch.Tensor] = None + pa_metadata_max_qlen: Optional[int] = None diff --git a/atom/plugin/sglang/attention_backend/full_attention/pa_metadata.py b/atom/plugin/sglang/attention_backend/full_attention/pa_metadata.py new file mode 100644 index 0000000000..5ab66056e0 --- /dev/null +++ b/atom/plugin/sglang/attention_backend/full_attention/pa_metadata.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from typing import Optional + +import torch +from aiter import get_pa_metadata_info_v1, get_pa_metadata_v1 +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton + + +def _ensure_buffer(backend, name, size, dtype, zero=True): + """Allocate or reuse a pa_metadata buffer, growing if needed.""" + if backend.pa_metadata_buffers is None: + backend.pa_metadata_buffers = {} + size_val = size[0] if isinstance(size, (tuple, list)) else size + buf = backend.pa_metadata_buffers.get(name) + needs_alloc = ( + buf is None + or buf.shape[0] < size_val + or (isinstance(size, (tuple, list)) and len(buf.shape) < len(size)) + ) + if needs_alloc: + factory = torch.zeros if zero else torch.empty + backend.pa_metadata_buffers[name] = factory( + size, dtype=dtype, device=backend.device + ) + elif zero: + backend.pa_metadata_buffers[name].zero_() + + +def allocate_pa_metadata_buffers(backend, buffer_specs): + """Allocate or reuse pa_metadata buffers for the backend.""" + names = [ + "work_metadata_ptrs", + "work_indptr", + "work_info", + "reduce_indptr", + "reduce_final_map", + "reduce_partial_map", + ] + zero_flags = [False, True, True, True, True, True] + for name, (size, dtype), zero in zip(names, buffer_specs, zero_flags): + _ensure_buffer(backend, name, size, dtype, zero=zero) + + +def build_pa_metadata_for_decode( + backend, + batch_size: int, + tp_q_head_num: Optional[int] = None, +): + """Build pa_metadata buffers for pa_persistent_fwd in decode mode.""" + max_qlen = 1 + + if tp_q_head_num is None: + tp_q_head_num = backend.num_head + + buffer_specs = get_pa_metadata_info_v1(batch_size, backend.num_kv_head) + allocate_pa_metadata_buffers(backend, buffer_specs) + qo_indptr = backend.pa_decode_qo_indptr[: batch_size + 1] + + context_lens = backend.forward_metadata.kv_lens + + kernel_block_size = backend.page_size + num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size + pages_kv_indptr = backend.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + page_table = backend.forward_metadata.page_table + + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + backend.pa_batch_indices[:batch_size], + num_blocks_per_seq, + pages_kv_indptr, + None, + backend.pa_kv_indices, + page_table.stride(0), + ) + kv_indices = backend.pa_kv_indices + + get_pa_metadata_v1( + seqlens_qo_indptr=qo_indptr, + pages_kv_indptr=pages_kv_indptr, + context_lens=context_lens.int(), + num_heads_per_head_k=tp_q_head_num // backend.num_kv_head, + num_heads_k=backend.num_kv_head, + is_causal=True, + work_metadata_ptrs=backend.pa_metadata_buffers["work_metadata_ptrs"], + work_indptr=backend.pa_metadata_buffers["work_indptr"], + work_info=backend.pa_metadata_buffers["work_info"], + reduce_indptr=backend.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=backend.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=backend.pa_metadata_buffers["reduce_partial_map"], + kv_granularity=max(kernel_block_size, 16), + block_size=kernel_block_size, + max_seqlen_qo=max_qlen, + uni_seqlen_qo=max_qlen, + fast_mode=True, + topk=-1, + max_split_per_batch=-1, + ) + backend.forward_metadata.pa_metadata_qo_indptr = qo_indptr + backend.forward_metadata.pa_metadata_pages_kv_indptr = pages_kv_indptr + backend.forward_metadata.pa_metadata_kv_indices = kv_indices + backend.forward_metadata.pa_metadata_context_lens = context_lens + backend.forward_metadata.pa_metadata_max_qlen = max_qlen + + +def build_pa_metadata_for_prefill(backend, batch_size: int): + """Build page-level metadata for non-MLA prefill mode.""" + block_size = backend.page_size + context_lens = backend.forward_metadata.kv_lens + num_blocks_per_seq = (context_lens + block_size - 1) // block_size + + pages_kv_indptr = backend.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + page_table = backend.forward_metadata.page_table + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + backend.pa_batch_indices[:batch_size], + num_blocks_per_seq, + pages_kv_indptr, + None, + backend.pa_kv_indices, + page_table.stride(0), + ) diff --git a/atom/plugin/sglang/attention_backend/radix_attention.py b/atom/plugin/sglang/attention_backend/full_attention/radix_attention.py similarity index 98% rename from atom/plugin/sglang/attention_backend/radix_attention.py rename to atom/plugin/sglang/attention_backend/full_attention/radix_attention.py index 8e20eecb14..613458ee37 100644 --- a/atom/plugin/sglang/attention_backend/radix_attention.py +++ b/atom/plugin/sglang/attention_backend/full_attention/radix_attention.py @@ -129,7 +129,7 @@ def forward_impl_plugin_mode( # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) if forward_batch is None: - from atom.plugin.sglang.models.base_model_wrapper import ( + from atom.plugin.sglang.runtime import ( get_current_forward_batch, ) diff --git a/atom/plugin/sglang/attention_backend/sgl_attention_mla.py b/atom/plugin/sglang/attention_backend/sgl_attention_mla.py deleted file mode 100644 index 8a9cbb1f48..0000000000 --- a/atom/plugin/sglang/attention_backend/sgl_attention_mla.py +++ /dev/null @@ -1,1375 +0,0 @@ -"""Sglang-specific MLA forward and weight processing for DeepseekV2/V3. - -DeepSeek MLA (Multi-Latent Attention) forward logic for sglang plugin mode: -absorbed BMM computation, MHA/MLA path dispatch (prefill -> MHA, decode -> MLA), -kv_b_proj weight splitting (w_kc/w_vc), and monkey-patch setup via -setup_deepseek_for_sglang(). - -This module is lazily imported from base_model_wrapper.py only when running in -sglang plugin mode (``is_sglang() == True``). Keeping all sglang-dependent -imports here avoids crashing when sglang is not installed. - -TODO: rewrite this file once sglang's attention flow is unified into ATOM's -attention layer — the MLA absorbed path and MHA dispatch will then be handled -natively by ATOM's attention ops, making this sglang-specific module unnecessary. -""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING, Any, NamedTuple, Optional - -import torch -from aiter import dtypes -from aiter.dist.parallel_state import get_tensor_model_parallel_world_size, get_tp_group -from atom.model_ops.base_attention import Attention -from atom.model_ops.attention_mla import ( - concat_and_cache_mla, - dynamic_per_batched_tensor_quant, - fused_qk_rope_concat_and_cache_mla, -) -from atom.models.utils import maybe_prefix -from atom.models.deepseek_v2 import _fuse_rmsnorm_quant - -# sglang imports -from sglang.srt.layers.communicator import AttentionInputs, get_attn_tp_context -from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp -from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode -from sglang.srt.models.deepseek_common.utils import ( - _use_aiter_gfx95, - _is_hip, - _is_cpu, - _is_cpu_amx_available, - _is_cuda, - _is_fp8_fnuz, - _is_npu, - awq_dequantize_func, -) -from sglang.srt.layers.quantization.rocm_mxfp4_utils import ( - batched_gemm_afp4wfp4_pre_quant, -) -from aiter.utility.fp4_utils import e8m0_to_f32, mxfp4_to_f32 -from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import ( - batched_gemm_a16wfp4, -) -from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, -) -from sglang.srt.layers.quantization.fp8_kernel import ( - per_tensor_quant_mla_fp8, - per_token_group_quant_mla_deep_gemm_masked_fp8, -) -from sglang.srt.utils import bind_or_assign, get_bool_env_var - -if TYPE_CHECKING: - from atom.models.deepseek_v2 import DeepseekV2MLAAttention - - -logger = logging.getLogger(__name__) - - -# bmm_fp8 custom-op wrapper (adapted from sglang forward_mla.py) -if _is_cuda: - from sgl_kernel import bmm_fp8 as _raw_bmm_fp8 - from sglang.srt.utils.custom_op import register_custom_op - - @register_custom_op(mutates_args=["out"]) - def _bmm_fp8_op( - A: torch.Tensor, - B: torch.Tensor, - out: torch.Tensor, - A_scale: torch.Tensor, - B_scale: torch.Tensor, - ) -> None: - _raw_bmm_fp8(A, B, A_scale, B_scale, out.dtype, out) - - def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): - if out is None: - out = torch.empty( - (A.shape[0], A.shape[1], B.shape[2]), - device=A.device, - dtype=dtype, - ) - _bmm_fp8_op(A, B, out, A_scale, B_scale) - return out - -else: - - def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): - raise RuntimeError("bmm_fp8 requires CUDA (sgl_kernel)") - - -# NamedTuple for prepare → core data flow -class SglPrepareResult(NamedTuple): - q_pe: torch.Tensor - k_pe: torch.Tensor - q_nope_out: torch.Tensor - k_nope: torch.Tensor - forward_batch: Any - zero_allocator: Any - positions: torch.Tensor - topk_indices: Optional[torch.Tensor] - llama_4_scaling: Optional[Any] - - -class SglMhaPrepareResult(NamedTuple): - q: torch.Tensor - k: torch.Tensor - v: torch.Tensor - forward_batch: Any - - -def _unwrap_linear_output(output: Any) -> torch.Tensor: - """Normalize ATOM/public-SGLang linear outputs to a tensor.""" - if isinstance(output, tuple): - return output[0] - return output - - -def _linear_quant_type_value(linear: Any) -> Optional[int]: - quant_type = getattr(linear, "quant_type", None) - return None if quant_type is None else getattr(quant_type, "value", quant_type) - - -def _fuse_qk_rmsnorm_and_q_quant( - attn: DeepseekV2MLAAttention, - q: torch.Tensor, - k_nope: torch.Tensor, - *, - output_unquantized_q: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]: - """Fuse q/k RMSNorm and q quant using ATOM's DeepSeek-V2 path.""" - - (q_quantized, q_scale), q_normed, k_nope_normed, _ = _fuse_rmsnorm_quant( - q, - attn.q_a_layernorm.weight, - attn.q_a_layernorm.eps, - k_nope, - attn.kv_a_layernorm.weight, - attn.kv_a_layernorm.eps, - None, - dtype_quant=attn.quant_dtype, - shuffle=False, - scale_shuffle_padding=False, - group_size=128, - quant_type=_linear_quant_type_value(attn.q_b_proj), - output_unquantized_inp1=output_unquantized_q, - transpose_scale=True, - ) - return q_quantized, q_scale, q_normed, k_nope_normed - - -def _fuse_qk_rmsnorm( - attn: DeepseekV2MLAAttention, - q: torch.Tensor, - k_nope: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - """Fuse q/k RMSNorm without quantizing q.""" - (q_normed, _), _, k_nope_normed, _ = _fuse_rmsnorm_quant( - q, - attn.q_a_layernorm.weight, - attn.q_a_layernorm.eps, - k_nope, - attn.kv_a_layernorm.weight, - attn.kv_a_layernorm.eps, - None, - dtype_quant=torch.bfloat16, - shuffle=False, - scale_shuffle_padding=False, - group_size=128, - quant_type=None, - output_unquantized_inp1=False, - transpose_scale=False, - ) - return q_normed, k_nope_normed - - -def _prepare_weight_for_bmm( - weight: torch.Tensor, in_dim: int, out_dim: int -) -> torch.Tensor: - """Normalize absorbed weight layout for torch.bmm fallback.""" - if weight.shape[1] == in_dim and weight.shape[2] == out_dim: - return weight - if weight.shape[1] == out_dim and weight.shape[2] == in_dim: - return weight.transpose(-2, -1) - raise RuntimeError( - "Unexpected absorbed weight shape for bmm fallback: " - f"{tuple(weight.shape)} with in_dim={in_dim}, out_dim={out_dim}" - ) - - -# Init helpers -def init_sgl_attrs( - attn: DeepseekV2MLAAttention, - config, - kv_cache_dtype: str = "bf16", -) -> None: - """Initialise sglang-only attributes on DeepseekV2MLAAttention.""" - from sglang.srt.configs.model_config import is_deepseek_nsa - - attn.use_nsa = is_deepseek_nsa(config) - attn.use_deep_gemm_bmm = False - attn.alt_stream = None - attn.kv_cache_dtype = kv_cache_dtype - attn.use_fused_qk_rope_concat_and_cache_mla = _use_aiter_gfx95 - attn.current_sgl_plugin_attn_path = None - attn.w_kc, attn.w_vc = None, None - attn.w_scale = None - attn.w_scale_k = None - attn.w_scale_v = None - attn.attn_mha = Attention( - num_heads=attn.num_local_heads, - head_dim=attn.qk_head_dim, - scale=attn.scaling, - num_kv_heads=attn.num_local_heads, - kv_cache_dtype=kv_cache_dtype, - layer_num=attn.layer_num, - use_mla=False, - v_head_dim=attn.v_head_dim, - prefix=maybe_prefix(attn.prefix, "attn_mha"), - ) - if hasattr(attn.attn_mha, "attn"): - attn.attn_mha.attn.kv_b_proj = None - - -# Absorbed batched-matmul (shared by prepare and core) -def mla_absorbed_bmm( - attn: DeepseekV2MLAAttention, - inp: torch.Tensor, - weight: torch.Tensor, - weight_scale: Optional[torch.Tensor], - weight_scale_k: Optional[torch.Tensor], - out_dim: int, -) -> torch.Tensor: - """Batched matmul for MLA absorbed weights (w_kc / w_vc). - - Handles deep_gemm, mxfp4, fp8-triton, fp8-cublas, and bf16 fallback paths. - inp: (num_tokens, num_heads, in_dim) — token-major - Returns: (num_tokens, num_heads, out_dim) — token-major - """ - effective_weight_scale = ( - weight_scale_k if weight_scale_k is not None else weight_scale - ) - - if attn.use_deep_gemm_bmm: - from sglang.srt.layers import deep_gemm_wrapper - - val, scale, masked_m, expected_m, aligned_m = ( - per_token_group_quant_mla_deep_gemm_masked_fp8(inp.transpose(0, 1)) - ) - out = inp.new_empty((attn.num_local_heads, aligned_m, out_dim)) - deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( - (val, scale), - (weight, weight_scale_k), - out, - masked_m, - expected_m, - ) - return out[:, :expected_m, :].transpose(0, 1) - - if _is_hip: - if _use_aiter_gfx95 and weight.dtype == torch.uint8: - x = inp.transpose(0, 1) - out = torch.empty( - x.shape[0], - x.shape[1], - weight.shape[2], - device=x.device, - dtype=torch.bfloat16, - ) - batched_gemm_afp4wfp4_pre_quant( - x, - weight.transpose(-2, -1), - weight_scale_k.transpose(-2, -1), - torch.bfloat16, - out, - ) - return out.transpose(0, 1) - - if (_use_aiter_gfx95 and weight.dtype == torch.float8_e4m3fn) or ( - get_is_capture_mode() and weight.dtype == torch.float8_e4m3fnuz - ): - x = inp.transpose(0, 1) - out = ( - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( - X=x, - WQ=weight, - w_scale=effective_weight_scale, - group_size=128, - YQ=None, - transpose_bm=True, - transpose_bm_in=False, - dtype=torch.bfloat16, - ) - ) - return out - - w_bf16 = _prepare_weight_for_bmm(weight, inp.shape[-1], out_dim).to( - torch.bfloat16 - ) - if effective_weight_scale is not None: - w_bf16 = w_bf16 * effective_weight_scale - out = torch.bmm( - inp.to(torch.bfloat16).transpose(0, 1), - w_bf16, - ) - return out.transpose(0, 1) - - # CUDA fp8 path - if weight.dtype == torch.float8_e4m3fn: - val, scale = per_tensor_quant_mla_fp8( - inp.transpose(0, 1), - torch.zeros((1,), dtype=torch.float32, device=inp.device), - ) - out = bmm_fp8(val, weight, scale, effective_weight_scale, torch.bfloat16) - return out.transpose(0, 1) - - # bf16 fallback - return torch.bmm(inp.transpose(0, 1), weight).transpose(0, 1) - - -def mla_v_up_proj( - attn: DeepseekV2MLAAttention, - inp: torch.Tensor, - weight: torch.Tensor, - weight_scale: Optional[torch.Tensor], - weight_scale_k: Optional[torch.Tensor], - out_dim: int, -) -> torch.Tensor: - """Project MLA decode output to a flat o_proj input.""" - effective_weight_scale = ( - weight_scale_k if weight_scale_k is not None else weight_scale - ) - if _is_hip and _use_aiter_gfx95 and weight.dtype == torch.uint8: - x = inp.transpose(0, 1) - out = torch.empty( - (inp.shape[0], attn.num_local_heads * out_dim), - device=inp.device, - dtype=torch.bfloat16, - ) - out_3d = out.view(inp.shape[0], attn.num_local_heads, out_dim) - batched_gemm_a16wfp4( - x, - weight.transpose(-2, -1), - weight_scale_k.transpose(-2, -1), - dtype=torch.bfloat16, - y=out_3d, - transpose_bm=True, - prequant=True, - y_scale=None, - ) - return out - - if _is_hip and ( - (_use_aiter_gfx95 and weight.dtype == torch.float8_e4m3fn) - or (get_is_capture_mode() and weight.dtype == torch.float8_e4m3fnuz) - ): - x = inp.transpose(0, 1) - out = torch.empty( - (inp.shape[0], attn.num_local_heads * out_dim), - device=inp.device, - dtype=torch.bfloat16, - ) - out_3d = out.view(inp.shape[0], attn.num_local_heads, out_dim) - batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( - X=x, - WQ=weight, - w_scale=effective_weight_scale, - group_size=128, - YQ=out_3d, - transpose_bm=True, - transpose_bm_in=False, - dtype=torch.bfloat16, - ) - return out - - return mla_absorbed_bmm( - attn, inp, weight, weight_scale, weight_scale_k, out_dim - ).flatten(1, 2) - - -# Forward: prepare → core -def forward_sgl_prepare( - attn: DeepseekV2MLAAttention, - positions: torch.Tensor, - hidden_states: torch.Tensor, - **model_kwargs, -) -> SglPrepareResult: - """Prepare QKV for sglang MLA attention (adapted from sglang forward_absorb_prepare).""" - hidden_states_scale = None - if isinstance(hidden_states, tuple): - hidden_states, hidden_states_scale = hidden_states - - forward_batch = model_kwargs.get("forward_batch", None) - zero_allocator = model_kwargs.get("zero_allocator", None) - llama_4_scaling = model_kwargs.get("llama_4_scaling", None) - q_lora = None - topk_indices = None - - if attn.q_lora_rank is not None: - q, latent_cache = ( - get_attn_tp_context() - .fetch_qkv_latent() - .split( - [attn.q_lora_rank, attn.kv_lora_rank + attn.qk_rope_head_dim], - dim=-1, - ) - ) - - if ( - q.shape[0] != positions.shape[0] - and get_tensor_model_parallel_world_size() > 1 - ): - qkv_lora = torch.cat([q, latent_cache], dim=-1) - qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) - if qkv_lora.shape[0] < positions.shape[0]: - raise RuntimeError( - f"qkv_lora gather mismatch: got {qkv_lora.shape[0]}, " - f"expected {positions.shape[0]}" - ) - qkv_lora = qkv_lora[: positions.shape[0]] - q, latent_cache = torch.split( - qkv_lora, - [attn.q_lora_rank, attn.kv_lora_rank + attn.qk_rope_head_dim], - dim=-1, - ) - - k_nope = latent_cache[..., : attn.kv_lora_rank] - q_scale = None - - # Reuse native ATOM gating for q/k RMSNorm fusion. Quant fusion is used - # when DeepSeek enables qknorm-quant; otherwise keep the non-quant fused - # path aligned with native ATOM before falling back to plain layernorm. - if getattr(attn, "fuse_qknorm_quant", False): - q, q_scale, q_lora, k_nope = _fuse_qk_rmsnorm_and_q_quant( - attn, - q, - k_nope, - output_unquantized_q=attn.use_nsa, - ) - elif getattr(attn, "fuse_qknorm", False): - q, k_nope = _fuse_qk_rmsnorm(attn, q, k_nope) - # Otherwise keep the original overlap path for unfused qk norm. - elif attn.alt_stream is not None and get_is_capture_mode(): - current_stream = torch.cuda.current_stream() - attn.alt_stream.wait_stream(current_stream) - q = attn.q_a_layernorm(q) - with torch.cuda.stream(attn.alt_stream): - k_nope = attn.kv_a_layernorm(k_nope) - current_stream.wait_stream(attn.alt_stream) - else: - q = attn.q_a_layernorm(q) - k_nope = attn.kv_a_layernorm(k_nope) - - if attn.use_nsa: - if q_lora is None: - q_lora = q - - # overlap q_b_proj and indexer during decode - if ( - attn.alt_stream is not None - and get_is_capture_mode() - and forward_batch.forward_mode.is_decode_or_idle() - and q_lora is not None - ): - current_stream = torch.cuda.current_stream() - attn.alt_stream.wait_stream(current_stream) - with torch.cuda.stream(attn.alt_stream): - k_nope = k_nope.unsqueeze(1) - q = _unwrap_linear_output( - attn.q_b_proj(q, q_scale) - if q_scale is not None - else attn.q_b_proj(q) - ).view(-1, attn.num_local_heads, attn.qk_head_dim) - topk_indices = attn.indexer( - x=hidden_states, - q_lora=q_lora, - positions=positions, - forward_batch=forward_batch, - layer_id=attn.layer_num, - ) - current_stream.wait_stream(attn.alt_stream) - else: - k_nope = k_nope.unsqueeze(1) - q = _unwrap_linear_output( - attn.q_b_proj(q, q_scale) if q_scale is not None else attn.q_b_proj(q) - ).view(-1, attn.num_local_heads, attn.qk_head_dim) - if q_lora is not None: - topk_indices = attn.indexer( - x=hidden_states, - q_lora=q_lora, - positions=positions, - forward_batch=forward_batch, - layer_id=attn.layer_num, - ) - else: - q = _unwrap_linear_output(attn.q_proj(hidden_states)).view( - -1, attn.num_local_heads, attn.qk_head_dim - ) - latent_cache = _unwrap_linear_output(attn.kv_a_proj_with_mqa(hidden_states)) - k_nope = latent_cache[..., : attn.kv_lora_rank] - k_nope = attn.kv_a_layernorm(k_nope).unsqueeze(1) - - q_nope, q_pe = q.split([attn.qk_nope_head_dim, attn.qk_rope_head_dim], dim=-1) - k_pe = latent_cache[..., attn.kv_lora_rank :].unsqueeze(1) - - q_nope_out = mla_absorbed_bmm( - attn, q_nope, attn.w_kc, attn.w_scale, attn.w_scale_k, attn.kv_lora_rank - ) - - if attn.rotary_emb is not None and not attn.use_fused_qk_rope_concat_and_cache_mla: - q_pe, k_pe = attn.rotary_emb(positions, q_pe, k_pe) - - if nsa_use_prefill_cp(forward_batch): - k_nope, k_pe = attn.rebuild_cp_kv_cache( - latent_cache, forward_batch, k_nope, k_pe - ) - - return SglPrepareResult( - q_pe=q_pe, - k_pe=k_pe, - q_nope_out=q_nope_out, - k_nope=k_nope, - forward_batch=forward_batch, - zero_allocator=zero_allocator, - positions=positions, - topk_indices=topk_indices, - llama_4_scaling=llama_4_scaling, - ) - - -def forward_sgl_core( - attn: DeepseekV2MLAAttention, - prepared: SglPrepareResult, -) -> torch.Tensor: - """Core MLA attention computation for sglang (adapted from sglang forward_absorb_core).""" - save_kv_cache = True - - if attn.use_fused_qk_rope_concat_and_cache_mla: - mla_attn = _get_sglang_radix_attn(attn.mla_attn) - kv_cache = prepared.forward_batch.token_to_kv_pool.get_key_buffer( - mla_attn.layer_id - ) - q_out_dtype = ( - dtypes.fp8 - if attn.kv_cache_dtype == "fp8_e4m3" - else prepared.q_nope_out.dtype - ) - q = torch.empty( - ( - prepared.q_nope_out.shape[0], - attn.num_local_heads, - attn.kv_lora_rank + attn.qk_rope_head_dim, - ), - dtype=q_out_dtype, - device=prepared.q_nope_out.device, - ) - - fused_qk_rope_concat_and_cache_mla( - prepared.q_nope_out, - prepared.q_pe, - prepared.k_nope, - prepared.k_pe, - kv_cache, - q, - prepared.forward_batch.out_cache_loc, - mla_attn.k_scale, - mla_attn.k_scale, - prepared.positions, - attn.rotary_emb.cos_cache, - attn.rotary_emb.sin_cache, - is_neox=attn.rotary_emb.is_neox_style, - is_nope_first=True, - ) - # Decode/speculative MLA consumes q plus packed MLA cache directly. - k = None - v = None - save_kv_cache = False - else: - q = torch.cat([prepared.q_nope_out, prepared.q_pe], dim=-1) - k = torch.cat([prepared.k_nope, prepared.k_pe], dim=-1) - v = prepared.k_nope - - if prepared.llama_4_scaling is not None: - q = q * prepared.llama_4_scaling - - extra_kwargs = {} - if prepared.topk_indices is not None: - extra_kwargs["topk_indices"] = prepared.topk_indices - - attn_output = attn.mla_attn( - q, - k, - v, - forward_batch=prepared.forward_batch, - save_kv_cache=save_kv_cache, - **extra_kwargs, - ) - attn_output = attn_output.view(-1, attn.num_local_heads, attn.kv_lora_rank) - - # up-proj by w_vc - attn_bmm_output = mla_v_up_proj( - attn, attn_output, attn.w_vc, attn.w_scale, attn.w_scale_v, attn.v_head_dim - ) - - return attn.o_proj(attn_bmm_output) - - -def _dispatch_sgl_plugin_attn_path(forward_batch) -> str: - """Decide the attention algorithm for this batch based on forward_mode. - - Returns "mha" for extend/prefill-style batches (uses standard Q×K×V - with flash_attn) or "mla" for decode/verify batches (uses absorbed - weights + mla_decode_fwd). - - This is the per-batch *routing* decision, distinct from - ``_can_run_sgl_mha_now`` which is a *capability* gate checking whether - the model configuration supports the MHA path at all. - """ - if forward_batch.forward_mode.is_extend_without_speculative(): - return "mha" - - if forward_batch.forward_mode.is_draft_extend(): - # The explicit K/V path is only memory-friendly for no-prefix draft - # extend. With prefix/context, SGLang's MLA backend has to materialize - # full k_prefix/v_prefix from latent cache, which can OOM during graph - # capture. Use absorbed MLA for those batches until chunked prefix - # expansion exists here. - extend_prefix_lens_cpu = getattr(forward_batch, "extend_prefix_lens_cpu", None) - if extend_prefix_lens_cpu is not None and not any(extend_prefix_lens_cpu): - return "mha" - - return "mla" - - -def forward_sgl_plugin_mode_mla( - attn: DeepseekV2MLAAttention, - positions: torch.Tensor, - hidden_states: torch.Tensor, - **model_kwargs, -) -> torch.Tensor: - prepared = forward_sgl_prepare(attn, positions, hidden_states, **model_kwargs) - from atom.utils.forward_context import get_forward_context - - if get_forward_context().context.is_dummy_run: - base_hidden_states = ( - hidden_states[0] if isinstance(hidden_states, tuple) else hidden_states - ) - dummy_output = base_hidden_states.new_empty( - (base_hidden_states.shape[0], base_hidden_states.shape[-1]) - ) - return dummy_output - return forward_sgl_core(attn, prepared) - - -def _get_sglang_radix_attn(attn_module): - return attn_module.attn if hasattr(attn_module, "attn") else attn_module - - -def _concat_mha_k_for_sgl_mha( - attn: DeepseekV2MLAAttention, - k_nope: torch.Tensor, - k_pe: torch.Tensor, -) -> torch.Tensor: - k = k_nope.new_empty( - k_nope.shape[0], - attn.num_local_heads, - attn.qk_nope_head_dim + attn.qk_rope_head_dim, - ) - - try: - from sglang.srt.layers.attention.utils import concat_and_cast_mha_k_triton - except ImportError as exc: - logger.warning( - "Unable to import concat_and_cast_mha_k_triton; " - "falling back to torch native MHA K concat: %s", - exc, - ) - else: - concat_and_cast_mha_k_triton(k, k_nope, k_pe) - return k - - k[..., : attn.qk_nope_head_dim] = k_nope - k[..., attn.qk_nope_head_dim :] = k_pe - return k - - -def _set_mla_kv_buffer_for_mha( - attn: DeepseekV2MLAAttention, - kv_a: torch.Tensor, - k_pe: torch.Tensor, - forward_batch, -) -> None: - attn_mha = _get_sglang_radix_attn(attn.attn_mha) - - kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(attn_mha.layer_id) - concat_and_cache_mla( - kv_a, - k_pe.squeeze(1), - kv_cache, - forward_batch.out_cache_loc.flatten(), - kv_cache_dtype=( - "fp8" if str(attn.kv_cache_dtype).startswith("fp8") else "auto" - ), - scale=attn_mha.k_scale, - ) - - -def _is_mxfp4_kv_b_proj(attn: DeepseekV2MLAAttention) -> bool: - kv_b_proj = attn.kv_b_proj - params_dtype = getattr(kv_b_proj, "params_dtype", None) - if params_dtype == dtypes.fp4x2 or params_dtype == getattr( - torch, "float4_e2m1fn_x2", None - ): - return True - - quant_type = getattr(kv_b_proj, "quant_type", None) - if getattr(quant_type, "name", "") == "per_1x32" or str(quant_type).endswith( - "per_1x32" - ): - return True - - quant_method = getattr(kv_b_proj, "quant_method", None) - quant_config = getattr(quant_method, "quant_config", None) - return bool( - quant_config is not None - and quant_config.get_name() == "quark" - and kv_b_proj.weight.dtype == torch.uint8 - ) - - -def _can_run_sgl_mha_now(attn: DeepseekV2MLAAttention, forward_batch) -> bool: - """Check if the model configuration supports the MHA attention path. - - This is a *capability* gate — NSA models cannot use the MHA path. - MXFP4 ``kv_b_proj`` weights are supported here because the MHA prepare - path expands K/V through ``attn.kv_b_proj`` itself, which already owns - the per_1x32 GEMM implementation. Distinct from - ``_dispatch_sgl_plugin_attn_path`` which routes each batch. - """ - del forward_batch - if attn.use_nsa: - return False - if attn.kv_b_proj.weight.dtype == torch.uint8 and not _is_mxfp4_kv_b_proj(attn): - return False - return True - - -def forward_sgl_mha_prepare( - attn: DeepseekV2MLAAttention, - positions: torch.Tensor, - hidden_states: torch.Tensor, - **model_kwargs, -) -> SglMhaPrepareResult: - - forward_batch = model_kwargs.get("forward_batch", None) - if forward_batch is None: - raise RuntimeError("forward_batch is required in forward_sgl_mha_prepare") - - hidden_states_scale = None - if isinstance(hidden_states, tuple): - hidden_states, hidden_states_scale = hidden_states - - attn_mha = _get_sglang_radix_attn(attn.attn_mha) - if getattr(attn_mha, "kv_b_proj", None) is None: - attn_mha.kv_b_proj = attn.kv_b_proj - - if attn.q_lora_rank is not None: - q, latent_cache = ( - get_attn_tp_context() - .fetch_qkv_latent() - .split( - [attn.q_lora_rank, attn.kv_lora_rank + attn.qk_rope_head_dim], - dim=-1, - ) - ) - - if ( - q.shape[0] != positions.shape[0] - and get_tensor_model_parallel_world_size() > 1 - ): - qkv_lora = torch.cat([q, latent_cache], dim=-1) - qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) - if qkv_lora.shape[0] < positions.shape[0]: - raise RuntimeError( - f"qkv_lora gather mismatch: got {qkv_lora.shape[0]}, " - f"expected {positions.shape[0]}" - ) - qkv_lora = qkv_lora[: positions.shape[0]] - q, latent_cache = torch.split( - qkv_lora, - [attn.q_lora_rank, attn.kv_lora_rank + attn.qk_rope_head_dim], - dim=-1, - ) - - if _use_aiter_gfx95 and attn.q_b_proj.weight.dtype == torch.float8_e4m3fn: - (q, q_scale), _, _, _ = _fuse_rmsnorm_quant( - q, - attn.q_a_layernorm.weight, - attn.q_a_layernorm.eps, - None, - None, - None, - res1=None, - dtype_quant=torch.float8_e4m3fn, - group_size=128, - quant_type=_linear_quant_type_value(attn.q_b_proj), - output_unquantized_inp1=False, - transpose_scale=True, - ) - q = _unwrap_linear_output(attn.q_b_proj(q, q_scale)).view( - -1, attn.num_local_heads, attn.qk_head_dim - ) - else: - q = attn.q_a_layernorm(q) - q = _unwrap_linear_output(attn.q_b_proj(q)).view( - -1, attn.num_local_heads, attn.qk_head_dim - ) - else: - q = _unwrap_linear_output(attn.q_proj(hidden_states, hidden_states_scale)).view( - -1, attn.num_local_heads, attn.qk_head_dim - ) - latent_cache = _unwrap_linear_output( - attn.kv_a_proj_with_mqa(hidden_states, hidden_states_scale) - ) - - _, q_pe = q.split([attn.qk_nope_head_dim, attn.qk_rope_head_dim], dim=-1) - kv_a, _ = latent_cache.split([attn.kv_lora_rank, attn.qk_rope_head_dim], dim=-1) - latent_cache = latent_cache.unsqueeze(1) - - if _use_aiter_gfx95 and attn.kv_b_proj.weight.dtype == torch.float8_e4m3fn: - (kv_a_quanted, kv_a_quanted_scale), kv_a, _, _ = _fuse_rmsnorm_quant( - kv_a, - attn.kv_a_layernorm.weight, - attn.kv_a_layernorm.eps, - None, - None, - None, - res1=None, - dtype_quant=torch.float8_e4m3fn, - group_size=128, - quant_type=_linear_quant_type_value(attn.kv_b_proj), - output_unquantized_inp1=True, - transpose_scale=True, - ) - else: - kv_a_quanted = None - kv_a = attn.kv_a_layernorm(kv_a) - - k_pe = latent_cache[:, :, attn.kv_lora_rank :] - if attn.rotary_emb is not None: - q_pe, k_pe = attn.rotary_emb(positions, q_pe, k_pe) - q[..., attn.qk_nope_head_dim :] = q_pe - - _set_mla_kv_buffer_for_mha(attn, kv_a, k_pe, forward_batch) - - if kv_a_quanted is not None: - kv = _unwrap_linear_output(attn.kv_b_proj(kv_a_quanted, kv_a_quanted_scale)) - else: - kv = _unwrap_linear_output(attn.kv_b_proj(kv_a)) - kv = kv.view(-1, attn.num_local_heads, attn.qk_nope_head_dim + attn.v_head_dim) - k_nope = kv[..., : attn.qk_nope_head_dim] - v = kv[..., attn.qk_nope_head_dim :] - k = _concat_mha_k_for_sgl_mha(attn, k_nope, k_pe) - return SglMhaPrepareResult(q=q, k=k, v=v, forward_batch=forward_batch) - - -def forward_sgl_mha_core( - attn: DeepseekV2MLAAttention, - prepared: SglMhaPrepareResult, -) -> torch.Tensor: - attn_output = attn.attn_mha( - prepared.q, - prepared.k, - prepared.v, - forward_batch=prepared.forward_batch, - save_kv_cache=False, - ) - attn_output = attn_output.reshape(-1, attn.num_local_heads * attn.v_head_dim) - return attn.o_proj(attn_output) - - -def forward_sgl_plugin_mode_mha( - attn: DeepseekV2MLAAttention, - positions: torch.Tensor, - hidden_states: torch.Tensor, - **model_kwargs, -) -> torch.Tensor: - forward_batch = model_kwargs.get("forward_batch", None) - if forward_batch is None: - raise RuntimeError("forward_batch is required in forward_sgl_plugin_mode_mha") - if not _can_run_sgl_mha_now(attn, forward_batch): - attn.current_sgl_plugin_attn_path = "mla_fallback" - return forward_sgl_plugin_mode_mla( - attn, - positions, - hidden_states, - **model_kwargs, - ) - prepared = forward_sgl_mha_prepare(attn, positions, hidden_states, **model_kwargs) - return forward_sgl_mha_core(attn, prepared) - - -def prepare_qkv_latent( - attn: DeepseekV2MLAAttention, - hidden_states: torch.Tensor, - forward_batch, -) -> torch.Tensor: - """Prepare QKV latent tensor for the sglang communicator.""" - assert attn.q_lora_rank is not None - hidden_states_scale = None - if isinstance(hidden_states, tuple): - hidden_states, hidden_states_scale = hidden_states - qkv_lora = attn.fused_qkv_a_proj(hidden_states, hidden_states_scale) - - # Fallback: when communicator does not enable input_scattered gather, - # force qkv latent token dimension to align with positions. - expected_tokens = 0 - if hasattr(forward_batch, "positions") and forward_batch.positions is not None: - expected_tokens = int(forward_batch.positions.shape[0]) - if expected_tokens <= 0: - expected_tokens = int(getattr(forward_batch, "seq_lens_sum", 0) or 0) - - if ( - expected_tokens > 0 - and qkv_lora.shape[0] != expected_tokens - and get_tensor_model_parallel_world_size() > 1 - ): - qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) - if qkv_lora.shape[0] > expected_tokens: - qkv_lora = qkv_lora[:expected_tokens] - elif qkv_lora.shape[0] < expected_tokens: - raise RuntimeError( - f"prepare_qkv_latent gather mismatch: got {qkv_lora.shape[0]}, " - f"expected {expected_tokens}" - ) - return qkv_lora - - -# Top-level forward entry point -def forward_sgl_plugin_mode( - attn: DeepseekV2MLAAttention, - positions: torch.Tensor, - hidden_states: torch.Tensor, - **model_kwargs, -) -> torch.Tensor: - """Full MLA forward in sglang plugin mode.""" - forward_batch = model_kwargs.get("forward_batch", None) - if forward_batch is None: - raise RuntimeError("forward_batch is required in forward_sgl_plugin_mode") - - attn_tp_context = get_attn_tp_context() - with attn_tp_context.maybe_input_scattered(forward_batch): - if attn.q_lora_rank is not None: - attn_tp_context.set_attn_inputs( - AttentionInputs( - hidden_states, - forward_batch, - lambda hs, fb: prepare_qkv_latent(attn, hs, fb), - ) - ) - attn_path = _dispatch_sgl_plugin_attn_path(forward_batch) - attn.current_sgl_plugin_attn_path = attn_path - if attn_path == "mha": - return forward_sgl_plugin_mode_mha( - attn, - positions, - hidden_states, - **model_kwargs, - ) - if attn_path == "mla": - return forward_sgl_plugin_mode_mla( - attn, - positions, - hidden_states, - **model_kwargs, - ) - raise ValueError(f"Unsupported plugin attention path: {attn_path}") - - -# Weight post-processing: decomposed into sub-functions -def _read_kv_b_proj_weight(attn: DeepseekV2MLAAttention) -> torch.Tensor: - """Read kv_b_proj weight, handling AWQ and fnuz dtypes.""" - if hasattr(attn.kv_b_proj, "qweight"): - awq_dequant = awq_dequantize_func() - if awq_dequant is None: - raise ValueError( - "AWQ dequantize function is not supported for current device" - ) - w = awq_dequant( - attn.kv_b_proj.qweight, - attn.kv_b_proj.scales, - attn.kv_b_proj.qzeros, - ).T - else: - layer_quant_config = getattr(attn.kv_b_proj, "layer_quant_config", None) - is_quark_static_mxfp4 = ( - layer_quant_config is not None - and layer_quant_config.quant_method == "quark" - and layer_quant_config.quant_dtype == dtypes.fp4x2 - and getattr(layer_quant_config.quant_type, "name", None) == "per_1x32" - ) - if is_quark_static_mxfp4: - w = getattr( - attn.kv_b_proj, - "_mxfp4_unshuffled_weight", - attn.kv_b_proj.weight, - ) - else: - w = attn.kv_b_proj.weight - - # On ROCm, ATOM creates parameters with fnuz dtype but loads fn bytes. - # View-cast back to fn so the normalize path works correctly. - if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fnuz: - w = w.view(torch.float8_e4m3fn) - - return w - - -def _get_weight_block_size(attn: DeepseekV2MLAAttention) -> Optional[list[int]]: - """Derive weight_block_size from ATOM's quant_type system.""" - from aiter import QuantType as _AiterQuantType - - qt = getattr(attn.kv_b_proj, "quant_type", None) - if qt == _AiterQuantType.per_1x128: - return [128, 128] - elif qt == _AiterQuantType.per_1x32: - return [1, 32] - return None - - -def _process_fp8_weight( - attn: DeepseekV2MLAAttention, - w: torch.Tensor, - weight_block_size: Optional[list[int]], -) -> tuple[torch.Tensor, bool, Optional[torch.Tensor]]: - """Process FP8 weights for kv_b_proj. - - Returns (w, use_deep_gemm_bmm, block_scale). - """ - from atom.model_ops.utils import normalize_e4m3fn_to_e4m3fnuz - from sglang.srt.layers.quantization.fp8_utils import ( - block_quant_dequant, - block_quant_to_tensor_quant, - channel_quant_to_tensor_quant, - inverse_transform_scale_ue8m0, - ) - from sglang.srt.layers.deep_gemm_wrapper import ( - ENABLE_JIT_DEEPGEMM, - DEEPGEMM_BLACKWELL, - ) - from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 - - use_deep_gemm_bmm = False - block_scale = None - - if weight_block_size is not None: - assert hasattr(attn.kv_b_proj, "weight_scale_inv") or hasattr( - attn.kv_b_proj, "weight_scale" - ) - weight_scale = ( - attn.kv_b_proj.weight_scale - if hasattr(attn.kv_b_proj, "weight_scale") - else attn.kv_b_proj.weight_scale_inv - ) - - if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fn: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, weight_scale=weight_scale, input_scale=None - ) - else: - weight = w - - if should_deepgemm_weight_requant_ue8m0( - weight_block_size=weight_block_size - ) and getattr(weight_scale, "format_ue8m0", False): - weight_scale = inverse_transform_scale_ue8m0( - weight_scale, mn=weight.shape[-2] - ) - - if _is_cuda and weight_block_size[0] == 128 and weight_block_size[1] == 128: - if ( - ENABLE_JIT_DEEPGEMM - and not DEEPGEMM_BLACKWELL - and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") - ): - block_scale = weight_scale - use_deep_gemm_bmm = True - else: - w = block_quant_dequant( - weight, weight_scale, weight_block_size, torch.bfloat16 - ) - else: - w, scale = block_quant_to_tensor_quant( - weight, weight_scale, weight_block_size - ) - attn.w_scale = scale - else: - if w.dtype == torch.float8_e4m3fn and _is_fp8_fnuz: - weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( - weight=w, weight_scale=attn.kv_b_proj.weight_scale, input_scale=None - ) - else: - weight = w - weight_scale = attn.kv_b_proj.weight_scale - - w, scale = channel_quant_to_tensor_quant(weight, weight_scale) - attn.w_scale = scale - - return w, use_deep_gemm_bmm, block_scale - - -def _process_int8_weight( - attn: DeepseekV2MLAAttention, - w: torch.Tensor, - weight_block_size: Optional[list[int]], -) -> torch.Tensor: - """Process INT8 weights for kv_b_proj.""" - from sglang.srt.layers.quantization.int8_utils import ( - block_dequant as int8_block_dequant, - ) - - if weight_block_size is not None: - assert hasattr(attn.kv_b_proj, "weight_scale_inv") - return int8_block_dequant( - w, attn.kv_b_proj.weight_scale_inv, weight_block_size - ).to(torch.bfloat16) - else: - return w.to(torch.bfloat16) * attn.kv_b_proj.weight_scale.to(torch.bfloat16) - - -def _split_kc_vc_like_vllm( - attn: DeepseekV2MLAAttention, w: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - """Split kv_b_proj weight using vLLM's transpose-first layout.""" - kv_b_proj_weight = w.T - assert kv_b_proj_weight.shape == ( - attn.kv_lora_rank, - attn.num_local_heads * (attn.qk_nope_head_dim + attn.v_head_dim), - ), ( - f"{kv_b_proj_weight.shape=}, " - f"{attn.kv_lora_rank=}, " - f"{attn.num_local_heads=}, " - f"{attn.qk_nope_head_dim=}, " - f"{attn.v_head_dim=}" - ) - kv_b_proj_weight = kv_b_proj_weight.view( - attn.kv_lora_rank, - attn.num_local_heads, - attn.qk_nope_head_dim + attn.v_head_dim, - ) - w_uk, w_uv = kv_b_proj_weight.split( - [attn.qk_nope_head_dim, attn.v_head_dim], dim=-1 - ) - return w_uk.transpose(0, 1).contiguous(), w_uv.permute(1, 2, 0).contiguous() - - -def _split_and_assign_kc_vc( - attn: DeepseekV2MLAAttention, - w: torch.Tensor, - use_deep_gemm_bmm: bool, - block_scale: Optional[torch.Tensor], - weight_block_size: Optional[list[int]], -) -> None: - """Split weight into kc/vc and assign to attn.""" - from atom.model_ops.utils import quark_post_load_weights - - w_kc, w_vc = w.unflatten(0, (-1, attn.qk_nope_head_dim + attn.v_head_dim)).split( - [attn.qk_nope_head_dim, attn.v_head_dim], dim=1 - ) - - # Quark MXFP4 LinearBase modules store quantization on layer_quant_config; - # there is no kv_b_proj.quant_method object to inspect in this path. - layer_quant_config = getattr(attn.kv_b_proj, "layer_quant_config", None) - quant_method = getattr(attn.kv_b_proj, "quant_method", None) - quant_config = getattr(quant_method, "quant_config", None) - is_quark_quant_method = ( - quant_config is not None and quant_config.get_name() == "quark" - ) - is_quark_mxfp4_linear = ( - layer_quant_config is not None - and layer_quant_config.quant_method == "quark" - and layer_quant_config.quant_dtype == dtypes.fp4x2 - and getattr(layer_quant_config.quant_type, "name", None) == "per_1x32" - ) - if _use_aiter_gfx95 and (is_quark_quant_method or is_quark_mxfp4_linear): - quark_weight = w - weight_scale = getattr(attn.kv_b_proj, "_mxfp4_unshuffled_weight_scale", None) - if is_quark_mxfp4_linear and weight_scale is not None: - quark_weight = mxfp4_to_f32(w.view(torch.uint8)).to(torch.bfloat16) - quark_weight = quark_weight * e8m0_to_f32( - weight_scale.repeat_interleave(32, dim=-1) - ).to(torch.bfloat16) - w_kc, attn.w_scale_k, w_vc, attn.w_scale_v = quark_post_load_weights( - attn, quark_weight, "mxfp4" - ) - - if not use_deep_gemm_bmm: - use_vllm_weight_layout = _is_hip and not ( - is_quark_quant_method or is_quark_mxfp4_linear - ) - - if use_vllm_weight_layout: - w_kc, w_vc = _split_kc_vc_like_vllm(attn, w) - else: - w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) - w_vc = w_vc.contiguous().transpose(1, 2) - - # Align bf16 kv_b_proj post-load handling with vLLM: split first, then - # quantize kc/vc independently for the fp8 BMM path. - if w.dtype == torch.bfloat16 and (_is_hip or _is_cuda): - w_kc, w_scale_k = dynamic_per_batched_tensor_quant(w_kc, dtype=dtypes.fp8) - w_vc, w_scale_v = dynamic_per_batched_tensor_quant(w_vc, dtype=dtypes.fp8) - attn.w_scale_k = bind_or_assign(attn.w_scale_k, w_scale_k) - attn.w_scale_v = bind_or_assign(attn.w_scale_v, w_scale_v) - - attn.w_kc = bind_or_assign(attn.w_kc, w_kc) - if _is_npu: - w_vc = w_vc.contiguous() - attn.w_vc = bind_or_assign(attn.w_vc, w_vc) - - if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn: - attn.w_kc = attn.w_kc.to(torch.bfloat16) * attn.w_scale - attn.w_vc = attn.w_vc.to(torch.bfloat16) * attn.w_scale - else: - num_tiles_k = attn.qk_nope_head_dim // weight_block_size[1] - num_tiles_n = attn.v_head_dim // weight_block_size[0] - ws_kc, ws_vc = block_scale.unflatten( - 0, (-1, (num_tiles_k + num_tiles_n)) - ).split([num_tiles_k, num_tiles_n], dim=1) - - attn.w_scale_k = bind_or_assign( - attn.w_scale_k, ws_kc.transpose(1, 2).contiguous() - ) - attn.w_scale_v = bind_or_assign(attn.w_scale_v, ws_vc.contiguous()) - attn.w_kc = bind_or_assign(attn.w_kc, w_kc.transpose(1, 2).contiguous()) - attn.w_vc = bind_or_assign(attn.w_vc, w_vc.contiguous()) - attn.use_deep_gemm_bmm = True - - -def process_mla_kv_b_proj_after_loading(attn: DeepseekV2MLAAttention) -> None: - """Process kv_b_proj weights after loading for sglang MLA mode. - - Orchestrates reading, quantization handling, and splitting of - kv_b_proj into absorbed w_kc / w_vc weights. - """ - if not getattr(attn.kv_b_proj, "_sgl_mxfp4_process_done", False): - attn.kv_b_proj.process_weights_after_loading() - - w = _read_kv_b_proj_weight(attn) - weight_block_size = _get_weight_block_size(attn) - - use_deep_gemm_bmm = False - block_scale = None - - # fp8 path - if w.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): - w, use_deep_gemm_bmm, block_scale = _process_fp8_weight( - attn, w, weight_block_size - ) - - # int8 path - if w.dtype == torch.int8: - w = _process_int8_weight(attn, w, weight_block_size) - - # split and assign kc/vc - _split_and_assign_kc_vc(attn, w, use_deep_gemm_bmm, block_scale, weight_block_size) - - -def _patch_kv_b_proj_for_sglang_mxfp4(attn: DeepseekV2MLAAttention) -> None: - """Preserve DeepSeek MLA kv_b_proj's original MXFP4 layout for kc/vc split.""" - kv_b_proj = attn.kv_b_proj - if getattr(kv_b_proj, "_sgl_mxfp4_preserve_patched", False): - return - - orig_process_weights_after_loading = kv_b_proj.process_weights_after_loading - - def process_weights_after_loading_with_mxfp4_preserve(): - if getattr(kv_b_proj, "_sgl_mxfp4_process_done", False): - return - - layer_quant_config = getattr(kv_b_proj, "layer_quant_config", None) - is_quark_static_mxfp4 = ( - kv_b_proj.weight.dim() == 2 - and layer_quant_config is not None - and layer_quant_config.quant_method == "quark" - and layer_quant_config.quant_dtype == dtypes.fp4x2 - and getattr(layer_quant_config.quant_type, "name", None) == "per_1x32" - and getattr(kv_b_proj, "source_quant_dtype", None) is None - ) - if is_quark_static_mxfp4: - kv_b_proj._mxfp4_unshuffled_weight = kv_b_proj.weight.detach().clone() - kv_b_proj._mxfp4_unshuffled_weight_scale = ( - kv_b_proj.weight_scale.detach().clone() - ) - orig_process_weights_after_loading() - kv_b_proj._sgl_mxfp4_process_done = True - - kv_b_proj.process_weights_after_loading = ( - process_weights_after_loading_with_mxfp4_preserve - ) - kv_b_proj._sgl_mxfp4_preserve_patched = True - - -# One-time model setup (called from base_model_wrapper.py) -def setup_deepseek_for_sglang(model) -> None: - """Patch a DeepseekV2/V3 model for sglang plugin mode. - - - Initialises sglang TP context - - Patches each MLAAttention.forward to dispatch to the sglang MLA path - - Registers process_weights_after_loading hooks - - Stores atom_config on the model - """ - config = model.config - - # Store atom_config (needed by load_weights in the OOT wrapper) - if not hasattr(model, "atom_config"): - from atom.config import get_current_atom_config - - model.atom_config = get_current_atom_config() - - kv_cache_dtype = model.atom_config.kv_cache_dtype - - # Initialise sglang TP context for MLA gather/scatter - from sglang.srt.configs.model_config import is_deepseek_nsa - from sglang.srt.layers.communicator import get_attn_tp_context - - get_attn_tp_context().init_context(config.q_lora_rank, is_deepseek_nsa(config)) - - # Patch each MLAAttention instance - from atom.models.deepseek_v2 import DeepseekV2MLAAttention - - for module in model.modules(): - if isinstance(module, DeepseekV2MLAAttention): - _patch_mla_attention_for_sglang(module, config, kv_cache_dtype) - - -def _patch_mla_attention_for_sglang(attn, config, kv_cache_dtype: str = "bf16") -> None: - """Patch a single DeepseekV2MLAAttention for sglang plugin mode. - - We patch attn.forward (rather than relying solely on ops.Attention = - RadixAttention) because MLA's absorbed-weight forward path replaces the - *entire* forward method — including RoPE, and absorbed - BMM — not just the attention backend. ops.Attention = RadixAttention - handles the backend layer (flash_attn / paged_attn dispatch) and is - already set via set_attn_cls(); this patch sits above that layer. - """ - init_sgl_attrs(attn, config, kv_cache_dtype) - _patch_kv_b_proj_for_sglang_mxfp4(attn) - - def patched_forward( - positions: torch.Tensor, - hidden_states: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - from atom.plugin.sglang.models.base_model_wrapper import ( - get_current_forward_batch, - ) - - kwargs["forward_batch"] = get_current_forward_batch() - return forward_sgl_plugin_mode(attn, positions, hidden_states, **kwargs) - - attn.forward = patched_forward - attn.process_weights_after_loading = lambda: process_mla_kv_b_proj_after_loading( - attn - ) diff --git a/atom/plugin/sglang/models/base_model_wrapper.py b/atom/plugin/sglang/models/base_model_wrapper.py index 3f2c743b14..c3bd0854be 100644 --- a/atom/plugin/sglang/models/base_model_wrapper.py +++ b/atom/plugin/sglang/models/base_model_wrapper.py @@ -6,13 +6,8 @@ To add a new model, append its architecture class name to _MODEL_NAMES. """ -import copy - import logging -from contextlib import contextmanager -from contextvars import ContextVar -from dataclasses import dataclass -from typing import Any, ClassVar, Iterable, Optional, Tuple, Union +from typing import Any, Iterable, Optional, Tuple, Union import torch from torch import nn @@ -22,312 +17,26 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -logger = logging.getLogger("atom.plugin.sglang.models") - -_RUNTIME_SENTINEL = object() - -# Context for patched DeepSeek attention layers that need wrapper state without -# changing every intermediate forward signature. ContextVar keeps nested or -# concurrent forwards isolated and lets us reliably restore the prior value. -_current_forward_batch: ContextVar[Optional[ForwardBatch]] = ContextVar( - "atom_sglang_current_forward_batch", default=None +from atom.plugin.sglang.runtime import ( + MODEL_ARCH_SPECS, + SGLangForwardBatchMetadata, + SGLangPluginRuntime, + bind_current_forward_batch, + get_current_forward_batch, + get_model_arch_spec, + plugin_runtime_scope, ) +logger = logging.getLogger("atom.plugin.sglang.models") -def get_current_forward_batch(): - return _current_forward_batch.get() - - -def _is_dummy_forward(forward_batch: ForwardBatch) -> bool: - # SGLang's IDLE batch is the plugin-side equivalent of ATOM dummy run. - forward_mode = getattr(forward_batch, "forward_mode", None) - return bool( - forward_mode is not None - and hasattr(forward_mode, "is_idle") - and forward_mode.is_idle() - ) - - -def _pad_dummy_like( - tensor: Optional[torch.Tensor], - *, - length: int, - fill_value: int | float = 0, -) -> Optional[torch.Tensor]: - if tensor is None: - return None - shape = (length, *tensor.shape[1:]) - return torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device) - - -def _materialize_atom_dummy_forward( - input_ids: Optional[torch.Tensor], - positions: Optional[torch.Tensor], - input_embeds: Optional[torch.Tensor], - forward_batch: ForwardBatch, -) -> tuple[ - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - ForwardBatch, -]: - """Convert an empty SGLang IDLE batch into ATOM-style dummy forward inputs.""" - dummy_positions = positions.new_zeros((1,)) - dummy_input_ids = input_ids.new_zeros((1,)) - dummy_input_embeds = _pad_dummy_like(input_embeds, length=1, fill_value=0) - - model_forward_batch = copy.copy(forward_batch) - model_forward_batch.positions = dummy_positions - model_forward_batch.batch_size = 1 - model_forward_batch.seq_lens_sum = 1 - model_forward_batch.seq_lens = forward_batch.seq_lens.new_ones((1,)) - model_forward_batch.seq_lens_cpu = forward_batch.seq_lens_cpu.new_ones((1,)) - - return dummy_input_ids, dummy_positions, dummy_input_embeds, model_forward_batch - - -def _trim_hidden_states_for_output(hidden_states, num_tokens: int): - if torch.is_tensor(hidden_states): - return hidden_states[:num_tokens] - if isinstance(hidden_states, tuple): - return tuple( - tensor[:num_tokens] if torch.is_tensor(tensor) else tensor - for tensor in hidden_states - ) - return hidden_states - - -def _resolve_num_tokens_across_dp( - atom_config: Any, - forward_batch: ForwardBatch, - num_tokens: int, - is_dummy_run: bool, -) -> torch.Tensor: - """Resolve per-DP token counts for ATOM's CPU-side DPMetadata. - - Real SGLang dp-attention batches carry ``global_num_tokens_cpu`` from the - scheduler. That list is the source of truth for mixed prefill/decode/idle - batches, where token counts may look like [8, 1, 8, 8]. - - Some SGLang synthetic/static batches, especially CUDA graph capture batches, - only keep the global token buffer on GPU. ATOM's DPMetadata is CPU-side and - needs a CPU tensor before model forward, so avoid reading the GPU buffer back - to CPU. We only fallback when the batch advertises the same-shape DP buffer - layout (global_dp_buffer_len == local_num_tokens * dp_size), where the CPU - equivalent is exactly [local_num_tokens] * dp_size. - - IDLE batches are reported by SGLang as 0 tokens on the current rank, but - this wrapper materializes them as one local dummy token before entering - ATOM. Patch the current DP rank after resolving the distribution so - ``DPMetadata`` sees a local count that matches the actual ATOM input. - """ - global_num_tokens_cpu = getattr(forward_batch, "global_num_tokens_cpu", None) - if global_num_tokens_cpu is not None: - num_tokens_across_dp = torch.tensor( - global_num_tokens_cpu, dtype=torch.int32, device="cpu" - ) - else: - dp_size = atom_config.parallel_config.data_parallel_size - global_num_tokens_gpu = getattr(forward_batch, "global_num_tokens_gpu", None) - global_dp_buffer_len = getattr(forward_batch, "global_dp_buffer_len", None) - is_static_same_shape_batch = ( - global_num_tokens_gpu is not None - and global_dp_buffer_len == num_tokens * dp_size - ) - if not is_static_same_shape_batch: - raise RuntimeError( - "[SGL+ATOM] SGLang dp-attention requires " - "forward_batch.global_num_tokens_cpu unless the batch uses static " - "same-shape DP metadata." - ) - - # Static batches, such as CUDA graph capture batches, may only keep - # global token counts on GPU. Avoid GPU-to-CPU reads here and mirror - # their same-shape layout directly for ATOM's CPU DPMetadata. - num_tokens_across_dp = torch.full( - (dp_size,), num_tokens, dtype=torch.int32, device="cpu" - ) - - if is_dummy_run: - # SGLang reports idle ranks as 0 tokens, but ATOM materializes them - # as one local dummy token so collectives and DPMetadata stay aligned. - dp_rank = atom_config.parallel_config.data_parallel_rank - num_tokens_across_dp[dp_rank] = num_tokens - return num_tokens_across_dp - - -def _set_sglang_forward_context( - atom_config: Any, - forward_batch: ForwardBatch, - positions: torch.Tensor, -) -> None: - """Bridge SGLang batch metadata into ATOM's global forward context.""" - from atom.utils.forward_context import ( - AttentionMetaData, - Context, - set_forward_context, - ) - - forward_mode = forward_batch.forward_mode - # TODO: This max_seqlen_q is not the source of truth for prefill attention; - # SGLang plugin attention consumes forward_batch.attn_backend.forward_metadata - # directly. In this wrapper it is only needed by ATOM MoE padding: under - # dp-attention + TP (non-EP all_gather/reduce_scatter), decode/idle batches - # must use 1 so pad_for_all_gather keeps fixed-shape collectives aligned. - # Leaving it as 0 there can make active and dummy ranks send different - # shapes to DP all_gather and hang. - max_seqlen_q = 1 if forward_mode.is_decode_or_idle() else 0 - attn_metadata = AttentionMetaData(max_seqlen_q=max_seqlen_q) - batch_size = int(forward_batch.batch_size) - is_dummy_run = _is_dummy_forward(forward_batch) - is_prefill = forward_mode.is_prefill() - num_tokens = int(positions.shape[0]) - - enable_dp_attention = bool(atom_config.enable_dp_attention) - if enable_dp_attention: - # SGLang owns the cross-DP token distribution under dp-attention; ATOM - # uses it to derive graph_bs and fixed-size MoE gather/scatter buffers. - num_tokens_across_dp = _resolve_num_tokens_across_dp( - atom_config, forward_batch, num_tokens, is_dummy_run - ) - graph_bs = int(torch.max(num_tokens_across_dp).item()) - else: - # Without dp-attention, ATOM runs with local-rank shapes only. There is - # no cross-DP token distribution to pass into DPMetadata, so graph_bs - # follows the local prefill token count or decode batch size. - num_tokens_across_dp = None - graph_bs = num_tokens if is_prefill else batch_size - context = Context( - positions=positions, - is_prefill=is_prefill, - is_dummy_run=is_dummy_run, - batch_size=batch_size, - graph_bs=graph_bs, - ) - set_forward_context( - attn_metadata=attn_metadata, - atom_config=atom_config, - context=context, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - ) - - -def _reset_sglang_forward_context() -> None: - from atom.utils.forward_context import reset_forward_context - - reset_forward_context() - - -@contextmanager -def plugin_runtime_scope( - *, - framework: Optional[str] = None, - atom_config: Any = _RUNTIME_SENTINEL, -): - """Temporarily bind plugin runtime globals to one wrapper instance. - - ATOM core currently relies on process-global framework/config state. In - SGLang speculative mode both target and draft wrappers coexist, so plugin - entrypoints must save/restore those globals around each init/load/forward. - """ - - import atom.config as atom_config_module - import atom.plugin.prepare as plugin_prepare - - prev_framework = plugin_prepare._CURRENT_FRAMEWORK - prev_atom_config = getattr(atom_config_module, "_current_atom_config", None) - - if framework is not None: - plugin_prepare._set_framework_backbone(framework) - if atom_config is not _RUNTIME_SENTINEL: - atom_config_module._current_atom_config = atom_config - - try: - yield - finally: - plugin_prepare._CURRENT_FRAMEWORK = prev_framework - atom_config_module._current_atom_config = prev_atom_config - - -@dataclass(frozen=True) -class SGLangForwardBatchMetadata: - """Small context object for one SGLang model forward.""" - - forward_batch: Optional[ForwardBatch] - pp_proxy_tensors: Optional[PPProxyTensors] = None - save_kv_cache: bool = True - _current: ClassVar[ContextVar[Optional["SGLangForwardBatchMetadata"]]] = ContextVar( - "atom_sglang_current_forward_batch_metadata", - default=None, - ) - - @classmethod - def current(cls) -> Optional["SGLangForwardBatchMetadata"]: - return cls._current.get() - - @classmethod - def build( - cls, - forward_batch: Optional[ - Union[ForwardBatch, "SGLangForwardBatchMetadata"] - ] = None, - *, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - save_kv_cache: Optional[bool] = None, - ) -> Optional["SGLangForwardBatchMetadata"]: - if isinstance(forward_batch, cls): - return forward_batch - if forward_batch is None and pp_proxy_tensors is None and save_kv_cache is None: - return cls.current() - return cls( - forward_batch=forward_batch, - pp_proxy_tensors=pp_proxy_tensors, - save_kv_cache=True if save_kv_cache is None else save_kv_cache, - ) - - @classmethod - @contextmanager - def bind(cls, metadata: Optional["SGLangForwardBatchMetadata"]): - meta_token = cls._current.set(metadata) - batch_token = _current_forward_batch.set( - None if metadata is None else metadata.forward_batch - ) - try: - yield metadata - finally: - _current_forward_batch.reset(batch_token) - cls._current.reset(meta_token) - - @staticmethod - def to_intermediate_tensors( - intermediate_tensors, - metadata: Optional["SGLangForwardBatchMetadata"], - ): - if intermediate_tensors is not None or metadata is None: - return intermediate_tensors - pp_proxy_tensors = metadata.pp_proxy_tensors - if pp_proxy_tensors is None: - return intermediate_tensors - tensors = getattr(pp_proxy_tensors, "tensors", None) - if tensors is None: - return intermediate_tensors - from atom.models.utils import IntermediateTensors - - return IntermediateTensors(dict(tensors)) - - -@dataclass(frozen=True) -class ModelArchSpec: - wrapper_binds_gdn_context: bool = False - apply_deepseek_patch: bool = False - - -_MODEL_ARCH_SPECS = { - "DeepseekV3ForCausalLM": ModelArchSpec(apply_deepseek_patch=True), - "Qwen3MoeForCausalLM": ModelArchSpec(), - "Qwen3NextForCausalLM": ModelArchSpec(wrapper_binds_gdn_context=True), -} +__all__ = [ + "EntryClass", + "SGLangForwardBatchMetadata", + "SGLangPluginRuntime", + "bind_current_forward_batch", + "get_current_forward_batch", + "plugin_runtime_scope", +] class _AtomCausalLMBaseForSglang(nn.Module): @@ -353,19 +62,13 @@ def __init__( self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size self.model_arch = getattr(config, "architectures", [""])[0] - self.model_arch_spec = _MODEL_ARCH_SPECS.get(self.model_arch, ModelArchSpec()) - - import atom + self.model_arch_spec = get_model_arch_spec(self.model_arch) - # TODO: prepare_model() currently handles model construction, config - # generation, attention backend registration, and distributed init. - # Refactor so this wrapper only dispatches the attention backend - # (register_ops_to_sglang + set_attn_cls), and let sglang handle - # model construction directly with plugin_runtime_scope(framework="sglang"): from atom.config import get_current_atom_config + from atom.plugin.sglang.prepare import prepare_model - self.model = atom.prepare_model(config=config, engine="sglang") + self.model = prepare_model(config=config) self.atom_config = getattr(self.model, "atom_config", None) if self.atom_config is None: self.atom_config = get_current_atom_config() @@ -383,15 +86,10 @@ def __init__( config, skip_all_gather=plugin_skip_all_gather ) - # Apply ds model-specific sglang patches (attn dispatch, weight hooks, etc.) - # TODO: will remove this after sglang supports atom attention backend - if self.model_arch_spec.apply_deepseek_patch: - from atom.plugin.sglang.attention_backend.sgl_attention_mla import ( - setup_deepseek_for_sglang, - ) - + # Apply model-specific install-time adapters (attn dispatch, weight hooks, etc.). + if self.model_arch_spec.install_adapters is not None: with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): - setup_deepseek_for_sglang(self.model) + self.model_arch_spec.install_adapters(self.model) def get_embed_and_head(self): if hasattr(self.model, "get_embed_and_head"): @@ -449,94 +147,60 @@ def forward( **model_kwargs: Any, ) -> Union[LogitsProcessorOutput, PPProxyTensors]: with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): - metadata = SGLangForwardBatchMetadata.build( - forward_batch, - pp_proxy_tensors=pp_proxy_tensors, - save_kv_cache=model_kwargs.get("save_kv_cache"), - ) - - if _is_dummy_forward(forward_batch): - ( - model_input_ids, - model_positions, - model_input_embeds, - model_forward_batch, - ) = _materialize_atom_dummy_forward( - input_ids, - positions, - input_embeds, - forward_batch, + with SGLangPluginRuntime( + atom_config=self.atom_config, + forward_batch=forward_batch, + positions=positions, + input_ids=input_ids, + input_embeds=input_embeds, + set_forward_context=not self.model_arch_spec.wrapper_binds_gdn_context, + ) as runtime: + metadata = SGLangForwardBatchMetadata.build( + runtime.forward_batch, + pp_proxy_tensors=pp_proxy_tensors, + save_kv_cache=model_kwargs.get("save_kv_cache"), ) - else: - ( - model_input_ids, - model_positions, - model_input_embeds, - model_forward_batch, - ) = ( - input_ids, - positions, - input_embeds, - forward_batch, + model_inputs = dict( + input_ids=runtime.input_ids, + positions=runtime.positions, + intermediate_tensors=SGLangForwardBatchMetadata.to_intermediate_tensors( + pp_proxy_tensors, metadata + ), + inputs_embeds=runtime.input_embeds, ) - - model_inputs = dict( - input_ids=model_input_ids, - positions=model_positions, - intermediate_tensors=SGLangForwardBatchMetadata.to_intermediate_tensors( - pp_proxy_tensors, metadata - ), - inputs_embeds=model_input_embeds, - ) - uses_context_only_forward = ( - self.model_arch_spec.apply_deepseek_patch - or self.model_arch_spec.wrapper_binds_gdn_context - ) - with SGLangForwardBatchMetadata.bind(metadata): - if self.model_arch_spec.wrapper_binds_gdn_context: - from atom.plugin.sglang.attention_backend.attention_gdn import ( - SGLangGDNForwardContext, - ) - - with SGLangGDNForwardContext.bind(metadata): - hidden_states = self.model(**model_inputs) - elif uses_context_only_forward: - try: - _set_sglang_forward_context( - self.atom_config, model_forward_batch, model_positions + uses_context_only_forward = ( + self.model_arch_spec.install_adapters is not None + or self.model_arch_spec.wrapper_binds_gdn_context + ) + with SGLangForwardBatchMetadata.bind(metadata): + if self.model_arch_spec.wrapper_binds_gdn_context: + from atom.plugin.sglang.attention_backend.attention_gdn import ( + SGLangGDNForwardContext, ) + + with SGLangGDNForwardContext.bind(metadata): + hidden_states = self.model(**model_inputs) + elif uses_context_only_forward: hidden_states = self.model(**model_inputs) - finally: - _reset_sglang_forward_context() - else: - try: - _set_sglang_forward_context( - self.atom_config, model_forward_batch, model_positions - ) + else: hidden_states = self.model( **model_inputs, - forward_batch=model_forward_batch, + forward_batch=runtime.forward_batch, get_embedding=get_embedding, pp_proxy_tensors=pp_proxy_tensors, **model_kwargs, ) - finally: - _reset_sglang_forward_context() - if self.pp_group.is_last_rank: - if _is_dummy_forward(forward_batch): - # TODO: Revisit if SGLang ever sends non-empty dummy batches. - # Today this path only runs when an empty IDLE batch is expanded - # to one ATOM dummy token, so the output boundary must trim back to - # the original SGLang-visible length: 0 tokens. - hidden_states = _trim_hidden_states_for_output(hidden_states, 0) - return self.logits_processor( - input_ids, - hidden_states, - self.model.lm_head, - forward_batch, - ) - return hidden_states + hidden_states = runtime.trim_output(hidden_states) + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.model.lm_head, + forward_batch, + ) + return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # The passed `weights` iterable from sglang is ignored because ATOM @@ -552,7 +216,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): EntryClass = [] -for _name in _MODEL_ARCH_SPECS: +for _name in MODEL_ARCH_SPECS: _cls = type(_name, (_AtomCausalLMBaseForSglang,), {}) globals()[_name] = _cls EntryClass.append(_cls) diff --git a/atom/plugin/sglang/models/deepseek_mla.py b/atom/plugin/sglang/models/deepseek_mla.py new file mode 100644 index 0000000000..c6b62f07ac --- /dev/null +++ b/atom/plugin/sglang/models/deepseek_mla.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Model-level DeepSeek MLA patching for SGLang plugin mode. + +This module owns the install-time hooks that adapt DeepSeek MLA models to +SGLang plugin mode. The heavy DeepSeek-specific runtime helpers live in +`atom.plugin.sglang.models.deepseek_mla_forward`. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from atom.plugin.sglang.models.deepseek_mla_attention import ( + SGLangDeepseekMLAAttention, +) +from atom.plugin.sglang.models.deepseek_mla_forward import ( + _patch_kv_b_proj_for_sglang_mxfp4, + init_sgl_attrs, + process_mla_kv_b_proj_after_loading, +) + +if TYPE_CHECKING: + from atom.models.deepseek_v2 import DeepseekV2MLAAttention + + +def setup_deepseek_for_sglang(model) -> None: + """Patch a DeepSeek V2/V3 model for SGLang plugin mode.""" + config = model.config + + # Store atom_config for the OOT wrapper before install-time hooks run. + if not hasattr(model, "atom_config"): + from atom.config import get_current_atom_config + + model.atom_config = get_current_atom_config() + + kv_cache_dtype = model.atom_config.kv_cache_dtype + + # Initialise SGLang's MLA TP context before patching per-layer forwards. + from sglang.srt.configs.model_config import is_deepseek_nsa + from sglang.srt.layers.communicator import get_attn_tp_context + + get_attn_tp_context().init_context(config.q_lora_rank, is_deepseek_nsa(config)) + + from atom.models.deepseek_v2 import DeepseekV2MLAAttention + + for module in model.modules(): + if isinstance(module, DeepseekV2MLAAttention): + _patch_mla_attention_for_sglang(module, config, kv_cache_dtype) + + +def _patch_mla_attention_for_sglang( + attn: "DeepseekV2MLAAttention", + config: Any, + kv_cache_dtype: str = "bf16", +) -> None: + """Patch one DeepSeek MLA layer for SGLang plugin mode.""" + init_sgl_attrs(attn, config, kv_cache_dtype) + _patch_kv_b_proj_for_sglang_mxfp4(attn) + if not isinstance(attn.mla_attn, SGLangDeepseekMLAAttention): + attn.mla_attn = SGLangDeepseekMLAAttention(attn, attn.mla_attn) + attn.process_weights_after_loading = lambda: process_mla_kv_b_proj_after_loading( + attn + ) diff --git a/atom/plugin/sglang/models/deepseek_mla_attention.py b/atom/plugin/sglang/models/deepseek_mla_attention.py new file mode 100644 index 0000000000..696f064a6d --- /dev/null +++ b/atom/plugin/sglang/models/deepseek_mla_attention.py @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""DeepSeek MLA wrapper for SGLang plugin mode. + +This adapter keeps the model-side entry at ``self.mla_attn(...)`` and owns the +SGLang-specific runtime dispatch for DeepSeek MLA. It is intentionally shaped +closer to the vLLM plugin path than the older model-side monkey-patched entry. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import torch +from torch import nn + +if TYPE_CHECKING: + from atom.models.deepseek_v2 import DeepseekV2MLAAttention + + +class SGLangDeepseekMLAAttention(nn.Module): + """Enter SGLang DeepSeek MLA runtime through ``self.mla_attn(...)``.""" + + def __init__( + self, + owner_attn: "DeepseekV2MLAAttention", + base_attn: nn.Module, + ) -> None: + super().__init__() + # Keep a non-module back reference. Registering owner_attn as a child + # module would create owner_attn -> mla_attn(wrapper) -> owner_attn and + # make nn.Module.train/eval recurse forever. + object.__setattr__(self, "owner_attn", owner_attn) + self.base_attn = base_attn + + @property + def attn(self): + return getattr(self.base_attn, "attn", self.base_attn) + + def _get_forward_batch(self, kwargs: dict[str, Any]): + forward_batch = kwargs.get("forward_batch", None) + if forward_batch is None: + from atom.plugin.sglang.runtime import ( + get_current_forward_batch, + ) + + forward_batch = get_current_forward_batch() + kwargs["forward_batch"] = forward_batch + if forward_batch is None: + raise RuntimeError( + "forward_batch is required for SGLang DeepSeek MLA wrapper" + ) + return forward_batch + + def _infer_total_tokens(self, forward_batch, tensor: torch.Tensor) -> int: + if hasattr(forward_batch, "input_ids") and forward_batch.input_ids is not None: + return int(forward_batch.input_ids.shape[0]) + if hasattr(forward_batch, "positions") and forward_batch.positions is not None: + return int(forward_batch.positions.shape[0]) + if hasattr(forward_batch, "seq_lens_sum"): + return int(forward_batch.seq_lens_sum) + return int(tensor.shape[0]) + + def _maybe_all_gather( + self, + tensor: torch.Tensor | None, + *, + total_tokens: int, + input_scattered: bool, + ): + if tensor is None or not input_scattered: + return tensor + from sglang.srt.distributed import get_tp_group + + output = tensor.new_empty((total_tokens, *tensor.shape[1:])) + get_tp_group().all_gather_into_tensor(output, tensor) + return output + + def _gather_runtime_inputs( + self, + q_input: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + positions: torch.Tensor, + q_scale: torch.Tensor | None, + *, + forward_batch, + input_scattered: bool, + ): + total_tokens = self._infer_total_tokens(forward_batch, q_input) + q_input = self._maybe_all_gather( + q_input, + total_tokens=total_tokens, + input_scattered=input_scattered, + ) + kv_c_normed = self._maybe_all_gather( + kv_c_normed, + total_tokens=total_tokens, + input_scattered=input_scattered, + ) + k_pe = self._maybe_all_gather( + k_pe, + total_tokens=total_tokens, + input_scattered=input_scattered, + ) + positions = self._maybe_all_gather( + positions, + total_tokens=total_tokens, + input_scattered=input_scattered, + ) + q_scale = self._maybe_all_gather( + q_scale, + total_tokens=total_tokens, + input_scattered=input_scattered, + ) + return q_input, kv_c_normed, k_pe, positions, q_scale + + def _project_q( + self, + q_input: torch.Tensor, + q_scale: torch.Tensor | None, + ) -> torch.Tensor: + attn = self.owner_attn + from atom.plugin.sglang.models.deepseek_mla_forward import _unwrap_linear_output + + if attn.q_lora_rank is not None: + q = ( + attn.q_b_proj(q_input, q_scale) + if q_scale is not None + else attn.q_b_proj(q_input) + ) + else: + q = ( + attn.q_proj(q_input, q_scale) + if q_scale is not None + else attn.q_proj(q_input) + ) + return _unwrap_linear_output(q).view(-1, attn.num_local_heads, attn.qk_head_dim) + + def _forward_absorbed( + self, + q_input: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + positions: torch.Tensor, + q_scale: torch.Tensor | None, + *, + forward_batch, + ) -> torch.Tensor: + attn = self.owner_attn + from aiter import dtypes + from atom.model_ops.attention_mla import fused_qk_rope_concat_and_cache_mla + from atom.plugin.sglang.models.deepseek_mla_forward import ( + _get_sglang_radix_attn, + mla_absorbed_bmm, + mla_v_up_proj, + ) + from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp + + q = self._project_q(q_input, q_scale) + k_nope = kv_c_normed.unsqueeze(1) + k_pe = k_pe.unsqueeze(1) + q_nope, q_pe = q.split([attn.qk_nope_head_dim, attn.qk_rope_head_dim], dim=-1) + q_nope_out = mla_absorbed_bmm( + attn, q_nope, attn.w_kc, attn.w_scale, attn.w_scale_k, attn.kv_lora_rank + ) + + if ( + attn.rotary_emb is not None + and not attn.use_fused_qk_rope_concat_and_cache_mla + ): + q_pe, k_pe = attn.rotary_emb(positions, q_pe, k_pe) + + if nsa_use_prefill_cp(forward_batch): + latent_cache = torch.cat([k_nope.squeeze(1), k_pe.squeeze(1)], dim=-1) + k_nope, k_pe = attn.rebuild_cp_kv_cache( + latent_cache, forward_batch, k_nope, k_pe + ) + + save_kv_cache = True + if attn.use_fused_qk_rope_concat_and_cache_mla: + mla_attn = _get_sglang_radix_attn(self.base_attn) + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(mla_attn.layer_id) + q_out_dtype = ( + dtypes.fp8 if attn.kv_cache_dtype == "fp8_e4m3" else q_nope_out.dtype + ) + q = torch.empty( + ( + q_nope_out.shape[0], + attn.num_local_heads, + attn.kv_lora_rank + attn.qk_rope_head_dim, + ), + dtype=q_out_dtype, + device=q_nope_out.device, + ) + fused_qk_rope_concat_and_cache_mla( + q_nope_out, + q_pe, + k_nope, + k_pe, + kv_cache, + q, + forward_batch.out_cache_loc, + mla_attn.k_scale, + mla_attn.k_scale, + positions, + attn.rotary_emb.cos_cache, + attn.rotary_emb.sin_cache, + is_neox=attn.rotary_emb.is_neox_style, + is_nope_first=True, + ) + k = None + v = None + save_kv_cache = False + else: + q = torch.cat([q_nope_out, q_pe], dim=-1) + k = torch.cat([k_nope, k_pe], dim=-1) + v = k_nope + + attn_output = self.base_attn( + q, + k, + v, + forward_batch=forward_batch, + save_kv_cache=save_kv_cache, + ) + attn_output = attn_output.view(-1, attn.num_local_heads, attn.kv_lora_rank) + attn_bmm_output = mla_v_up_proj( + attn, attn_output, attn.w_vc, attn.w_scale, attn.w_scale_v, attn.v_head_dim + ) + return attn.o_proj(attn_bmm_output) + + def _forward_non_absorbed( + self, + q_input: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + positions: torch.Tensor, + q_scale: torch.Tensor | None, + *, + forward_batch, + ) -> torch.Tensor: + attn = self.owner_attn + from atom.plugin.sglang.models.deepseek_mla_forward import ( + _concat_mha_k_for_non_absorbed, + _set_mla_kv_buffer_for_non_absorbed, + _unwrap_linear_output, + ) + + q = self._project_q(q_input, q_scale) + _, q_pe = q.split([attn.qk_nope_head_dim, attn.qk_rope_head_dim], dim=-1) + + kv_a = kv_c_normed + k_pe = k_pe.unsqueeze(1) + if attn.rotary_emb is not None: + q_pe, k_pe = attn.rotary_emb(positions, q_pe, k_pe) + q[..., attn.qk_nope_head_dim :] = q_pe + + _set_mla_kv_buffer_for_non_absorbed(attn, kv_a, k_pe, forward_batch) + + kv = _unwrap_linear_output(attn.kv_b_proj(kv_a)).view( + -1, attn.num_local_heads, attn.qk_nope_head_dim + attn.v_head_dim + ) + k_nope = kv[..., : attn.qk_nope_head_dim] + v = kv[..., attn.qk_nope_head_dim :] + k = _concat_mha_k_for_non_absorbed(attn, k_nope, k_pe) + + attn_output = attn.attn_non_absorbed( + q, + k, + v, + forward_batch=forward_batch, + save_kv_cache=False, + ) + attn_output = attn_output.reshape(-1, attn.num_local_heads * attn.v_head_dim) + return attn.o_proj(attn_output) + + def forward( + self, + q_input: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + positions: torch.Tensor, + q_scale: torch.Tensor | None = None, + **kwargs: Any, + ) -> torch.Tensor: + attn = self.owner_attn + forward_batch = self._get_forward_batch(kwargs) + + from atom.plugin.sglang.models.deepseek_mla_forward import ( + _can_run_non_absorbed_mla_now, + ) + from sglang.srt.layers.communicator import get_attn_tp_context + + attn_tp_context = get_attn_tp_context() + with attn_tp_context.maybe_input_scattered(forward_batch): + q_input, kv_c_normed, k_pe, positions, q_scale = ( + self._gather_runtime_inputs( + q_input, + kv_c_normed, + k_pe, + positions, + q_scale, + forward_batch=forward_batch, + input_scattered=attn_tp_context.input_scattered, + ) + ) + + use_non_absorbed = ( + forward_batch.forward_mode.is_extend_without_speculative() + ) + if not use_non_absorbed and forward_batch.forward_mode.is_draft_extend(): + extend_prefix_lens_cpu = getattr( + forward_batch, "extend_prefix_lens_cpu", None + ) + use_non_absorbed = extend_prefix_lens_cpu is not None and not any( + extend_prefix_lens_cpu + ) + + if use_non_absorbed: + if _can_run_non_absorbed_mla_now(attn, forward_batch): + attn.current_sgl_plugin_attn_path = "non_absorbed" + return self._forward_non_absorbed( + q_input, + kv_c_normed, + k_pe, + positions, + q_scale, + forward_batch=forward_batch, + ) + attn.current_sgl_plugin_attn_path = "absorbed_fallback" + else: + attn.current_sgl_plugin_attn_path = "absorbed" + + return self._forward_absorbed( + q_input, + kv_c_normed, + k_pe, + positions, + q_scale, + forward_batch=forward_batch, + ) diff --git a/atom/plugin/sglang/models/deepseek_mla_forward.py b/atom/plugin/sglang/models/deepseek_mla_forward.py new file mode 100644 index 0000000000..d18449cad7 --- /dev/null +++ b/atom/plugin/sglang/models/deepseek_mla_forward.py @@ -0,0 +1,759 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Helper functions for DeepSeek MLA in SGLang plugin mode. + +This module now contains only the low-level helpers that are still shared by +the SGLang DeepSeek MLA wrapper and the install-time weight hooks: +absorbed BMM math, small utility helpers, non-absorbed cache staging, and +kv_b_proj post-load processing. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Optional + +import torch +from aiter import dtypes +from atom.model_ops.base_attention import Attention +from atom.model_ops.attention_mla import ( + dynamic_per_batched_tensor_quant, +) +from atom.models.deepseek_v2 import _fuse_rmsnorm_quant +from atom.models.utils import maybe_prefix + +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode +from sglang.srt.models.deepseek_common.utils import ( + _use_aiter_gfx95, + _is_hip, + _is_cpu, + _is_cpu_amx_available, + _is_cuda, + _is_fp8_fnuz, + _is_npu, + awq_dequantize_func, +) +from sglang.srt.layers.quantization.rocm_mxfp4_utils import ( + batched_gemm_afp4wfp4_pre_quant, +) +from aiter.utility.fp4_utils import e8m0_to_f32, mxfp4_to_f32 +from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import ( + batched_gemm_a16wfp4, +) +from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant, +) +from sglang.srt.layers.quantization.fp8_kernel import ( + per_tensor_quant_mla_fp8, + per_token_group_quant_mla_deep_gemm_masked_fp8, +) +from sglang.srt.utils import bind_or_assign, get_bool_env_var + +if TYPE_CHECKING: + from atom.models.deepseek_v2 import DeepseekV2MLAAttention + + +logger = logging.getLogger(__name__) + + +# bmm_fp8 custom-op wrapper (adapted from sglang forward_mla.py) +if _is_cuda: + from sgl_kernel import bmm_fp8 as _raw_bmm_fp8 + from sglang.srt.utils.custom_op import register_custom_op + + @register_custom_op(mutates_args=["out"]) + def _bmm_fp8_op( + A: torch.Tensor, + B: torch.Tensor, + out: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + ) -> None: + _raw_bmm_fp8(A, B, A_scale, B_scale, out.dtype, out) + + def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): + if out is None: + out = torch.empty( + (A.shape[0], A.shape[1], B.shape[2]), + device=A.device, + dtype=dtype, + ) + _bmm_fp8_op(A, B, out, A_scale, B_scale) + return out + +else: + + def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): + raise RuntimeError("bmm_fp8 requires CUDA (sgl_kernel)") + + +def _unwrap_linear_output(output: Any) -> torch.Tensor: + """Normalize ATOM/public-SGLang linear outputs to a tensor.""" + if isinstance(output, tuple): + return output[0] + return output + + +def _linear_quant_type_value(linear: Any) -> Optional[int]: + quant_type = getattr(linear, "quant_type", None) + return None if quant_type is None else getattr(quant_type, "value", quant_type) + + +def _fuse_qk_rmsnorm_and_q_quant( + attn: DeepseekV2MLAAttention, + q: torch.Tensor, + k_nope: torch.Tensor, + *, + output_unquantized_q: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + """Fuse q/k RMSNorm and q quant using ATOM's DeepSeek-V2 path.""" + + (q_quantized, q_scale), q_normed, k_nope_normed, _ = _fuse_rmsnorm_quant( + q, + attn.q_a_layernorm.weight, + attn.q_a_layernorm.eps, + k_nope, + attn.kv_a_layernorm.weight, + attn.kv_a_layernorm.eps, + None, + dtype_quant=attn.quant_dtype, + shuffle=False, + scale_shuffle_padding=False, + group_size=128, + quant_type=_linear_quant_type_value(attn.q_b_proj), + output_unquantized_inp1=output_unquantized_q, + transpose_scale=True, + ) + return q_quantized, q_scale, q_normed, k_nope_normed + + +def _fuse_qk_rmsnorm( + attn: DeepseekV2MLAAttention, + q: torch.Tensor, + k_nope: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fuse q/k RMSNorm without quantizing q.""" + (q_normed, _), _, k_nope_normed, _ = _fuse_rmsnorm_quant( + q, + attn.q_a_layernorm.weight, + attn.q_a_layernorm.eps, + k_nope, + attn.kv_a_layernorm.weight, + attn.kv_a_layernorm.eps, + None, + dtype_quant=torch.bfloat16, + shuffle=False, + scale_shuffle_padding=False, + group_size=128, + quant_type=None, + output_unquantized_inp1=False, + transpose_scale=False, + ) + return q_normed, k_nope_normed + + +def _prepare_weight_for_bmm( + weight: torch.Tensor, in_dim: int, out_dim: int +) -> torch.Tensor: + """Normalize absorbed weight layout for torch.bmm fallback.""" + if weight.shape[1] == in_dim and weight.shape[2] == out_dim: + return weight + if weight.shape[1] == out_dim and weight.shape[2] == in_dim: + return weight.transpose(-2, -1) + raise RuntimeError( + "Unexpected absorbed weight shape for bmm fallback: " + f"{tuple(weight.shape)} with in_dim={in_dim}, out_dim={out_dim}" + ) + + +def init_sgl_attrs( + attn: DeepseekV2MLAAttention, + config, + kv_cache_dtype: str = "bf16", +) -> None: + """Initialise sglang-only attributes on DeepseekV2MLAAttention.""" + from sglang.srt.configs.model_config import is_deepseek_nsa + + attn.use_nsa = is_deepseek_nsa(config) + attn.use_deep_gemm_bmm = False + attn.alt_stream = None + attn.kv_cache_dtype = kv_cache_dtype + attn.use_fused_qk_rope_concat_and_cache_mla = _use_aiter_gfx95 + attn.current_sgl_plugin_attn_path = None + attn.w_kc, attn.w_vc = None, None + attn.w_scale = None + attn.w_scale_k = None + attn.w_scale_v = None + attn.attn_non_absorbed = Attention( + num_heads=attn.num_local_heads, + head_dim=attn.qk_head_dim, + scale=attn.scaling, + num_kv_heads=attn.num_local_heads, + kv_cache_dtype=kv_cache_dtype, + layer_num=attn.layer_num, + use_mla=False, + v_head_dim=attn.v_head_dim, + prefix=maybe_prefix(attn.prefix, "attn_non_absorbed"), + ) + _bind_non_absorbed_kv_b_proj(attn) + + +def mla_absorbed_bmm( + attn: DeepseekV2MLAAttention, + inp: torch.Tensor, + weight: torch.Tensor, + weight_scale: Optional[torch.Tensor], + weight_scale_k: Optional[torch.Tensor], + out_dim: int, +) -> torch.Tensor: + """Batched matmul for MLA absorbed weights (w_kc / w_vc).""" + effective_weight_scale = ( + weight_scale_k if weight_scale_k is not None else weight_scale + ) + + if attn.use_deep_gemm_bmm: + from sglang.srt.layers import deep_gemm_wrapper + + val, scale, masked_m, expected_m, aligned_m = ( + per_token_group_quant_mla_deep_gemm_masked_fp8(inp.transpose(0, 1)) + ) + out = inp.new_empty((attn.num_local_heads, aligned_m, out_dim)) + deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked( + (val, scale), + (weight, weight_scale_k), + out, + masked_m, + expected_m, + ) + return out[:, :expected_m, :].transpose(0, 1) + + if _is_hip: + if _use_aiter_gfx95 and weight.dtype == torch.uint8: + x = inp.transpose(0, 1) + out = torch.empty( + x.shape[0], + x.shape[1], + weight.shape[2], + device=x.device, + dtype=torch.bfloat16, + ) + batched_gemm_afp4wfp4_pre_quant( + x, + weight.transpose(-2, -1), + weight_scale_k.transpose(-2, -1), + torch.bfloat16, + out, + ) + return out.transpose(0, 1) + + if (_use_aiter_gfx95 and weight.dtype == torch.float8_e4m3fn) or ( + get_is_capture_mode() and weight.dtype == torch.float8_e4m3fnuz + ): + x = inp.transpose(0, 1) + out = ( + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=x, + WQ=weight, + w_scale=effective_weight_scale, + group_size=128, + YQ=None, + transpose_bm=True, + transpose_bm_in=False, + dtype=torch.bfloat16, + ) + ) + return out + + w_bf16 = _prepare_weight_for_bmm(weight, inp.shape[-1], out_dim).to( + torch.bfloat16 + ) + if effective_weight_scale is not None: + w_bf16 = w_bf16 * effective_weight_scale + out = torch.bmm( + inp.to(torch.bfloat16).transpose(0, 1), + w_bf16, + ) + return out.transpose(0, 1) + + if weight.dtype == torch.float8_e4m3fn: + val, scale = per_tensor_quant_mla_fp8( + inp.transpose(0, 1), + torch.zeros((1,), dtype=torch.float32, device=inp.device), + ) + out = bmm_fp8(val, weight, scale, effective_weight_scale, torch.bfloat16) + return out.transpose(0, 1) + + return torch.bmm(inp.transpose(0, 1), weight).transpose(0, 1) + + +def mla_v_up_proj( + attn: DeepseekV2MLAAttention, + inp: torch.Tensor, + weight: torch.Tensor, + weight_scale: Optional[torch.Tensor], + weight_scale_k: Optional[torch.Tensor], + out_dim: int, +) -> torch.Tensor: + """Project MLA decode output to a flat o_proj input.""" + effective_weight_scale = ( + weight_scale_k if weight_scale_k is not None else weight_scale + ) + if _is_hip and _use_aiter_gfx95 and weight.dtype == torch.uint8: + x = inp.transpose(0, 1) + out = torch.empty( + (inp.shape[0], attn.num_local_heads * out_dim), + device=inp.device, + dtype=torch.bfloat16, + ) + out_3d = out.view(inp.shape[0], attn.num_local_heads, out_dim) + batched_gemm_a16wfp4( + x, + weight.transpose(-2, -1), + weight_scale_k.transpose(-2, -1), + dtype=torch.bfloat16, + y=out_3d, + transpose_bm=True, + prequant=True, + y_scale=None, + ) + return out + + if _is_hip and ( + (_use_aiter_gfx95 and weight.dtype == torch.float8_e4m3fn) + or (get_is_capture_mode() and weight.dtype == torch.float8_e4m3fnuz) + ): + x = inp.transpose(0, 1) + out = torch.empty( + (inp.shape[0], attn.num_local_heads * out_dim), + device=inp.device, + dtype=torch.bfloat16, + ) + out_3d = out.view(inp.shape[0], attn.num_local_heads, out_dim) + batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant( + X=x, + WQ=weight, + w_scale=effective_weight_scale, + group_size=128, + YQ=out_3d, + transpose_bm=True, + transpose_bm_in=False, + dtype=torch.bfloat16, + ) + return out + + return mla_absorbed_bmm( + attn, inp, weight, weight_scale, weight_scale_k, out_dim + ).flatten(1, 2) + + +def _get_sglang_radix_attn(attn_module): + return attn_module.attn if hasattr(attn_module, "attn") else attn_module + + +def _bind_non_absorbed_kv_b_proj(attn: DeepseekV2MLAAttention) -> None: + """Expose DeepSeek's latent-KV projection on the non-absorbed SGLang layer.""" + + if not hasattr(attn, "attn_non_absorbed"): + return + attn_non_absorbed = _get_sglang_radix_attn(attn.attn_non_absorbed) + attn_non_absorbed.kv_b_proj = attn.kv_b_proj + + +def _concat_mha_k_for_non_absorbed( + attn: DeepseekV2MLAAttention, + k_nope: torch.Tensor, + k_pe: torch.Tensor, +) -> torch.Tensor: + k = k_nope.new_empty( + k_nope.shape[0], + attn.num_local_heads, + attn.qk_nope_head_dim + attn.qk_rope_head_dim, + ) + + try: + from sglang.srt.layers.attention.utils import concat_and_cast_mha_k_triton + except ImportError as exc: + logger.warning( + "Unable to import concat_and_cast_mha_k_triton; " + "falling back to torch native non-absorbed K concat: %s", + exc, + ) + else: + concat_and_cast_mha_k_triton(k, k_nope, k_pe) + return k + + k[..., : attn.qk_nope_head_dim] = k_nope + k[..., attn.qk_nope_head_dim :] = k_pe + return k + + +def _set_mla_kv_buffer_for_non_absorbed( + attn: DeepseekV2MLAAttention, + kv_a: torch.Tensor, + k_pe: torch.Tensor, + forward_batch, +) -> None: + attn_non_absorbed = _get_sglang_radix_attn(attn.attn_non_absorbed) + cache_k = torch.cat([kv_a.unsqueeze(1), k_pe], dim=-1) + forward_batch.token_to_kv_pool.set_kv_buffer( + attn_non_absorbed, + forward_batch.out_cache_loc, + cache_k, + cache_k, + ) + + +def _is_mxfp4_kv_b_proj(attn: DeepseekV2MLAAttention) -> bool: + kv_b_proj = attn.kv_b_proj + params_dtype = getattr(kv_b_proj, "params_dtype", None) + if params_dtype == dtypes.fp4x2 or params_dtype == getattr( + torch, "float4_e2m1fn_x2", None + ): + return True + + quant_type = getattr(kv_b_proj, "quant_type", None) + if getattr(quant_type, "name", "") == "per_1x32" or str(quant_type).endswith( + "per_1x32" + ): + return True + + quant_method = getattr(kv_b_proj, "quant_method", None) + quant_config = getattr(quant_method, "quant_config", None) + return bool( + quant_config is not None + and quant_config.get_name() == "quark" + and kv_b_proj.weight.dtype == torch.uint8 + ) + + +def _can_run_non_absorbed_mla_now( + attn: DeepseekV2MLAAttention, + forward_batch, +) -> bool: + """Check if the model configuration supports the non-absorbed MLA path. + + This is a capability gate. NSA models cannot use the non-absorbed path. + MXFP4 ``kv_b_proj`` weights are supported because that path expands K/V + through ``attn.kv_b_proj`` itself, which owns the per_1x32 GEMM + implementation. + """ + del forward_batch + if attn.use_nsa: + return False + if attn.kv_b_proj.weight.dtype == torch.uint8 and not _is_mxfp4_kv_b_proj(attn): + return False + return True + + +def _is_static_quark_mxfp4_kv_b_proj(kv_b_proj) -> bool: + layer_quant_config = getattr(kv_b_proj, "layer_quant_config", None) + return ( + getattr(kv_b_proj, "weight", None) is not None + and kv_b_proj.weight.dim() == 2 + and layer_quant_config is not None + and layer_quant_config.quant_method == "quark" + and layer_quant_config.quant_dtype == dtypes.fp4x2 + and getattr(layer_quant_config.quant_type, "name", None) == "per_1x32" + and getattr(kv_b_proj, "source_quant_dtype", None) is None + ) + + +def _read_kv_b_proj_weight(attn: DeepseekV2MLAAttention) -> torch.Tensor: + """Read kv_b_proj weight, handling AWQ and fnuz dtypes.""" + if hasattr(attn.kv_b_proj, "qweight"): + awq_dequant = awq_dequantize_func() + if awq_dequant is None: + raise ValueError( + "AWQ dequantize function is not supported for current device" + ) + w = awq_dequant( + attn.kv_b_proj.qweight, + attn.kv_b_proj.scales, + attn.kv_b_proj.qzeros, + ).T + else: + if _is_static_quark_mxfp4_kv_b_proj(attn.kv_b_proj): + w = getattr( + attn.kv_b_proj, + "_mxfp4_unshuffled_weight", + attn.kv_b_proj.weight, + ) + else: + w = attn.kv_b_proj.weight + + if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fnuz: + w = w.view(torch.float8_e4m3fn) + + return w + + +def _get_weight_block_size(attn: DeepseekV2MLAAttention) -> Optional[list[int]]: + """Derive weight_block_size from ATOM's quant_type system.""" + from aiter import QuantType as _AiterQuantType + + qt = getattr(attn.kv_b_proj, "quant_type", None) + if qt == _AiterQuantType.per_1x128: + return [128, 128] + elif qt == _AiterQuantType.per_1x32: + return [1, 32] + return None + + +def _process_fp8_weight( + attn: DeepseekV2MLAAttention, + w: torch.Tensor, + weight_block_size: Optional[list[int]], +) -> tuple[torch.Tensor, bool, Optional[torch.Tensor]]: + """Process FP8 weights for kv_b_proj.""" + from atom.model_ops.utils import normalize_e4m3fn_to_e4m3fnuz + from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_dequant, + block_quant_to_tensor_quant, + channel_quant_to_tensor_quant, + inverse_transform_scale_ue8m0, + ) + from sglang.srt.layers.deep_gemm_wrapper import ( + ENABLE_JIT_DEEPGEMM, + DEEPGEMM_BLACKWELL, + ) + from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 + + use_deep_gemm_bmm = False + block_scale = None + + if weight_block_size is not None: + assert hasattr(attn.kv_b_proj, "weight_scale_inv") or hasattr( + attn.kv_b_proj, "weight_scale" + ) + weight_scale = ( + attn.kv_b_proj.weight_scale + if hasattr(attn.kv_b_proj, "weight_scale") + else attn.kv_b_proj.weight_scale_inv + ) + + if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fn: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, weight_scale=weight_scale, input_scale=None + ) + else: + weight = w + + if should_deepgemm_weight_requant_ue8m0( + weight_block_size=weight_block_size + ) and getattr(weight_scale, "format_ue8m0", False): + weight_scale = inverse_transform_scale_ue8m0( + weight_scale, mn=weight.shape[-2] + ) + + if _is_cuda and weight_block_size[0] == 128 and weight_block_size[1] == 128: + if ( + ENABLE_JIT_DEEPGEMM + and not DEEPGEMM_BLACKWELL + and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") + ): + block_scale = weight_scale + use_deep_gemm_bmm = True + else: + w = block_quant_dequant( + weight, weight_scale, weight_block_size, torch.bfloat16 + ) + else: + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + attn.w_scale = scale + else: + if w.dtype == torch.float8_e4m3fn and _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, weight_scale=attn.kv_b_proj.weight_scale, input_scale=None + ) + else: + weight = w + weight_scale = attn.kv_b_proj.weight_scale + + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) + attn.w_scale = scale + + return w, use_deep_gemm_bmm, block_scale + + +def _process_int8_weight( + attn: DeepseekV2MLAAttention, + w: torch.Tensor, + weight_block_size: Optional[list[int]], +) -> torch.Tensor: + """Process INT8 weights for kv_b_proj.""" + from sglang.srt.layers.quantization.int8_utils import ( + block_dequant as int8_block_dequant, + ) + + if weight_block_size is not None: + assert hasattr(attn.kv_b_proj, "weight_scale_inv") + return int8_block_dequant( + w, attn.kv_b_proj.weight_scale_inv, weight_block_size + ).to(torch.bfloat16) + else: + return w.to(torch.bfloat16) * attn.kv_b_proj.weight_scale.to(torch.bfloat16) + + +def _split_kc_vc_like_vllm( + attn: DeepseekV2MLAAttention, w: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """Split kv_b_proj weight using vLLM's transpose-first layout.""" + kv_b_proj_weight = w.T + assert kv_b_proj_weight.shape == ( + attn.kv_lora_rank, + attn.num_local_heads * (attn.qk_nope_head_dim + attn.v_head_dim), + ), ( + f"{kv_b_proj_weight.shape=}, " + f"{attn.kv_lora_rank=}, " + f"{attn.num_local_heads=}, " + f"{attn.qk_nope_head_dim=}, " + f"{attn.v_head_dim=}" + ) + kv_b_proj_weight = kv_b_proj_weight.view( + attn.kv_lora_rank, + attn.num_local_heads, + attn.qk_nope_head_dim + attn.v_head_dim, + ) + w_uk, w_uv = kv_b_proj_weight.split( + [attn.qk_nope_head_dim, attn.v_head_dim], dim=-1 + ) + return w_uk.transpose(0, 1).contiguous(), w_uv.permute(1, 2, 0).contiguous() + + +def _split_and_assign_kc_vc( + attn: DeepseekV2MLAAttention, + w: torch.Tensor, + use_deep_gemm_bmm: bool, + block_scale: Optional[torch.Tensor], + weight_block_size: Optional[list[int]], +) -> None: + """Split weight into kc/vc and assign to attn.""" + from atom.model_ops.utils import quark_post_load_weights + + w_kc, w_vc = w.unflatten(0, (-1, attn.qk_nope_head_dim + attn.v_head_dim)).split( + [attn.qk_nope_head_dim, attn.v_head_dim], dim=1 + ) + + # Quark MXFP4 LinearBase modules store quantization on layer_quant_config; + # there is no kv_b_proj.quant_method object to inspect in this path. + layer_quant_config = getattr(attn.kv_b_proj, "layer_quant_config", None) + quant_method = getattr(attn.kv_b_proj, "quant_method", None) + quant_config = getattr(quant_method, "quant_config", None) + is_quark_quant_method = ( + quant_config is not None and quant_config.get_name() == "quark" + ) + is_quark_mxfp4_linear = ( + layer_quant_config is not None + and layer_quant_config.quant_method == "quark" + and layer_quant_config.quant_dtype == dtypes.fp4x2 + and getattr(layer_quant_config.quant_type, "name", None) == "per_1x32" + ) + if _use_aiter_gfx95 and (is_quark_quant_method or is_quark_mxfp4_linear): + quark_weight = w + weight_scale = getattr(attn.kv_b_proj, "_mxfp4_unshuffled_weight_scale", None) + if is_quark_mxfp4_linear and weight_scale is not None: + quark_weight = mxfp4_to_f32(w.view(torch.uint8)).to(torch.bfloat16) + quark_weight = quark_weight * e8m0_to_f32( + weight_scale.repeat_interleave(32, dim=-1) + ).to(torch.bfloat16) + w_kc, attn.w_scale_k, w_vc, attn.w_scale_v = quark_post_load_weights( + attn, quark_weight, "mxfp4" + ) + + if not use_deep_gemm_bmm: + use_vllm_weight_layout = _is_hip and not ( + is_quark_quant_method or is_quark_mxfp4_linear + ) + + if use_vllm_weight_layout: + w_kc, w_vc = _split_kc_vc_like_vllm(attn, w) + else: + w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + w_vc = w_vc.contiguous().transpose(1, 2) + + if w.dtype == torch.bfloat16 and (_is_hip or _is_cuda): + w_kc, w_scale_k = dynamic_per_batched_tensor_quant(w_kc, dtype=dtypes.fp8) + w_vc, w_scale_v = dynamic_per_batched_tensor_quant(w_vc, dtype=dtypes.fp8) + attn.w_scale_k = bind_or_assign(attn.w_scale_k, w_scale_k) + attn.w_scale_v = bind_or_assign(attn.w_scale_v, w_scale_v) + + attn.w_kc = bind_or_assign(attn.w_kc, w_kc) + if _is_npu: + w_vc = w_vc.contiguous() + attn.w_vc = bind_or_assign(attn.w_vc, w_vc) + + if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn: + attn.w_kc = attn.w_kc.to(torch.bfloat16) * attn.w_scale + attn.w_vc = attn.w_vc.to(torch.bfloat16) * attn.w_scale + else: + num_tiles_k = attn.qk_nope_head_dim // weight_block_size[1] + num_tiles_n = attn.v_head_dim // weight_block_size[0] + ws_kc, ws_vc = block_scale.unflatten( + 0, (-1, (num_tiles_k + num_tiles_n)) + ).split([num_tiles_k, num_tiles_n], dim=1) + + attn.w_scale_k = bind_or_assign( + attn.w_scale_k, ws_kc.transpose(1, 2).contiguous() + ) + attn.w_scale_v = bind_or_assign(attn.w_scale_v, ws_vc.contiguous()) + attn.w_kc = bind_or_assign(attn.w_kc, w_kc.transpose(1, 2).contiguous()) + attn.w_vc = bind_or_assign(attn.w_vc, w_vc.contiguous()) + attn.use_deep_gemm_bmm = True + + +def process_mla_kv_b_proj_after_loading(attn: DeepseekV2MLAAttention) -> None: + """Process kv_b_proj weights after loading for sglang MLA mode. + + Orchestrates reading, quantization handling, and splitting of + kv_b_proj into absorbed w_kc / w_vc weights. + """ + _bind_non_absorbed_kv_b_proj(attn) + if _is_static_quark_mxfp4_kv_b_proj(attn.kv_b_proj) and not getattr( + attn.kv_b_proj, "_sgl_mxfp4_process_done", False + ): + attn.kv_b_proj.process_weights_after_loading() + + w = _read_kv_b_proj_weight(attn) + weight_block_size = _get_weight_block_size(attn) + + use_deep_gemm_bmm = False + block_scale = None + + if w.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + w, use_deep_gemm_bmm, block_scale = _process_fp8_weight( + attn, w, weight_block_size + ) + + if w.dtype == torch.int8: + w = _process_int8_weight(attn, w, weight_block_size) + + _split_and_assign_kc_vc(attn, w, use_deep_gemm_bmm, block_scale, weight_block_size) + + +def _patch_kv_b_proj_for_sglang_mxfp4(attn: DeepseekV2MLAAttention) -> None: + """Preserve DeepSeek MLA kv_b_proj's original MXFP4 layout for kc/vc split.""" + kv_b_proj = attn.kv_b_proj + if getattr(kv_b_proj, "_sgl_mxfp4_preserve_patched", False): + return + + orig_process_weights_after_loading = kv_b_proj.process_weights_after_loading + + def process_weights_after_loading_with_mxfp4_preserve(): + if getattr(kv_b_proj, "_sgl_mxfp4_process_done", False): + return + + if _is_static_quark_mxfp4_kv_b_proj(kv_b_proj): + kv_b_proj._mxfp4_unshuffled_weight = kv_b_proj.weight.detach().clone() + kv_b_proj._mxfp4_unshuffled_weight_scale = ( + kv_b_proj.weight_scale.detach().clone() + ) + orig_process_weights_after_loading() + kv_b_proj._sgl_mxfp4_process_done = True + + kv_b_proj.process_weights_after_loading = ( + process_weights_after_loading_with_mxfp4_preserve + ) + kv_b_proj._sgl_mxfp4_preserve_patched = True diff --git a/atom/plugin/sglang/models/deepseek_nextn_wrapper.py b/atom/plugin/sglang/models/deepseek_nextn_wrapper.py index 7ad85048c1..9b343a04d8 100644 --- a/atom/plugin/sglang/models/deepseek_nextn_wrapper.py +++ b/atom/plugin/sglang/models/deepseek_nextn_wrapper.py @@ -19,16 +19,11 @@ from atom.config import SpeculativeConfig from atom.plugin.config import generate_atom_config_for_plugin_mode -from atom.plugin.sglang.attention_backend.sgl_attention_mla import ( +from atom.plugin.sglang.models.deepseek_mla import ( setup_deepseek_for_sglang, ) -from atom.plugin.sglang.models.base_model_wrapper import ( - _current_forward_batch, - _is_dummy_forward, - _materialize_atom_dummy_forward, - _reset_sglang_forward_context, - _set_sglang_forward_context, - _trim_hidden_states_for_output, +from atom.plugin.sglang.runtime import ( + SGLangPluginRuntime, plugin_runtime_scope, ) @@ -78,7 +73,7 @@ def _retag_mtp_runtime_layer_ids(model: nn.Module) -> None: _set_runtime_layer_id(self_attn, local_layer_id) - for attr_name in ("mla_attn", "attn_mha"): + for attr_name in ("mla_attn", "attn_non_absorbed", "attn_mha"): attn_obj = getattr(self_attn, attr_name, None) if attn_obj is None: continue @@ -171,54 +166,28 @@ def forward( raise ValueError("DeepSeek MTP draft forward requires speculative info") with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): - if _is_dummy_forward(forward_batch): - ( - model_input_ids, - model_positions, - model_input_embeds, - model_forward_batch, - ) = _materialize_atom_dummy_forward( - input_ids, - positions, - input_embeds, - forward_batch, - ) - model_hidden_states = _materialize_dummy_hidden_states( - forward_batch.spec_info.hidden_states, - length=int(model_positions.shape[0]), - ) - else: - ( - model_input_ids, - model_positions, - model_input_embeds, - model_forward_batch, - ) = ( - input_ids, - positions, - input_embeds, - forward_batch, - ) + with SGLangPluginRuntime( + atom_config=self.atom_config, + forward_batch=forward_batch, + positions=positions, + input_ids=input_ids, + input_embeds=input_embeds, + ) as runtime: model_hidden_states = forward_batch.spec_info.hidden_states - - token = _current_forward_batch.set(model_forward_batch) - try: - _set_sglang_forward_context( - self.atom_config, model_forward_batch, model_positions - ) + if runtime.forward_batch is not forward_batch: + model_hidden_states = _materialize_dummy_hidden_states( + model_hidden_states, + length=int(runtime.positions.shape[0]), + ) hidden_states = self.model( - input_ids=model_input_ids, - positions=model_positions, + input_ids=runtime.input_ids, + positions=runtime.positions, hidden_states=model_hidden_states, - inputs_embeds=model_input_embeds, + inputs_embeds=runtime.input_embeds, ) - finally: - _reset_sglang_forward_context() - _current_forward_batch.reset(token) if self.pp_group.is_last_rank: - if _is_dummy_forward(forward_batch): - hidden_states = _trim_hidden_states_for_output(hidden_states, 0) + hidden_states = runtime.trim_output(hidden_states) return self.logits_processor( input_ids, hidden_states, diff --git a/atom/plugin/sglang/models/qwen3_5.py b/atom/plugin/sglang/models/qwen3_5.py index 22294b8548..e85af86a22 100644 --- a/atom/plugin/sglang/models/qwen3_5.py +++ b/atom/plugin/sglang/models/qwen3_5.py @@ -37,7 +37,7 @@ from atom.plugin.sglang.attention_backend.attention_gdn import ( SGLangGDNForwardContext, ) -from atom.plugin.sglang.models.base_model_wrapper import ( +from atom.plugin.sglang.runtime import ( SGLangForwardBatchMetadata, ) @@ -202,13 +202,13 @@ def __init__( prefix: str = "", ) -> None: del prefix - import atom + from atom.plugin.sglang.prepare import prepare_model nn.Module.__init__(self) root_config = type(self)._pending_vlm_root_config if root_config is None: root_config = config - atom_lm = atom.prepare_model(config=root_config, engine="sglang") + atom_lm = prepare_model(config=root_config) if atom_lm is None: arch = getattr(root_config, "architectures", ["unknown"])[0] raise ValueError(f"ATOM failed to build language model for {arch}") @@ -448,5 +448,5 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # SGLang discovers these multimodal wrappers from this module's `EntryClass`. # They are not covered by `base_model_wrapper.py`, whose generated entries only -# handle the plain causal-LM architectures in `_MODEL_ARCH_SPECS`. +# handle the plain causal-LM architectures in `MODEL_ARCH_SPECS`. EntryClass = [Qwen3_5ForConditionalGeneration, Qwen3_5MoeForConditionalGeneration] diff --git a/atom/plugin/sglang/prepare.py b/atom/plugin/sglang/prepare.py new file mode 100644 index 0000000000..a51f5f2f44 --- /dev/null +++ b/atom/plugin/sglang/prepare.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import logging +from typing import Any + +from atom.plugin.prepare import _set_framework_backbone + +logger = logging.getLogger("atom") + + +def prepare_model(config: Any): + """Prepare an ATOM model for SGLang plugin mode.""" + logger.info("Prepare model for plugin mode, the upper engine is sglang") + _set_framework_backbone("sglang") + + model_arch = config.architectures[0] + + # Import here to avoid partial initialization while SGLang discovers models. + from atom.plugin.register import ( + _ATOM_SUPPORTED_MODELS, + init_aiter_dist, + register_ops_to_sglang, + set_attn_cls, + ) + + if model_arch not in _ATOM_SUPPORTED_MODELS: + supported_archs = list(_ATOM_SUPPORTED_MODELS.keys()) + raise ValueError( + f"ATOM does not support the required model architecture: {model_arch}. " + f"For now supported model architectures: {supported_archs}" + ) + + from atom.plugin.config import generate_atom_config_for_plugin_mode + + atom_config = generate_atom_config_for_plugin_mode(config) + + model_cls = _ATOM_SUPPORTED_MODELS[model_arch] + logger.info("ATOM model class for %s is %s", model_arch, model_cls) + + from atom.plugin.sglang.runtime import get_model_arch_spec + + model_adapter = get_model_arch_spec(model_arch) + if model_adapter.prepare_config is not None: + model_adapter.prepare_config(atom_config, model_arch) + + register_ops_to_sglang(atom_config=atom_config) + set_attn_cls() + + # Init aiter dist for using aiter custom collective ops. + init_aiter_dist(config=atom_config) + + # Patch SGLang graph_capture to also enter aiter's ca_comm.capture(), + # avoiding hipMemcpyAsync in aiter collectives when model uses aiter's + # custom all_reduce (same fix as atom/plugin/vllm/graph_capture_patch.py). + from atom.plugin.sglang.graph_capture_patch import apply_graph_capture_patch + + apply_graph_capture_patch() + + try: + model = model_cls(atom_config=atom_config) + except TypeError as exc: + # Some SGLang plugin models keep SGLang's native wrapper constructor + # and only swap their internal language_model with an ATOM model. + # Those classes accept `config=...` instead of `atom_config=...`. + if "atom_config" not in str(exc): + raise + model = model_cls(config=config) + if not hasattr(model, "atom_config"): + model.atom_config = atom_config + return model + + +def prepare_model_for_sglang(config: Any): + """Backward-compatible alias for SGLang plugin model preparation.""" + return prepare_model(config) diff --git a/atom/plugin/sglang/runtime/__init__.py b/atom/plugin/sglang/runtime/__init__.py new file mode 100644 index 0000000000..fa8468c177 --- /dev/null +++ b/atom/plugin/sglang/runtime/__init__.py @@ -0,0 +1,27 @@ +"""Runtime utilities for ATOM's SGLang plugin integration.""" + +from atom.plugin.sglang.runtime.context import ( + SGLangForwardBatchMetadata, + bind_current_forward_batch, + get_current_forward_batch, + plugin_runtime_scope, +) +from atom.plugin.sglang.runtime.forward_context import SGLangPluginRuntime +from atom.plugin.sglang.runtime.model_arch import ( + MODEL_ADAPTER_SPECS, + MODEL_ARCH_SPECS, + SGLangModelAdapterSpec, + get_model_arch_spec, +) + +__all__ = [ + "MODEL_ADAPTER_SPECS", + "MODEL_ARCH_SPECS", + "SGLangForwardBatchMetadata", + "SGLangModelAdapterSpec", + "SGLangPluginRuntime", + "bind_current_forward_batch", + "get_current_forward_batch", + "get_model_arch_spec", + "plugin_runtime_scope", +] diff --git a/atom/plugin/sglang/runtime/context.py b/atom/plugin/sglang/runtime/context.py new file mode 100644 index 0000000000..1fffae92f2 --- /dev/null +++ b/atom/plugin/sglang/runtime/context.py @@ -0,0 +1,126 @@ +"""Runtime context helpers for ATOM's SGLang plugin path.""" + +from __future__ import annotations + +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass +from typing import ClassVar, Optional, Union + +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors + +_RUNTIME_SENTINEL = object() +_current_forward_batch: ContextVar[Optional[ForwardBatch]] = ContextVar( + "atom_sglang_current_forward_batch", default=None +) + + +def get_current_forward_batch(): + return _current_forward_batch.get() + + +@contextmanager +def bind_current_forward_batch(forward_batch: Optional[ForwardBatch]): + token = _current_forward_batch.set(forward_batch) + try: + yield + finally: + _current_forward_batch.reset(token) + + +@contextmanager +def plugin_runtime_scope( + *, + framework: Optional[str] = None, + atom_config=_RUNTIME_SENTINEL, +): + """Temporarily bind process-global ATOM plugin runtime state. + + SGLang target/draft wrappers can coexist during speculative decoding, while + ATOM core still reads process-global framework/config state in some paths. + Keep those globals scoped to one wrapper call and restore them afterwards. + """ + + import atom.config as atom_config_module + import atom.plugin.prepare as plugin_prepare + + prev_framework = plugin_prepare._CURRENT_FRAMEWORK + prev_atom_config = getattr(atom_config_module, "_current_atom_config", None) + + if framework is not None: + plugin_prepare._set_framework_backbone(framework) + if atom_config is not _RUNTIME_SENTINEL: + atom_config_module._current_atom_config = atom_config + + try: + yield + finally: + plugin_prepare._CURRENT_FRAMEWORK = prev_framework + atom_config_module._current_atom_config = prev_atom_config + + +@dataclass(frozen=True) +class SGLangForwardBatchMetadata: + """Small context object for one SGLang model forward.""" + + forward_batch: Optional[ForwardBatch] + pp_proxy_tensors: Optional[PPProxyTensors] = None + save_kv_cache: bool = True + _current: ClassVar[ContextVar[Optional["SGLangForwardBatchMetadata"]]] = ContextVar( + "atom_sglang_current_forward_batch_metadata", + default=None, + ) + + @classmethod + def current(cls) -> Optional["SGLangForwardBatchMetadata"]: + return cls._current.get() + + @classmethod + def build( + cls, + forward_batch: Optional[ + Union[ForwardBatch, "SGLangForwardBatchMetadata"] + ] = None, + *, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + save_kv_cache: Optional[bool] = None, + ) -> Optional["SGLangForwardBatchMetadata"]: + if isinstance(forward_batch, cls): + return forward_batch + if forward_batch is None and pp_proxy_tensors is None and save_kv_cache is None: + return cls.current() + return cls( + forward_batch=forward_batch, + pp_proxy_tensors=pp_proxy_tensors, + save_kv_cache=True if save_kv_cache is None else save_kv_cache, + ) + + @classmethod + @contextmanager + def bind(cls, metadata: Optional["SGLangForwardBatchMetadata"]): + meta_token = cls._current.set(metadata) + batch_token = _current_forward_batch.set( + None if metadata is None else metadata.forward_batch + ) + try: + yield metadata + finally: + _current_forward_batch.reset(batch_token) + cls._current.reset(meta_token) + + @staticmethod + def to_intermediate_tensors( + intermediate_tensors, + metadata: Optional["SGLangForwardBatchMetadata"], + ): + if intermediate_tensors is not None or metadata is None: + return intermediate_tensors + pp_proxy_tensors = metadata.pp_proxy_tensors + if pp_proxy_tensors is None: + return intermediate_tensors + tensors = getattr(pp_proxy_tensors, "tensors", None) + if tensors is None: + return intermediate_tensors + from atom.models.utils import IntermediateTensors + + return IntermediateTensors(dict(tensors)) diff --git a/atom/plugin/sglang/runtime/forward_context.py b/atom/plugin/sglang/runtime/forward_context.py new file mode 100644 index 0000000000..05939cc7cb --- /dev/null +++ b/atom/plugin/sglang/runtime/forward_context.py @@ -0,0 +1,234 @@ +"""Scoped runtime adapter from SGLang batches to ATOM core.""" + +from __future__ import annotations + +import copy +from contextlib import ExitStack +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +from atom.plugin.sglang.runtime.context import bind_current_forward_batch + + +def _is_dummy_forward(forward_batch: ForwardBatch) -> bool: + """Return whether an SGLang batch represents an empty/idle dummy run.""" + + forward_mode = getattr(forward_batch, "forward_mode", None) + return bool( + forward_mode is not None + and hasattr(forward_mode, "is_idle") + and forward_mode.is_idle() + ) + + +def _pad_dummy_like( + tensor: Optional[torch.Tensor], + *, + length: int, + fill_value: int | float = 0, +) -> Optional[torch.Tensor]: + if tensor is None: + return None + shape = (length, *tensor.shape[1:]) + return torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device) + + +def _materialize_atom_dummy_forward( + input_ids: Optional[torch.Tensor], + positions: Optional[torch.Tensor], + input_embeds: Optional[torch.Tensor], + forward_batch: ForwardBatch, +) -> tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + ForwardBatch, +]: + """Convert an empty SGLang IDLE batch into ATOM-style dummy inputs.""" + + if positions is None: + raise RuntimeError("SGLang dummy forward materialization requires positions") + if input_ids is None: + raise RuntimeError("SGLang dummy forward materialization requires input_ids") + + dummy_positions = positions.new_zeros((1,)) + dummy_input_ids = input_ids.new_zeros((1,)) + dummy_input_embeds = _pad_dummy_like(input_embeds, length=1, fill_value=0) + + model_forward_batch = copy.copy(forward_batch) + model_forward_batch.positions = dummy_positions + model_forward_batch.batch_size = 1 + model_forward_batch.seq_lens_sum = 1 + model_forward_batch.seq_lens = forward_batch.seq_lens.new_ones((1,)) + model_forward_batch.seq_lens_cpu = forward_batch.seq_lens_cpu.new_ones((1,)) + + return dummy_input_ids, dummy_positions, dummy_input_embeds, model_forward_batch + + +def _trim_hidden_states_for_output(hidden_states, num_tokens: int): + if torch.is_tensor(hidden_states): + return hidden_states[:num_tokens] + if isinstance(hidden_states, tuple): + return tuple( + tensor[:num_tokens] if torch.is_tensor(tensor) else tensor + for tensor in hidden_states + ) + return hidden_states + + +def _resolve_num_tokens_across_dp( + atom_config: Any, + forward_batch: ForwardBatch, + num_tokens: int, + is_dummy_run: bool, +) -> torch.Tensor: + """Resolve per-DP token counts for ATOM's CPU-side DPMetadata.""" + + global_num_tokens_cpu = getattr(forward_batch, "global_num_tokens_cpu", None) + if global_num_tokens_cpu is not None: + num_tokens_across_dp = torch.tensor( + global_num_tokens_cpu, dtype=torch.int32, device="cpu" + ) + else: + dp_size = atom_config.parallel_config.data_parallel_size + global_num_tokens_gpu = getattr(forward_batch, "global_num_tokens_gpu", None) + global_dp_buffer_len = getattr(forward_batch, "global_dp_buffer_len", None) + is_static_same_shape_batch = ( + global_num_tokens_gpu is not None + and global_dp_buffer_len == num_tokens * dp_size + ) + if not is_static_same_shape_batch: + raise RuntimeError( + "[SGL+ATOM] SGLang dp-attention requires " + "forward_batch.global_num_tokens_cpu unless the batch uses static " + "same-shape DP metadata." + ) + + # Static batches, such as CUDA graph capture batches, may only keep + # global token counts on GPU. Avoid GPU-to-CPU reads here and mirror + # their same-shape layout directly for ATOM's CPU DPMetadata. + num_tokens_across_dp = torch.full( + (dp_size,), num_tokens, dtype=torch.int32, device="cpu" + ) + + if is_dummy_run: + # SGLang reports idle ranks as 0 tokens, but ATOM materializes them + # as one local dummy token so collectives and DPMetadata stay aligned. + dp_rank = atom_config.parallel_config.data_parallel_rank + num_tokens_across_dp[dp_rank] = num_tokens + return num_tokens_across_dp + + +def _set_atom_forward_context( + atom_config: Any, + forward_batch: ForwardBatch, + positions: torch.Tensor, +) -> None: + """Bridge SGLang batch metadata into ATOM's global forward context.""" + + from atom.utils.forward_context import ( + AttentionMetaData, + Context, + set_forward_context, + ) + + forward_mode = forward_batch.forward_mode + # This value is only used by ATOM-side MoE padding in the SGLang wrapper. + max_seqlen_q = 1 if forward_mode.is_decode_or_idle() else 0 + attn_metadata = AttentionMetaData(max_seqlen_q=max_seqlen_q) + batch_size = int(forward_batch.batch_size) + is_dummy_run = _is_dummy_forward(forward_batch) + is_prefill = forward_mode.is_prefill() + num_tokens = int(positions.shape[0]) + + if bool(atom_config.enable_dp_attention): + num_tokens_across_dp = _resolve_num_tokens_across_dp( + atom_config, forward_batch, num_tokens, is_dummy_run + ) + graph_bs = int(torch.max(num_tokens_across_dp).item()) + else: + num_tokens_across_dp = None + graph_bs = num_tokens if is_prefill else batch_size + + context = Context( + positions=positions, + is_prefill=is_prefill, + is_dummy_run=is_dummy_run, + batch_size=batch_size, + graph_bs=graph_bs, + ) + set_forward_context( + attn_metadata=attn_metadata, + atom_config=atom_config, + context=context, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + ) + + +def _reset_atom_forward_context() -> None: + from atom.utils.forward_context import reset_forward_context + + reset_forward_context() + + +@dataclass +class SGLangPluginRuntime: + """Scoped adapter for running ATOM model code under SGLang plugin runtime. + + The adapter owns the temporary translation from SGLang's ``ForwardBatch`` to + ATOM's process-local runtime state. Callers should use the normalized + ``input_ids``, ``positions``, ``input_embeds``, and ``forward_batch`` exposed + by this object while inside the context. + """ + + atom_config: Any + forward_batch: ForwardBatch + positions: torch.Tensor + input_ids: Optional[torch.Tensor] = None + input_embeds: Optional[torch.Tensor] = None + set_forward_context: bool = True + _original_forward_batch: ForwardBatch = field(init=False, repr=False) + _is_dummy_run: bool = field(init=False, default=False) + _exit_stack: ExitStack = field(init=False, repr=False) + + def __enter__(self) -> "SGLangPluginRuntime": + self._original_forward_batch = self.forward_batch + self._is_dummy_run = _is_dummy_forward(self.forward_batch) + + if self._is_dummy_run: + ( + self.input_ids, + self.positions, + self.input_embeds, + self.forward_batch, + ) = _materialize_atom_dummy_forward( + self.input_ids, + self.positions, + self.input_embeds, + self.forward_batch, + ) + + self._exit_stack = ExitStack() + self._exit_stack.enter_context(bind_current_forward_batch(self.forward_batch)) + if self.set_forward_context: + _set_atom_forward_context( + self.atom_config, + self.forward_batch, + self.positions, + ) + self._exit_stack.callback(_reset_atom_forward_context) + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self._exit_stack.close() + + def trim_output(self, hidden_states): + """Map ATOM-visible outputs back to SGLang-visible token count.""" + + if self._is_dummy_run: + return _trim_hidden_states_for_output(hidden_states, 0) + return hidden_states diff --git a/atom/plugin/sglang/runtime/model_arch.py b/atom/plugin/sglang/runtime/model_arch.py new file mode 100644 index 0000000000..f103cb9427 --- /dev/null +++ b/atom/plugin/sglang/runtime/model_arch.py @@ -0,0 +1,65 @@ +"""SGLang plugin model adapter registry.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Optional + + +@dataclass(frozen=True) +class SGLangModelAdapterSpec: + """Adapter hooks for one SGLang plugin model architecture. + + The first version keeps the existing runtime flags while adding function + hooks for config preparation and install-time model adaptation. This avoids + growing a long list of booleans in the generic wrapper as new models arrive. + """ + + wrapper_binds_gdn_context: bool = False + prepare_config: Optional[Callable[[Any, str], None]] = None + install_adapters: Optional[Callable[[Any], None]] = None + + +def _prepare_qwen35_config(atom_config: Any, model_arch: str) -> None: + from atom.plugin.sglang.models.qwen3_5 import apply_prepare_model_adaptations + + apply_prepare_model_adaptations(atom_config, model_arch) + + +def _install_deepseek_mla_adapters(model: Any) -> None: + from atom.plugin.sglang.models.deepseek_mla import setup_deepseek_for_sglang + + setup_deepseek_for_sglang(model) + + +MODEL_ADAPTER_SPECS = { + "DeepseekV3ForCausalLM": SGLangModelAdapterSpec( + install_adapters=_install_deepseek_mla_adapters, + ), + "Qwen3MoeForCausalLM": SGLangModelAdapterSpec(), + "Qwen3NextForCausalLM": SGLangModelAdapterSpec( + wrapper_binds_gdn_context=True, + ), + "Qwen3_5ForConditionalGeneration": SGLangModelAdapterSpec( + prepare_config=_prepare_qwen35_config, + ), + "Qwen3_5MoeForConditionalGeneration": SGLangModelAdapterSpec( + prepare_config=_prepare_qwen35_config, + ), +} + +# Architectures whose SGLang EntryClass is generated by base_model_wrapper. +# Custom outer-wrapper modules, such as Qwen3.5 multimodal wrappers, keep their +# own EntryClass and should not appear here or SGLang will see duplicate classes. +MODEL_ARCH_SPECS = { + key: MODEL_ADAPTER_SPECS[key] + for key in ( + "DeepseekV3ForCausalLM", + "Qwen3MoeForCausalLM", + "Qwen3NextForCausalLM", + ) +} + + +def get_model_arch_spec(model_arch: str) -> SGLangModelAdapterSpec: + return MODEL_ADAPTER_SPECS.get(model_arch, SGLangModelAdapterSpec()) diff --git a/tests/plugin/test_sglang_model_wrapper.py b/tests/plugin/test_sglang_model_wrapper.py index e4015ed9dc..20d0e07923 100644 --- a/tests/plugin/test_sglang_model_wrapper.py +++ b/tests/plugin/test_sglang_model_wrapper.py @@ -54,8 +54,7 @@ def _make_fake_modules(*, is_last_rank: bool, setup_hook=None) -> dict[str, Modu forward_batch_mod.ForwardBatch = object forward_batch_mod.PPProxyTensors = object - attn_backend_pkg = _package("atom.plugin.sglang.attention_backend") - mla_mod = ModuleType("atom.plugin.sglang.attention_backend.sgl_attention_mla") + mla_mod = ModuleType("atom.plugin.sglang.models.deepseek_mla") mla_mod.setup_deepseek_for_sglang = setup_hook or (lambda model: None) return { @@ -68,8 +67,7 @@ def _make_fake_modules(*, is_last_rank: bool, setup_hook=None) -> dict[str, Modu "sglang.srt.layers.quantization.base_config": quant_base_mod, "sglang.srt.model_executor": model_executor_pkg, "sglang.srt.model_executor.forward_batch_info": forward_batch_mod, - "atom.plugin.sglang.attention_backend": attn_backend_pkg, - "atom.plugin.sglang.attention_backend.sgl_attention_mla": mla_mod, + "atom.plugin.sglang.models.deepseek_mla": mla_mod, } diff --git a/tests/plugin/test_sglang_prepare_hooks.py b/tests/plugin/test_sglang_prepare_hooks.py index cff778076b..04376cfa36 100644 --- a/tests/plugin/test_sglang_prepare_hooks.py +++ b/tests/plugin/test_sglang_prepare_hooks.py @@ -11,7 +11,8 @@ import pytest -from atom.plugin import prepare as plugin_prepare +from atom.plugin import prepare as plugin_runtime +from atom.plugin.sglang import prepare as sglang_prepare class _Obj: @@ -44,27 +45,32 @@ def _module(name: str, **attrs) -> ModuleType: return module +def _make_fake_runtime_module(model_arch: str, prepare_config): + module = ModuleType("atom.plugin.sglang.runtime") + module.get_model_arch_spec = MagicMock( + return_value=_Obj(prepare_config=prepare_config) + ) + return module + + @pytest.fixture(autouse=True) def _reset_framework_state(): - plugin_prepare._set_framework_backbone("atom") + plugin_runtime._set_framework_backbone("atom") yield - plugin_prepare._set_framework_backbone("atom") + plugin_runtime._set_framework_backbone("atom") @pytest.mark.parametrize( - "model_arch,expect_register_ops", + "model_arch", ( - ("Qwen3_5ForConditionalGeneration", False), - ("Qwen3_5MoeForConditionalGeneration", False), - ("Qwen3NextForCausalLM", False), - ("DeepseekV3ForCausalLM", True), - ("Qwen3MoeForCausalLM", True), + "Qwen3_5ForConditionalGeneration", + "Qwen3_5MoeForConditionalGeneration", + "Qwen3NextForCausalLM", + "DeepseekV3ForCausalLM", + "Qwen3MoeForCausalLM", ), ) -def test_prepare_model_register_ops_gate( - model_arch: str, - expect_register_ops: bool, -): +def test_prepare_model_register_ops_gate(model_arch: str): fake_atom_config = _Obj(plugin_config=_Obj(is_plugin_mode=True)) fake_register, _fake_model, fake_model_cls = _make_fake_register_module(model_arch) fake_config_mod = MagicMock() @@ -75,29 +81,34 @@ def test_prepare_model_register_ops_gate( "atom.plugin.sglang.models.qwen3_5", apply_prepare_model_adaptations=MagicMock(), ) + prepare_config = ( + fake_qwen35_mod.apply_prepare_model_adaptations + if model_arch + in { + "Qwen3_5ForConditionalGeneration", + "Qwen3_5MoeForConditionalGeneration", + } + else None + ) + fake_runtime_mod = _make_fake_runtime_module(model_arch, prepare_config) with patch.dict( sys.modules, { "atom.plugin.register": fake_register, "atom.plugin.config": fake_config_mod, + "atom.plugin.sglang.runtime": fake_runtime_mod, "atom.plugin.sglang.models.qwen3_5": fake_qwen35_mod, "atom.plugin.sglang.graph_capture_patch": MagicMock( apply_graph_capture_patch=MagicMock() ), }, ): - plugin_prepare.prepare_model( - config=_Obj(architectures=[model_arch]), - engine="sglang", - ) + sglang_prepare.prepare_model(config=_Obj(architectures=[model_arch])) - if expect_register_ops: - fake_register.register_ops_to_sglang.assert_called_once_with( - atom_config=fake_atom_config - ) - else: - fake_register.register_ops_to_sglang.assert_not_called() + fake_register.register_ops_to_sglang.assert_called_once_with( + atom_config=fake_atom_config + ) if model_arch in { "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration", diff --git a/tests/plugin/test_sglang_prepare_model.py b/tests/plugin/test_sglang_prepare_model.py index a9ae0a1851..b5f12b672a 100644 --- a/tests/plugin/test_sglang_prepare_model.py +++ b/tests/plugin/test_sglang_prepare_model.py @@ -13,7 +13,8 @@ import pytest from unittest.mock import MagicMock, patch -from atom.plugin import prepare as plugin_prepare +from atom.plugin import prepare as plugin_runtime +from atom.plugin.sglang import prepare as sglang_prepare class _Obj: @@ -26,9 +27,9 @@ def __init__(self, **kwargs): @pytest.fixture(autouse=True) def _reset_framework_state(): - plugin_prepare._set_framework_backbone("atom") + plugin_runtime._set_framework_backbone("atom") yield - plugin_prepare._set_framework_backbone("atom") + plugin_runtime._set_framework_backbone("atom") def _make_fake_register_module(model_dict=None): @@ -41,35 +42,37 @@ def _make_fake_register_module(model_dict=None): return mod +def _make_fake_runtime_module(): + mod = MagicMock() + mod.get_model_arch_spec = MagicMock(return_value=_Obj(prepare_config=None)) + return mod + + # --------------------------------------------------------------------------- # Engine / architecture validation # --------------------------------------------------------------------------- -def test_prepare_model_rejects_unsupported_engine(): - """Unsupported engine should raise ValueError from _set_framework_backbone.""" - config = _Obj(architectures=["SomeModel"]) - with pytest.raises(ValueError, match="Unsupported framework"): - plugin_prepare.prepare_model(config=config, engine="tensorflow") - - -def test_prepare_model_rejects_non_sglang_engine_gracefully(): - """vllm engine currently not supported in prepare_model (only sglang path).""" - config = _Obj(architectures=["Qwen3ForCausalLM"]) - with pytest.raises(ValueError, match="does not support engine"): - plugin_prepare.prepare_model(config=config, engine="vllm") - - def test_prepare_model_rejects_unsupported_architecture(): - """Known engine but unknown arch should raise ValueError.""" + """Unknown architecture should raise ValueError from the SGLang prepare path.""" fake_register = _make_fake_register_module( model_dict={"DeepseekV3ForCausalLM": MagicMock()} ) + fake_runtime = _make_fake_runtime_module() - with patch.dict(sys.modules, {"atom.plugin.register": fake_register}): + with patch.dict( + sys.modules, + { + "atom.plugin.register": fake_register, + "atom.plugin.sglang.runtime": fake_runtime, + "atom.plugin.sglang.graph_capture_patch": MagicMock( + apply_graph_capture_patch=MagicMock() + ), + }, + ): config = _Obj(architectures=["TotallyFakeModelArch"]) with pytest.raises(ValueError, match="does not support"): - plugin_prepare.prepare_model(config=config, engine="sglang") + sglang_prepare.prepare_model(config=config) # --------------------------------------------------------------------------- @@ -86,6 +89,7 @@ def test_prepare_model_sglang_happy_path(): fake_register = _make_fake_register_module( model_dict={"DeepseekV3ForCausalLM": fake_model_cls} ) + fake_runtime = _make_fake_runtime_module() mock_gen_config = MagicMock(return_value=fake_atom_config) fake_config_mod = MagicMock() @@ -96,10 +100,14 @@ def test_prepare_model_sglang_happy_path(): { "atom.plugin.register": fake_register, "atom.plugin.config": fake_config_mod, + "atom.plugin.sglang.runtime": fake_runtime, + "atom.plugin.sglang.graph_capture_patch": MagicMock( + apply_graph_capture_patch=MagicMock() + ), }, ): config = _Obj(architectures=["DeepseekV3ForCausalLM"]) - result = plugin_prepare.prepare_model(config=config, engine="sglang") + result = sglang_prepare.prepare_model(config=config) # Config generation called mock_gen_config.assert_called_once_with(config) @@ -126,6 +134,7 @@ def test_prepare_model_selects_sglang_dict_for_deepseek_v2(): fake_register = _make_fake_register_module( model_dict={"DeepseekV2ForCausalLM": fake_model_cls} ) + fake_runtime = _make_fake_runtime_module() fake_config_mod = MagicMock() fake_config_mod.generate_atom_config_for_plugin_mode = MagicMock( return_value=fake_atom_config @@ -136,10 +145,14 @@ def test_prepare_model_selects_sglang_dict_for_deepseek_v2(): { "atom.plugin.register": fake_register, "atom.plugin.config": fake_config_mod, + "atom.plugin.sglang.runtime": fake_runtime, + "atom.plugin.sglang.graph_capture_patch": MagicMock( + apply_graph_capture_patch=MagicMock() + ), }, ): config = _Obj(architectures=["DeepseekV2ForCausalLM"]) - result = plugin_prepare.prepare_model(config=config, engine="sglang") + result = sglang_prepare.prepare_model(config=config) assert result is fake_model @@ -152,6 +165,7 @@ def test_prepare_model_sets_framework_to_sglang(): fake_register = _make_fake_register_module( model_dict={"DeepseekV3ForCausalLM": fake_model_cls} ) + fake_runtime = _make_fake_runtime_module() fake_config_mod = MagicMock() fake_config_mod.generate_atom_config_for_plugin_mode = MagicMock( return_value=fake_atom_config @@ -162,10 +176,14 @@ def test_prepare_model_sets_framework_to_sglang(): { "atom.plugin.register": fake_register, "atom.plugin.config": fake_config_mod, + "atom.plugin.sglang.runtime": fake_runtime, + "atom.plugin.sglang.graph_capture_patch": MagicMock( + apply_graph_capture_patch=MagicMock() + ), }, ): config = _Obj(architectures=["DeepseekV3ForCausalLM"]) - plugin_prepare.prepare_model(config=config, engine="sglang") + sglang_prepare.prepare_model(config=config) - assert plugin_prepare.is_sglang() is True - assert plugin_prepare.is_plugin_mode() is True + assert plugin_runtime.is_sglang() is True + assert plugin_runtime.is_plugin_mode() is True diff --git a/tests/plugin/test_sglang_register.py b/tests/plugin/test_sglang_register.py index 562aaf8bc8..1adb20fb42 100644 --- a/tests/plugin/test_sglang_register.py +++ b/tests/plugin/test_sglang_register.py @@ -324,10 +324,13 @@ def __init__(self, runner): "atom.models.qwen3_moe": ModuleType("atom.models.qwen3_moe"), "atom.models.glm4_moe": ModuleType("atom.models.glm4_moe"), "atom.models.deepseek_v2": ModuleType("atom.models.deepseek_v2"), + "atom.models.minimax_m2": ModuleType("atom.models.minimax_m2"), + "atom.models.qwen3_next": ModuleType("atom.models.qwen3_next"), + "atom.models.qwen3_5": ModuleType("atom.models.qwen3_5"), "atom.config": ModuleType("atom.config"), "atom.plugin.prepare": fake_prepare_mod, - "atom.plugin.sglang.attention_backend.sgl_attn_backend": ModuleType( - "atom.plugin.sglang.attention_backend.sgl_attn_backend" + "atom.plugin.sglang.attention_backend.full_attention.full_attention_backend": ModuleType( + "atom.plugin.sglang.attention_backend.full_attention.full_attention_backend" ), } fake_modules["atom.models.qwen3"].Qwen3ForCausalLM = type( @@ -342,9 +345,21 @@ def __init__(self, runner): fake_modules["atom.models.deepseek_v2"].DeepseekV3ForCausalLM = type( "DeepseekV3ForCausalLM", (), {} ) + fake_modules["atom.models.minimax_m2"].MiniMaxM2ForCausalLM = type( + "MiniMaxM2ForCausalLM", (), {} + ) + fake_modules["atom.models.qwen3_next"].Qwen3NextForCausalLM = type( + "Qwen3NextForCausalLM", (), {} + ) + fake_modules["atom.models.qwen3_5"].Qwen3_5ForCausalLM = type( + "Qwen3_5ForCausalLM", (), {} + ) + fake_modules["atom.models.qwen3_5"].Qwen3_5MoeForCausalLM = type( + "Qwen3_5MoeForCausalLM", (), {} + ) fake_modules["atom.config"].Config = type("Config", (), {}) fake_modules[ - "atom.plugin.sglang.attention_backend.sgl_attn_backend" + "atom.plugin.sglang.attention_backend.full_attention.full_attention_backend" ].ATOMAttnBackendForSgl = _FakeBackend with patch.dict(sys.modules, fake_modules):