From 1aada4becf5b05b3fb3d22ff50f4e5ddb02594e9 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Mon, 11 May 2026 10:16:06 +0000 Subject: [PATCH 01/17] add work log --- .../2026-04-30-attn-refractory.md | 354 +++++++++++++ ...26-05-07-radix-cache-vs-paged-attention.md | 482 ++++++++++++++++++ ...26-05-08-kv-indices-address-space-notes.md | 359 +++++++++++++ .../2026-05-08-metadata-layering-summary.md | 144 ++++++ ...6-05-11-attn-refractory-session-summary.md | 398 +++++++++++++++ .../attn_refractory/decoupling-diagram.md | 158 ++++++ .../vllm-integration-architecture.md | 269 ++++++++++ 7 files changed, 2164 insertions(+) create mode 100644 work_log/attn_refractory/2026-04-30-attn-refractory.md create mode 100644 work_log/attn_refractory/2026-05-07-radix-cache-vs-paged-attention.md create mode 100644 work_log/attn_refractory/2026-05-08-kv-indices-address-space-notes.md create mode 100644 work_log/attn_refractory/2026-05-08-metadata-layering-summary.md create mode 100644 work_log/attn_refractory/2026-05-11-attn-refractory-session-summary.md create mode 100644 work_log/attn_refractory/decoupling-diagram.md create mode 100644 work_log/attn_refractory/vllm-integration-architecture.md diff --git a/work_log/attn_refractory/2026-04-30-attn-refractory.md b/work_log/attn_refractory/2026-04-30-attn-refractory.md new file mode 100644 index 0000000000..f0946299a9 --- /dev/null +++ b/work_log/attn_refractory/2026-04-30-attn-refractory.md @@ -0,0 +1,354 @@ +# 2026-04-30 ATOM Attention Refactory Session Summary + +## 1. 本次会话的核心目标 + +本次连续讨论的主线是围绕 `ATOM plugin` 中 attention backend 的重构方向展开,重点包括: + +- 调研当前 `ATOM plugin` attention backend 的架构问题 +- 评估 `vLLM plugin` 与 `SGLang plugin` 的耦合程度 +- 分析 SGLang plugin 在支持新模型、特别是非传统 MHA/MLA 类型 backend 时会遇到的困难 +- 讨论如何在 plugin 层把 `vLLM` 与 `SGLang` 解耦 +- 讨论 `SGLang` 中 `hybrid/composite backend` 的接入方式 +- 评估三种模式共享 ATOM attention 实现的难度: + - `ATOM native/server` + - `ATOM vLLM plugin` + - `ATOM SGLang plugin` + +## 2. 对当前 ATOM plugin attention 架构的主要判断 + +### 2.1 当前不是一个真正统一的 attention backend 架构 + +本次调研得到的一个核心结论是: + +- `vLLM plugin` 和 `SGLang plugin` 虽然都叫 “ATOM attention” +- 但它们并不是通过同一套干净的 backend contract 接进来的 +- 更准确地说,是两套宿主接入模型外加多层 patch / adapter / runtime glue 共同构成的系统 + +### 2.2 当前的主要问题不是底层 kernel,而是 runtime / ownership / patch 层次 + +现有问题主要集中在: + +- 全局状态过重 + - `_CURRENT_FRAMEWORK` + - `current_atom_config` + - `ops.Attention` +- vLLM 与 SGLang 的接入模型差异很大,但还共享 `plugin` 根目录里一批伪公共逻辑 +- 很多扩展点依赖 monkey patch 和 runtime override +- SGLang 侧对 DeepSeek MLA 已经上卷到 model-level patch,而不是单纯 backend 替换 + +### 2.3 attention 文件结构的拆分轴不统一 + +本次对 `plugin` 目录下 attention 相关文件的解读认为,目前至少混杂了三层不同对象: + +- host/runtime glue +- backend runtime core +- model-specific specialization + +尤其在 `SGLang` 侧: + +- `radix_attention.py` 更像 adapter +- `sgl_attn_backend.py` 更像 full-attn runtime backend core +- `sgl_attention_mla.py` 更像 DeepSeek MLA specialization + +而在根目录: + +- `attention.py` +- `attention_mha.py` +- `attention_mla.py` +- `attention_mla_sparse.py` + +虽然放在 `plugin/` 根目录,但职责上明显更接近 `vLLM plugin` 的 glue / patch 层,而不是真正共享的 plugin runtime core。 + +## 3. 关于 vLLM plugin 与 SGLang plugin 的耦合 + +### 3.1 结论:耦合度高 + +这次讨论中对二者耦合的判断是: + +- 不是轻量共享几个工具函数 +- 而是共享了一套 runtime state、attention 抽象、selector 和 plugin-mode 判断方式 + +关键耦合点包括: + +- `atom/plugin/prepare.py` +- `atom/plugin/register.py` +- `atom/plugin/config.py` +- `atom.plugin.prepare.is_vllm()/is_sglang()/is_plugin_mode()` +- `ops.Attention` 的全局切换 + +### 3.2 关键判断 + +真正的问题不是 “目录没有拆开”,而是: + +**host runtime 差异没有被限制在 adapter 边界,而是泄漏进了 backend 选择、metadata 组织和全局状态模型。** + +## 4. 关于 plugin 层解耦的共识 + +### 4.1 plugin 层不应该继续保留共享 runtime + +讨论最终倾向于一个更激进但更干净的方向: + +- `atom/plugin/vllm/` 是一个完整子系统 +- `atom/plugin/sglang/` 是另一个完整子系统 +- `atom/plugin/` 根目录不再承担共享 runtime 中心的角色 + +共享只保留给更底层的 `ATOM core`,例如: + +- `model_ops` +- kernels +- loader +- metadata helpers +- model-family specialization 中真正 host-agnostic 的部分 + +### 4.2 “bootstrap” 的定义 + +本次还明确了 `bootstrap` 在本讨论语境中的含义: + +- 不是核心执行逻辑 +- 而是插件被宿主框架加载时的最早期接线/初始化层 + +也就是: + +- 注册扩展点 +- 安装 patch +- 选择 wrapper / adapter / backend +- 建立 host 与 plugin 的连接关系 + +## 5. 已经做过的代码原型 + +本次会话中,为了验证“按 host 拆入口”的思路,已经做了一批原型代码修改: + +### 5.1 新增的 bootstrap / prepare / register 原型 + +- `atom/plugin/sglang/bootstrap.py` +- `atom/plugin/vllm/bootstrap.py` +- `atom/plugin/sglang/prepare.py` +- `atom/plugin/sglang/register.py` + +### 5.2 `prepare.py` 已部分收缩为兼容层 + +- 实际的 SGLang prepare 逻辑已经移动到 `atom/plugin/sglang/prepare.py` +- `atom/plugin/prepare.py` 现在更像一个 legacy shim +- 但它仍保留了 framework-state helper(因为仓库里很多地方还在依赖) + +### 5.3 SGLang register 逻辑已开始下沉 + +SGLang 独有的几块逻辑已经迁入: + +- `register_ops_to_sglang` +- `set_sglang_attn_cls` +- `init_aiter_dist_for_sglang` +- `bootstrap_sglang_runtime` + +共享 `atom/plugin/register.py` 目前更多是兼容 facade。 + +### 5.4 说明 + +这些改动的定位是: + +- 用来显式化未来结构 +- 不是完整重构 +- 还保留了较多兼容层 + +## 6. 关于 SGLang attention backend 的重要判断 + +### 6.1 `ATOMAttnBackendForSgl` 与 public `AiterAttnBackend` + +在一次反复讨论后,本次对它们的关系收敛为: + +- 在 backend 角色位置上,两者基本是 apple-to-apple +- `ATOMAttnBackendForSgl` 不是“更高层的新东西” +- 而是 `SGLang full-attention runtime backend` 的 ATOM 对位实现 + +因此更合适的重构思路不是怀疑它的存在合法性,而是: + +- 保留它作为 full-attn backend core +- 再去拆它内部的 metadata / kv_cache / decode / graph 等职责 + +### 6.2 DeepSeek MLA 的特殊性 + +`sgl_attention_mla.py` 暴露出一个很重要的事实: + +- 对 `MLA`,尤其是 `DeepSeek MLA` +- SGLang plugin 已经不是单纯换 backend +- 而是 patch 了模型级 `forward` + +这说明: + +**MLA 在 SGLang plugin 里已经进入了 model specialization 维度。** + +## 7. 关于 GDN / KDA / Lightning / Mamba2 等非传统 “attention” + +### 7.1 调研结论 + +public SGLang 已经证明,系统里不止 full-attention backend,还存在: + +- linear attention backend + - `GDNAttnBackend` + - `KDAAttnBackend` + - `LightningAttentionBackend` + - `Mamba2AttnBackend` +- hybrid/composite backend + - `HybridLinearAttnBackend` + - `HybridAttnBackend` +- wrapper/composition backend + - `TboAttnBackend` + +### 7.2 对当前 ATOM plugin 设计造成的挑战 + +当前 ATOM plugin 的主抽象仍偏向: + +- MHA +- MLA +- full attention runtime + +这会带来几个结构性问题: + +1. 还没有把 `linear backend` 当成一等公民 +2. 还没有给 `hybrid/composite backend` 预留独立宿主位置 +3. 容易把 algorithm backend、kernel backend、host glue 混成一层 +4. 容易继续把 `GDN` 等路径硬塞回 full-attn backend 体系 + +### 7.3 当前共识 + +后续不应该继续只谈 “attention backend 重构”,而应该升级为: + +**sequence backend / mixer backend 体系重构** + +建议至少在设计上并列三类: + +- `full_attn` +- `linear_attn` +- `hybrid/composite` + +## 8. hybrid/composite backend 在 plugin 侧的接入思路 + +### 8.1 核心判断 + +由于 ATOM SGLang plugin 的启动方式是: + +- 启动 `sglang server` +- 再通过 plugin runtime override 接管 model / attention + +所以要接 `hybrid/composite backend`,不可避免需要 patch 某个 seam。 + +### 8.2 最好的 seam + +本次讨论认为,public SGLang 里最好的 seam 是: + +- `model_runner.py` 中导入并调用的 `attn_backend_wrapper` + +注意: + +- 它不是 `ModelRunner` 的成员方法 +- 而是 `model_runner.py` 模块级导入的符号 + +所以如果要 patch,更稳的是 patch: + +- `sglang.srt.model_executor.model_runner.attn_backend_wrapper` + +而不是仅 patch `attention_registry.attn_backend_wrapper`。 + +### 8.3 最推荐的方式 + +对 hybrid/composite backend 的接入方式,本次最终倾向于: + +**plugin-own 一个薄的 composite wrapper/factory,只在 backend 构造 seam 做单点 patch。** + +不推荐: + +- 深度 monkey patch public hybrid backend 实现 +- 一开始就复制整套 public linear/full backend 逻辑 + +更推荐: + +- full side 先用 ATOM full backend +- linear side 初期先复用 public SGLang 的 `GDNAttnBackend` / `KDAAttnBackend` / `LightningAttentionBackend` / `Mamba2AttnBackend` +- composite wrapper 由 plugin 侧拥有 + +## 9. 关于三种模式共享 ATOM attention 实现的难度 + +本次对: + +- `ATOM native/server` +- `ATOM vLLM plugin` +- `ATOM SGLang plugin` + +三种模式共享 ATOM attention 实现,得到的判断如下。 + +### 9.1 MHA + +难度:`Medium-High` + +原因: + +- `ATOM native` 与 `ATOM vLLM plugin` 已经在 `PagedAttentionImpl` 等层共享较多 +- 真正难的是 `SGLang plugin` 这一侧 +- 它不走 `PagedAttention` 的 ATOM 主路径,而走: + - `RadixAttention` + - `ATOMAttnBackendForSgl` + - `ForwardBatch` + +所以 MHA 共享的主要困难在于: + +**runtime orchestration 差异** + +### 9.2 MLA + +难度:`High` + +原因: + +- native/server 与 vLLM plugin 仍然较多共享 `MLAAttention` +- 但 SGLang plugin 下的 DeepSeek MLA 已经上卷到 model specialization +- 不只是 backend 不同,连 forward 组织方式都不同 + +所以 MLA 的困难在于: + +**runtime orchestration + model specialization 双重叠加** + +## 10. 已产出的文档与图 + +本次会话过程中,已额外产出下列文档/图,用于不同角度说明问题。 + +### 10.1 代码目录内 Markdown + +- `atom/plugin/decoupling-diagram.md` +- `atom/plugin/vllm-integration-architecture.md` +- `atom/plugin/sglang-attention-backend-survey.md` + +### 10.2 Canvas / 可视化分析 + +- `atom-attention-backend-architecture-review.canvas.tsx` +- `atom-plugin-coupling-risk-analysis.canvas.tsx` +- `atom-plugin-decoupling-diagram.canvas.tsx` +- `atom-vllm-plugin-architecture.canvas.tsx` +- `atom-attention-sharing-modes.canvas.tsx` + +### 10.3 这些产物对应的主题 + +- 当前 attention backend 架构缺陷 +- vLLM / SGLang plugin 耦合与风险 +- plugin 解耦方向与 bootstrap 理解 +- vLLM plugin 集成架构 +- public SGLang backend survey +- 三种模式共享 ATOM attention 的难度评估 + +## 11. 当前最值得继续推进的方向 + +如果延续本次会话的结论,后续工作最值得按下面顺序推进: + +1. 继续收缩 `atom/plugin/` 根目录的 shared runtime 语义 +2. 把 `full_attn / linear_attn / hybrid` 三层作为 plugin 后续结构设计的一等公民 +3. 在 `SGLang plugin` 侧先补出 `hybrid/composite` 组合层 +4. 初期复用 public SGLang linear backend,优先验证 runtime 结构是否成立 +5. 然后再评估哪些 linear backend 需要逐步替换成 ATOM 自己的实现 +6. 对 `MLA` 尽早拆开: + - generic MLA runtime + - DeepSeek specialization + +## 12. 一句话总结 + +本次会话最终把问题收敛为: + +**当前 ATOM plugin attention 的关键矛盾,不是底层 kernel 能不能共享,而是 vLLM / SGLang / native 三种模式在 runtime、metadata、model specialization 和 backend ownership 上没有对齐;后续重构应从“统一 attention backend”升级为“重新定义 plugin 的 sequence backend / host-owned runtime 结构”。** diff --git a/work_log/attn_refractory/2026-05-07-radix-cache-vs-paged-attention.md b/work_log/attn_refractory/2026-05-07-radix-cache-vs-paged-attention.md new file mode 100644 index 0000000000..f0f9d5bccc --- /dev/null +++ b/work_log/attn_refractory/2026-05-07-radix-cache-vs-paged-attention.md @@ -0,0 +1,482 @@ +# 2026-05-07 Radix Cache vs Paged Attention Notes + +> 预估阅读时间:15 分钟 +> 主题:梳理 `SGLang` 中 `radix cache` / `RadixAttention` 与 `ATOM` / `vLLM` 常见的 `paged attention` 之间的关系,解释为什么它们最终可以落到相同的 kernel,以及这件事对 `sglang plugin` attention backend 重构意味着什么。 + +## 1. 这篇笔记想回答什么 + +围绕 `sglang plugin attn backend` 的重构,最近反复出现了几个问题: + +1. `SGLang` 用的是 `RadixAttention`,`ATOM`/`vLLM plugin` 里常见的是 `PagedAttention`,两者是不是两套完全不同的注意力体系? +2. 如果它们真的不同,为什么在 `DeepSeek MLA` 这类路径里,最终又会调用到相同的底层 kernel? +3. 既然 `SGLang` 还多了一层 `radix cache` 前缀管理,为什么没有一个非常直观的结论说 "`SGLang` 一定比 `vLLM` 更快 / 更省显存"? +4. 对当前重构来说,`radix` / `paged` 的差异到底应该被归类到: + - host runtime 差异 + - KV cache 管理差异 + - metadata lowering 差异 + - kernel 差异 + 的哪一层? + +本文尝试把这几个问题放进同一个分析框架里。 + +## 2. 先给结论 + +### 2.1 最核心的一句话 + +`RadixAttention` 和 `PagedAttention` 主要不是在底层注意力数学上不同,而是在 **runtime 如何管理、共享、命中和定位历史 KV** 上不同。 +一旦 runtime 把历史 KV lowering 成底层 kernel 能吃的 `page_table`、`kv_indptr`、`kv_indices`、`qo_indptr` 之类的 metadata,kernel 就不再关心这些索引最初是来自 radix tree 还是来自 paged block table。 + +### 2.2 另一句更工程化的结论 + +对重构来说,不应该把 “`RadixAttention` vs `PagedAttention`” 当作 “能不能共享底层执行实现” 的直接判断依据。 +更应该把系统拆成三层: + +1. **host/runtime 层** + 例如 `ForwardBatch`、`RadixAttention`、`ModelRunner`、prefix cache、speculative 调度 +2. **execution metadata lowering 层** + 例如 `page_table`、`block_tables`、`kv_indptr`、`kv_indices`、`qo_indptr` +3. **kernel / execution core 层** + 例如 `pa_fwd_asm`、`pa_persistent_fwd`、`flash_attn_varlen_func`、`mla_decode_fwd`、`mla_prefill_fwd` + +`radix` 和 `paged` 的差异,主要在第 1、2 层; +兼容性和共享机会,主要出现在第 2、3 层。 + +## 3. 为什么名字容易让人误解 + +当前最容易造成误解的点是: + +- `PagedAttention` 这个名字听起来像“最终执行算法” +- `RadixAttention` 这个名字也听起来像“最终执行算法” + +但在工程上,它们都更接近 **host-facing attention runtime abstraction**,而不是最终 kernel 名字。 + +更具体一点: + +- `PagedAttention` 更强调 **按固定 page/block 管理 KV cache** +- `RadixAttention` 更强调 **按 prefix/radix tree 管理请求历史与前缀复用** + +这两者都不是在说: + +- `QK^T` 怎么算 +- softmax 怎么做 +- decode kernel 用哪套 asm/triton/flashinfer + +这些底层执行问题,是更下一层的 backend / kernel 决定的。 + +## 4. `RadixAttention` 到底是什么 + +### 4.1 `RadixAttention` 不是 kernel + +在 public `SGLang` 里,`RadixAttention` 本身并不直接定义底层 kernel 选择。 +它更像是: + +- SGLang 模型层里统一使用的 attention layer 壳 +- 它知道自己在 `ForwardBatch` 语义里运行 +- 它把真正的执行委托给 `forward_batch.attn_backend` + +因此,`RadixAttention` 更像一个 **layer/runtime 入口点**。 + +### 4.2 `radix` 的重点是 prefix 管理 + +`radix cache` 的本质,是一棵 prefix tree,用来做: + +- 前缀命中 +- 前缀复用 +- unfinished / finished request 的缓存插入 +- 基于 prefix 的 eviction / lifecycle 管理 + +它回答的问题更像: + +- 这个请求的历史前缀已经缓存到哪里了? +- 哪一部分可以直接复用,不必重新 prefill? +- 哪些 KV slot/page 仍然属于共享前缀? +- 哪些历史片段可以回收? + +所以 `radix` 优化的主要是 **prefix reuse**,而不是单次 attention kernel 的算力效率。 + +### 4.3 `radix cache` 在 SGLang 里不是完全“反 page”的 + +这一点很重要。public `SGLang` 的 `radix_cache` 本身就已经带有 page-aware 的逻辑。 + +例如它有 page 对齐: + +- `page_align_keys()` + +也有按 page 粒度匹配 prefix: + +- `_key_match_paged()` + +这意味着: + +- `radix cache` 解决的是前缀共享与命中问题 +- 但它并不排斥最终按 page/block 粒度来表达缓存 + +也就是说,**radix tree 的逻辑管理** 与 **paged/block 的物理表达** 并不冲突。 + +## 5. `PagedAttention` 到底是什么 + +### 5.1 `PagedAttention` 的重点不是 prefix tree + +`PagedAttention` 更强调: + +- KV cache 被组织成固定大小的 page/block +- 每个请求通过 `page_table` / `block_table` 找到自己历史对应的 page +- kernel 依照这些 page/block 索引读取历史 KV + +它回答的问题更像: + +- 当前请求历史有哪些 block/page? +- 它们在物理内存中的位置是什么? +- 这些 page 如何映射到 kernel 所需的 block table? + +所以 `paged attention` 更偏: + +- **物理存储布局** +- **kernel 读取 KV 的地址组织方式** + +### 5.2 `PagedAttention` 天生不等于 prefix reuse + +`paged attention` 当然也可以配合 prefix cache 使用,但 prefix reuse 并不是它名字里最核心的概念。 +它的核心价值更在于: + +- 支持非连续物理布局 +- 让 decode kernel 可以按 page/block 高效读取 KV + +## 6. 为什么 `radix` 和 `paged` 最后可以兼容 + +### 6.1 因为它们主要解决的是不同层的问题 + +可以把它们看成: + +- `radix`: 逻辑历史管理器 +- `paged`: 物理页表表达方式 + +两者关系有点像: + +- `radix` 负责决定“哪些历史是共享的、哪些是命中的、请求当前逻辑历史是什么” +- `paged` 负责决定“这些历史最终以什么 page/block 索引形式交给 kernel” + +只要 runtime 最终能把逻辑历史翻译成 page/block 或 indptr/index 形式,底层 kernel 并不需要知道前缀最初是如何组织的。 + +### 6.2 兼容发生在 lowering 之后 + +这一点非常关键: + +上层看起来是: + +- `ForwardBatch` +- `RadixAttention` +- `req_to_token` +- prefix cache + +但 lower 到 backend 之后,会变成执行态 metadata,例如: + +- `page_table` +- `block_tables` +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `kv_last_page_len` + +到了这个层次,kernel 只看: + +- query 的 shape +- KV cache 的 shape +- 历史长度 +- 索引表 / indptr + +它不看: + +- 这套索引是不是来自 radix tree +- 是不是来自某个 prefix cache 命中 +- 是不是通过 `PagedAttention` 对象构造出来的 + +所以兼容性不是“名字兼容”,而是 **execution metadata 兼容**。 + +## 7. public `SGLang` 自己就在做类似 lowering + +这不是 `ATOM plugin` 独有的现象。public `SGLang` 本身就在这么做,只是不同 backend lower 到的 metadata 形态略有不同。 + +### 7.1 `aiter_backend` + +public `SGLang` 的 `aiter_backend` 会从: + +- `forward_batch.seq_lens` +- `forward_batch.req_pool_indices` +- `req_to_token` + +构造: + +- `kv_indptr` +- `kv_indices` +- `qo_indptr` + +也就是说,它会把 host runtime 的历史表示 lower 成 AITER kernel 所需的执行态索引。 + +### 7.2 `triton_backend` + +public `SGLang` 的 `triton_backend` 也做类似事情。 +它同样维护: + +- `kv_indptr` +- `kv_indices` +- `qo_indptr` + +然后再调用 Triton 的 decode/extend kernel。 + +### 7.3 `trtllm_mha_backend` / 部分 `flashinfer` 路径 + +这类 backend 更偏向: + +- 直接构造 `page_table` +- 或 `block_tables` + +再喂给 flashinfer / TRTLLM 风格的 kernel。 + +所以更准确地说,public `SGLang` 的共同模式不是: + +> “所有 backend 都把 radix runtime 翻成 paged attention” + +而是: + +> “所有 backend 都会把 radix/runtime 层的历史表示,lower 成各自 kernel 能吃的索引/页表 metadata” + +其中有的偏 `kv_indptr/kv_indices`,有的偏 `page_table/block_tables`。 + +## 8. 从 kernel 角度看,为什么可以兼容 + +### 8.1 kernel 真正关心什么 + +绝大多数 attention kernel 真正关心的是: + +1. 当前 query 的排列方式 +2. 每个请求可见的历史 KV 长度 +3. 如何从某个索引结构里拿到历史 KV 的物理地址 +4. page/block 大小、stride、layout +5. 是否需要 prefix/extend/speculative 的特殊 metadata + +它通常不关心: + +1. 这份索引最初是不是从 prefix tree 查出来的 +2. 命中前缀的逻辑是谁做的 +3. runtime 是 `RadixAttention` 还是 `PagedAttention` + +### 8.2 对 MHA kernel 的影响 + +MHA decode 常见地会落到: + +- `pa_fwd_asm` +- `pa_persistent_fwd` +- `flash_attn_varlen_func` + +这些 kernel 更偏 page/block 或 varlen index 风格。 +只要 runtime 最终给出: + +- `page_table` +- `context_lens` +- `block_tables` + +或者: + +- `kv_indptr` +- `kv_indices` +- `qo_indptr` + +它们都能算。 + +### 8.3 对 MLA kernel 的影响 + +`DeepSeek MLA` 更容易看出这种兼容性,因为 MLA 的底层算子路径更明确。 + +典型 kernel 包括: + +- `mla_decode_fwd` +- `mla_prefill_fwd` +- `concat_and_cache_mla` +- `fused_qk_rope_concat_and_cache_mla` + +这些 kernel 只要求: + +- 压缩后的 latent KV 表示 +- 以及相应的 `qo_indptr` / `kv_indptr` / `kv_indices` + +所以无论上层 host 是: + +- `ATOM native` +- `ATOM vLLM plugin` +- `ATOM SGLang plugin` + +只要最终 lower 成相同的 MLA execution metadata,就自然会落到同一套 MLA kernel。 + +## 9. `DeepSeek MLA` 为什么尤其容易“看起来收敛” + +### 9.1 因为 MLA 的 execution core 更专用 + +`DeepSeek MLA` 的真正执行对象不是: + +- “`RadixAttention` 版本的 MLA” +- “`PagedAttention` 版本的 MLA” + +而是: + +- 一组固定的 MLA latent/cache 表示 +- 一套固定的 rope + kv cache 写入方式 +- 一套固定的 prefill/decode kernel + +换句话说,MLA 的底层 core 更容易抽成“共享执行层”。 + +### 9.2 因为 host 差异主要体现在 metadata 准备与调度 + +对于 `DeepSeek MLA` 来说,上层 host 差异更多体现在: + +- `positions` 从哪里来 +- `forward_batch` 怎么组织 +- decode / extend / speculative / target_verify 怎么分流 +- `req_to_token` 怎么转换成 `kv_indices` + +一旦进入 `mla_decode_fwd` 或 `mla_prefill_fwd` 前,这些差异已经被消化掉了。 + +所以从外部观察,很容易得到一种印象: + +> “怎么上面完全不一样,下面却是同一套 kernel?” + +其实并不奇怪,因为两边只是上层 runtime 不同,而 MLA execution core 本来就在下层收敛。 + +## 10. 既然兼容,为什么 `SGLang` 没有天然更快或更省显存 + +这是另外一个很容易误判的问题。 + +### 10.1 `radix cache` 省的是前缀重复,不是所有 KV + +`radix cache` 能节省的,是重复前缀带来的重复 prefill 和重复 KV 占用。 + +它主要帮助的是: + +- 大量请求共享长前缀 +- 同 prompt 多分支采样 +- prefix-heavy workload +- 某些 speculative / chunked prefill 场景 + +如果 workload 不满足这些条件,它就不会天然表现为显著优势。 + +### 10.2 它优化的是 prefix reuse,不是单 token decode kernel + +radix 管理层并不会让 `mla_decode_fwd`、`pa_fwd_asm` 这种 kernel 自己更便宜。 +kernel 的 raw 速度主要还是受: + +- 头数 +- head dim +- KV layout +- quant +- page size +- GPU kernel 实现 + +这些因素决定。 + +所以常见现象是: + +- prefix-heavy workload 下,`SGLang` 的收益更明显 +- decode-heavy workload 下,收益没那么明显 +- 如果 `vLLM` 也有自己的 prefix caching/block caching,差距会进一步缩小 + +### 10.3 radix 自己也有管理成本 + +`radix cache` 不是免费层。它也有: + +- prefix match +- tree split / insert / eviction +- request -> KV slot 映射维护 +- speculative / unfinished request 的额外 bookkeeping + +如果命中率不高,这部分成本并不会自动转化成收益。 + +## 11. 对当前重构最有价值的判断 + +### 11.1 不要把 `RadixAttention` vs `PagedAttention` 当成是否共享 execution core 的边界 + +因为从 public `SGLang` 和当前 `ATOM plugin` 的代码路径看,二者最终都在做类似事情: + +- 从 host runtime 的历史表示出发 +- 构造底层 kernel 所需的 execution metadata + +所以: + +- runtime 不同,不代表 execution core 必须不同 +- 名字不同,不代表 kernel 层必须分叉 + +### 11.2 真正的边界更适合这样划 + +#### A. `sglang plugin runtime` 应拥有的部分 + +- `ForwardBatch` +- `RadixAttention` +- prefix cache / radix cache +- `ModelRunner` / `attn_backend_wrapper` 相关调度 +- speculative / graph / verify / chunked prefill 等宿主语义 + +#### B. execution metadata lowering 层 + +这层是最值得重新定义的边界。它负责: + +- 把 `req_to_token`、prefix 命中结果、cache state +- 转成 `page_table` / `kv_indptr` / `kv_indices` / `qo_indptr` + +这层既带有 host 语义,又已经开始接近共享执行核心,是未来最值得梳理清楚的一层。 + +#### C. shared execution core + +例如: + +- `mla_decode_fwd` +- `mla_prefill_fwd` +- `pa_fwd_asm` +- `pa_persistent_fwd` +- `flash_attn_varlen_func` + +以及与这些 kernel 紧耦合的那部分 ATOM core helper。 + +### 11.3 这对 `sglang plugin` 的重构意味着什么 + +当前 `sglang plugin` 最不该做的是: + +- 因为底层 kernel 能共享,就把 `sglang` 的 runtime seam 也强行对齐到 `vllM` +- 或者因为上层叫 `RadixAttention`,就认为它和 `PagedAttention` 完全不能共享底层执行实现 + +更合理的重构方向是: + +1. 保持 `sglang` 的 host/runtime 语义完整 +2. 识别 runtime lowering 层的清晰边界 +3. 尽量把 execution core 继续下沉回共享 `ATOM core` + +## 12. 一种对外更容易讲清楚的表述 + +如果后续要向别人解释,可以用下面这几句话: + +### 12.1 关于 `radix` vs `paged` + +`radix cache` 主要解决的是 prefix 命中与复用,`paged attention` 主要解决的是 page/block 形式的 KV 组织与读取。 +前者偏逻辑管理,后者偏物理表达;两者不是同一层次的问题。 + +### 12.2 关于为什么最终会调到同一个 kernel + +因为 attention kernel 只关心最终的执行态 metadata,比如 `page_table`、`kv_indptr`、`kv_indices`、`qo_indptr`。 +一旦 runtime 把历史 KV lowering 成这类形式,kernel 就不再关心这些索引最初是由 radix tree 还是 paged runtime 生成的。 + +### 12.3 关于为什么 `SGLang` 没有天然显著更快 + +因为 `radix` 带来的主要收益是 prefix reuse,而不是单 token decode kernel 的 raw speed。 +如果 workload 不是 prefix-heavy,或者 decode 占主要成本,那么 radix 带来的优势不会自然放大成“总是更快 / 更省显存”。 + +## 13. 最终总结 + +把这次探索收敛成一句话: + +**`RadixAttention` 与 `PagedAttention` 的主要差异,不在底层注意力数学,也不一定在最终 KV 的物理表示本身,而在 host runtime 如何管理、共享、命中并 lowering 历史 KV;一旦 lowering 完成,底层 kernel 完全可以是共享的。** + +对当前 `sglang plugin attn backend` 的重构来说,真正值得做的不是争论 “`radix` 和 `paged` 能不能统一成一个名字”,而是把系统拆清楚: + +- 哪些是 `sglang` 必须自有的 runtime / prefix cache / scheduling 语义 +- 哪些是 execution metadata lowering +- 哪些已经是可以共享的 `ATOM` execution core + +这比直接讨论 “`radix` 和 `paged` 谁更先进、谁应该替代谁” 更接近真正的工程问题。 diff --git a/work_log/attn_refractory/2026-05-08-kv-indices-address-space-notes.md b/work_log/attn_refractory/2026-05-08-kv-indices-address-space-notes.md new file mode 100644 index 0000000000..e5a63be0f1 --- /dev/null +++ b/work_log/attn_refractory/2026-05-08-kv-indices-address-space-notes.md @@ -0,0 +1,359 @@ +# 2026-05-08 KV Indices and Address Space Notes + +> 预估阅读时间:8-10 分钟 +> 主题:解释 `kv_indptr` / `kv_indices` 为什么不是纯 host metadata,也不是完全 storage-agnostic 的抽象,并用 `sglang` / `ATOM` / `vllm plugin` 的实际 KV cache 存储举例说明。 + +## 1. 想回答的问题 + +最近有一个很容易卡住重构讨论的问题: + +> `kv_indptr` / `kv_indices` 看起来和 KV cache 的存储方式深度绑定。 +> 既然如此,为什么还能把它们看成 `sglang` / `ATOM` / `vllm plugin` 之间可以共享的 metadata? + +这个问题的关键在于区分两件事: + +1. **host/runtime 是否相同** +2. **kernel 最终消费的地址空间模型是否相同** + +一句话先说结论: + +**`kv_indptr` / `kv_indices` 不是完全 storage-agnostic 的抽象,它们绑定的是某种“可索引的 KV 地址空间”; +但它们可以是 host/runtime-agnostic 的,只要不同宿主最终都能把自己的 runtime state lowering 到同一种地址空间模型。** + +## 2. 三层视角 + +理解这个问题,最好先把 attention 相关代码拆成三层: + +### 2.1 host/runtime 层 + +这一层描述的是宿主框架如何组织请求和调度,例如: + +- `ATOM native` 的 `ScheduledBatch` +- `vllm plugin` 的 `query_start_loc` / `num_prefills` +- `sglang plugin` 的 `ForwardBatch` / `forward_mode` / `spec_info` + +这一层表达的是: + +- 哪些请求在 batch 中 +- 哪些是 decode,哪些是 prefill +- prefix / speculative / graph 的语义是什么 + +### 2.2 execution metadata lowering 层 + +这一层做的事情是: + +- 把 host/runtime 里的 batch 状态 +- 翻译成 kernel 真正能消费的索引结构 + +典型字段: + +- `block_tables` / `page_table` +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `kv_last_page_len` + +### 2.3 kernel / execution core 层 + +这一层只关心: + +- query tensor +- KV cache tensor +- index / indptr / page table +- persistent kernel workspace + +例如: + +- `mla_decode_fwd` +- `mla_prefill_fwd` +- `pa_fwd_asm` +- `pa_persistent_fwd` + +## 3. `sglang` 的实际 KV cache 存储 + +在 `sglang` 里,KV cache 本身不是存在 radix tree 里。 +radix tree 负责的是 prefix 复用与命中;真正的 KV 还是存在 memory pool 里。 + +`sglang` 自己的注释说得很清楚: + +- `ReqToTokenPool` 负责 request -> token location 映射 +- `TokenToKVPoolAllocator` 负责 KV cache index 管理 +- `KVCache` 真正持有物理 K/V tensor + +也就是说,`sglang` 的 runtime 世界里至少有两层: + +1. **逻辑视角**:某个 request 的第 `i` 个 token 属于哪里 +2. **物理视角**:这个 token 对应的 K/V 存在物理 pool 的哪个 slot + +### 3.1 `ReqToTokenPool` + +`ReqToTokenPool` 本质上是一张二维表: + +```text +req_to_token[req_id, token_pos] = physical_slot_id +``` + +它回答的是: + +- 某个 request 的第 `j` 个逻辑 token +- 对应到物理 KV pool 里的哪个 slot + +### 3.2 物理 KV pool + +物理 pool 里,K 和 V 都是按 slot 存的。 +写入 KV 时,最终就是按 `loc/indices` 写入: + +```text +k_cache[indices] = k +v_cache[indices] = v +``` + +所以从 kernel 角度看,最终它看到的是: + +- 一个可以索引的 KV 地址空间 +- 一组 slot index + +而不是看到“radix tree”本身。 + +## 4. `ATOM` / `vllm plugin` 的实际 KV cache 存储 + +`ATOM` / `vllm plugin` 更偏 paged/block 风格。 + +典型思路是: + +- 每个 request 有一张 `block_table` +- 每个 block 对应一个固定大小的 page +- kernel 按 `block_table` 或其展开后的 index 去读 KV + +所以它的上层直觉更像: + +```text +request -> block table -> page/block address -> physical KV +``` + +和 `sglang` 相比,最大的不同不是“有没有物理地址空间”,而是: + +- `sglang` 上层先经过 `req_to_token` +- `vllm/ATOM` 上层先经过 `block_table` + +但两边最终都能 lower 到一组可供 kernel 读取的地址索引。 + +## 5. 例子一:MLA decode 中的 `kv_indptr` / `kv_indices` + +这个例子里,可以把 `kv_indices` 理解为 **token-slot index**。 + +### 5.1 在 `sglang` 里 + +假设一个 request 当前长度是 10, +它在物理 KV pool 里的 slot 顺序不是连续的,而是: + +```text +[8, 9, 10, 11, 0, 1, 2, 3, 4, 5] +``` + +那 `req_to_token[req_id, :10]` 就大致表示: + +```text +logical token position: 0 1 2 3 4 5 6 7 8 9 +physical slot id: 8 9 10 11 0 1 2 3 4 5 +``` + +如果这个 batch 只有这一个 request,那么 lowering 之后可以得到: + +```text +kv_indptr = [0, 10] +kv_indices = [8, 9, 10, 11, 0, 1, 2, 3, 4, 5] +``` + +这里: + +- `kv_indptr` 表示“第 0 个 request 的 KV 范围是 `[0, 10)`” +- `kv_indices` 表示“真正去物理 KV pool 里读这些 slot” + +### 5.2 在 `vllm/ATOM` 里 + +如果 page size 是 4,同样 10 个 token 对应的 block layout 可以写成: + +```text +block 2 -> slots [8, 9, 10, 11] +block 0 -> slots [0, 1, 2, 3] +block 1 -> slots [4, 5, 6, 7] +``` + +block table 大致就是: + +```text +[2, 0, 1] +``` + +展开前 10 个 token 后,对应的 slot 顺序仍然是: + +```text +[8, 9, 10, 11, 0, 1, 2, 3, 4, 5] +``` + +于是 lower 之后,给 MLA decode kernel 的: + +```text +kv_indptr = [0, 10] +kv_indices = [8, 9, 10, 11, 0, 1, 2, 3, 4, 5] +``` + +和 `sglang` 这边是等价的。 + +### 5.3 这个例子说明什么 + +说明在 MLA decode 这个路径里: + +- `kv_indices` 绑定的是 **token-slot 地址空间** +- 它不是“完全与存储无关” +- 但它已经不关心 slot 列表最初是由 `req_to_token` 还是 `block_table` 推出来的 + +也就是说,**它和具体宿主无关,但和被选定的地址空间模型有关。** + +## 6. 例子二:MHA persistent decode 中的 `page_table` / `kv_indices` + +这个例子里,`kv_indices` 更接近 **page/block index**,而不是 token-slot index。 + +### 6.1 在 `sglang plugin` 里 + +仍然假设 page size = 4。 +如果某个 request 的 token-slot 顺序是: + +```text +[8, 9, 10, 11, 0, 1, 2, 3, 4, 5] +``` + +那么每页的起始 slot 是: + +```text +[8, 0, 4] +``` + +除以 page size 后,对应 page id: + +```text +[2, 0, 1] +``` + +在 `sglang plugin` 里,这就是 `page_table` 的来源: +从 `req_to_token` 按页采样,再除以 `page_size`。 + +然后 `_build_pa_metadata_for_decode()` 再把它变成 `pa_persistent_fwd` 所需的 page-level `kv_indices`。 + +### 6.2 在 `ATOM` / `vllm` 里 + +这一层本来就是 paged/block 视角,所以 `block_table` 原生就已经是: + +```text +[2, 0, 1] +``` + +也就是说: + +- `sglang` 是 `req_to_token -> page_table` +- `vllm/ATOM` 是直接 `block_table -> page_table` + +但到了 persistent kernel 眼里,最后看到的是同一种 page/block 地址空间。 + +### 6.3 这个例子说明什么 + +说明在 MHA persistent decode 这个路径里: + +- `kv_indices` 绑定的是 **page/block 地址空间** +- 它仍然不是“完全 storage-agnostic” +- 但也不需要知道上层 host 是 `sglang` 还是 `vllm` + +只要两边都能 lower 到同一个 page/block 地址空间,kernel 就能共享。 + +## 7. 这两个例子真正说明的事 + +上面两个例子其实在说明同一个结论: + +### 7.1 `kv_indptr` / `kv_indices` 不是 host metadata + +它们不描述: + +- `ForwardBatch.forward_mode` +- `query_start_loc` +- `num_prefills` +- `run_graph` +- prefix hit 的高层语义 + +它们只描述: + +- 当前这次 kernel 调用 +- 要读哪些 KV 地址 +- 每个 request 的边界在哪里 + +所以它们更接近: + +- execution metadata +- kernel-facing metadata + +而不是 host/runtime metadata。 + +### 7.2 但它们也不是完全 storage-agnostic + +它们依赖: + +- 当前 KV cache 的地址空间模型 +- kernel 对该地址空间的解释方式 + +如果底层 cache 不是: + +- 可线性索引的 token slot +- 或可 flatten 的 page/block index + +那 `kv_indptr` / `kv_indices` 这套抽象就未必成立。 + +所以它们不是“无限通用”的。 + +## 8. 对重构的直接启示 + +这个背景知识对当前 `sglang attention` 重构最有价值的结论是: + +### 8.1 不要去统一 host metadata + +比如不要强行统一: + +- `ForwardBatch` +- `query_start_loc` +- `num_prefills` +- `run_graph` + +这些都属于宿主框架自己的 runtime 语言。 + +### 8.2 应该争取统一 execution metadata + +例如: + +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `page_table` +- `kv_last_page_len` +- `work_indptr` +- `reduce_indptr` + +以及围绕这些字段构造的更小 dataclass,例如当前 draft 里已经开始尝试的: + +- `MLADecodeKernelMetadata` +- `MHAAsmKernelMetadata` +- `MHAPersistentKernelMetadata` + +### 8.3 正确的共享边界 + +所以真正可共享的不是: + +- `sglang` 的 `ForwardMetadata` +- `vllm plugin` 那一整套 metadata family + +而是: + +- 它们再往下切出来的 +- kernel-facing execution metadata + +## 9. 一句话总结 + +**`kv_indptr` / `kv_indices` 不是“和存储彻底无关”的抽象,它们绑定的是某种被 kernel 接受的 KV 地址空间;但正因为它们已经脱离了宿主 batch/runtime 语义,`sglang` 和 `vllm/ATOM` 即使上层完全不同,也仍然可以在这层 metadata 上收敛并共享 kernel 调用逻辑。** diff --git a/work_log/attn_refractory/2026-05-08-metadata-layering-summary.md b/work_log/attn_refractory/2026-05-08-metadata-layering-summary.md new file mode 100644 index 0000000000..86bccce8e6 --- /dev/null +++ b/work_log/attn_refractory/2026-05-08-metadata-layering-summary.md @@ -0,0 +1,144 @@ +# 2026-05-08 Metadata Layering Summary + +## 1. 目的 + +这份笔记简要总结 `ATOM native`、`vllm plugin`、`sglang plugin` 三边 metadata 的关系,重点回答: + +- `ATOM/atom/utils/forward_context.py` 中的 `AttentionMetaData` 是什么 +- `ATOM/atom/plugin/attention.py` 里的 metadata dataclass 在做什么 +- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` 里的 `ForwardMetadata` 属于哪一层 +- 哪些 metadata 值得共享,哪些不值得强行统一 + +## 2. 三套 metadata 的定位 + +### 2.1 `AttentionMetaData` + +文件:`ATOM/atom/utils/forward_context.py` + +它是 `ATOM` attention 执行链路里的**通用外层容器**。 +里面既有: + +- 通用 attention 字段 + - `slot_mapping` + - `block_tables` + - `kv_indptr` + - `kv_indices` + - `work_meta_data` +- 也预留了 plugin 扩展入口 + - `plugin_metadata` + +所以它更像一个大容器,而不是某个 host 专属的数据模型。 + +### 2.2 `vllm plugin metadata` + +文件:`ATOM/atom/plugin/attention.py` + +`vllm plugin` 不是重写一套完全独立的 metadata,而是在 `AttentionMetaData` 外壳上,额外挂了一层 `plugin_metadata` payload。 + +代表类型包括: + +- `AiterFlashAttentionMetadataForPluginMode` +- `AiterMLACommonMetadataForPluginMode` +- `AiterMLADecodeMetadataForPluginMode` +- `AiterMLASparseMetadataForPluginMode` + +这些 dataclass 主要承载: + +- `vllm` 的 batch/runtime 语义 +- decode/prefill/extend phase 拆分 +- chunked prefill / DCP / sparse MLA 等 feature-specific 信息 + +一句话:**`vllm plugin` 是“外层复用 `AttentionMetaData`,内层自带一套 host-specific metadata family”。** + +### 2.3 `sglang plugin` 的 `ForwardMetadata` + +文件:`ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` + +`ForwardMetadata` 不是 `AttentionMetaData.plugin_metadata` 那种 payload,而是 `sglang backend` 内部的**lowering 结果缓存对象**。 + +它更接近: + +- `ForwardBatch` +- `req_to_token` +- `token_to_kv_pool` +- graph/speculative path + +向 kernel metadata 的过渡层。 + +一句话:**`ForwardMetadata` 更像 backend-local lowering state,而不是 host-agnostic 的最终执行 metadata。** + +## 3. 三类字段 + +### 3.1 host/runtime-bound + +这类字段描述宿主框架如何组织 batch/runtime,而不是 kernel 如何计算。 + +例子: + +- `vllm plugin` + - `query_start_loc` + - `num_decodes` + - `num_prefills` + - `chunked_context` +- `sglang plugin` + - `run_graph` + - 各种 `ForwardBatch.forward_mode` 派生分支语义 + +### 3.2 execution/kernel-bound + +这类字段最接近 kernel 真正需要的执行信息。 + +例子: + +- `slot_mapping` +- `block_tables` +- `context_lens` +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `kv_last_page_len` +- `work_indptr` +- `reduce_indptr` + +### 3.3 hardware-sensitive + +这类字段和 page size、dtype、kernel 选择、并行配置相关。 + +例子: + +- `max_q_len` +- `max_kv_len` +- `num_kv_splits` +- `attn_out_dtype` +- `head_dim` + +## 4. 哪些值得共享 + +### 不建议强行共享的 + +- `AttentionMetaData` 整体 +- `vllm plugin` 整套 `plugin_metadata` +- `sglang plugin` 整体的 `ForwardMetadata` + +原因不是它们没价值,而是它们都带着很重的 host/runtime 语义。 + +### 值得重点共享的 + +应该共享的是更小的 **execution metadata view**,例如当前 draft 已经开始尝试的: + +- `MLADecodeKernelMetadata` +- `MHAAsmKernelMetadata` +- `MHAPersistentKernelMetadata` + +这些结构只保留某个 kernel call 真正需要的字段,更适合在: + +- `ATOM core` +- `vllm plugin` +- `sglang plugin` + +之间共享。 + +## 5. 一句话结论 + +现在不是三边共用“一套 metadata”,而是三边各自都有一层 host-owned metadata / lowering 体系。 +真正值得共享的,不是这些大 metadata 本身,而是从它们里面切出来的、更小的、kernel-facing execution metadata。 diff --git a/work_log/attn_refractory/2026-05-11-attn-refractory-session-summary.md b/work_log/attn_refractory/2026-05-11-attn-refractory-session-summary.md new file mode 100644 index 0000000000..3b2128f5bd --- /dev/null +++ b/work_log/attn_refractory/2026-05-11-attn-refractory-session-summary.md @@ -0,0 +1,398 @@ +# 2026-05-11 ATOM Attention Refactory Session Summary + +> 主题:从 `sglang plugin` 设计者视角,梳理 `attention backend` 重构中的 runtime / metadata / KV cache / kernel reuse 边界,并记录几种 kernel 共用 draft 的结论。 + +## 1. 本次会话的主线 + +本次讨论的核心不是继续从 `vllm plugin` 出发看问题,而是明确切换到: + +- **`sglang plugin attn backend` 设计者视角** +- 关注 `sglang plugin` 自己的 runtime 语义 +- 判断哪些层应该继续由 `sglang plugin` 自有 +- 判断哪些层值得尽量与 `ATOM core` 共用 + +围绕这个目标,本次会话主要探索了 6 个问题: + +1. `RadixAttention` / `radix cache` 与 `PagedAttention` 的关系 +2. public `sglang` 是否也会把 radix runtime lowering 到 paged / index metadata +3. metadata 为什么会越长越多,以及哪些字段本质上属于 host/runtime +4. `kv_indptr` / `kv_indices` 为什么既绑定存储模型,又仍然能跨 host 共用 +5. `sglang plugin` 和 public `sglang` 在 KV cache 管理上的异同 +6. 在不大改文件结构、不做过度软件工程的前提下,怎么试探性地共用 kernel call + +## 2. 对 `sglang plugin` 重构最关键的几个判断 + +### 2.1 不是所有问题都该叫 “attention backend 复用” + +当前问题至少分成三层: + +1. **host/runtime 层** + - `ForwardBatch` + - `RadixAttention` + - speculative / graph / verify / extend 调度 +2. **execution metadata lowering 层** + - `page_table` + - `kv_indptr` + - `kv_indices` + - `qo_indptr` + - `kv_last_page_len` +3. **kernel / execution core 层** + - `flash_attn_varlen_func` + - `mla_decode_fwd` + - `mla_prefill_fwd` + - `pa_fwd_asm` + - `pa_persistent_fwd` + +后续任何“共用更多 code”的讨论,都应该先说明是在第几层说话。 + +### 2.2 `sglang plugin` 不应该为了复用去对齐到 `vllm` 的 host runtime + +`vllm plugin` 和 `sglang plugin` 上层接入模型不同: + +- `vllm plugin` 更像 layer-owned / metadata-builder-owned 路径 +- `sglang plugin` 更像 `ForwardBatch` 驱动的 backend runtime 路径 + +所以: + +- 不该共享 `vllm plugin` 那层 runtime facade +- 也不该强行统一大 metadata 对象 +- 真正值得共享的,是更靠近 kernel 的 execution metadata 和 kernel call + +### 2.3 `sglang plugin` 想共用 `ATOM` code,最现实的边界在 kernel call 一层 + +如果先不做大规模结构改造,最适合先共用的是: + +- `mla_decode_fwd` +- `pa_persistent_fwd` +- `pa_fwd_asm` +- `flash_attn_varlen_func` + +也就是: + +- 保留各自的 runtime lowering +- 先把“掉 kernel 前最后一跳”抽出来 + +## 3. `RadixAttention` / `radix cache` 和 `PagedAttention` 的结论 + +### 3.1 `RadixAttention` 不是最终 kernel + +`RadixAttention` 更像: + +- `sglang` 模型层统一使用的 attention layer 壳 +- 它知道自己跑在 `ForwardBatch` 语义里 +- 最终还是委托给 `forward_batch.attn_backend` + +所以它不是 “另一套底层注意力数学”,而是 **host-facing runtime abstraction**。 + +### 3.2 `PagedAttention` 和 `radix` 解决的问题不在同一层 + +- `radix cache` 更偏 prefix 命中 / 复用 / eviction +- `paged attention` 更偏 block/page 形式的物理组织和读取 + +可以理解成: + +- `radix` 管“历史怎么共享和命中” +- `paged` 管“命中的历史最后如何变成 page/block 索引给 kernel” + +### 3.3 public `sglang` 自己也在做类似 lowering + +这不是 `ATOM plugin` 才有的行为。 + +public `sglang` 的: + +- `aiter_backend` +- `triton_backend` +- `trtllm_mha_backend` +- 部分 `flashinfer` 路径 + +本身也会把: + +- `ForwardBatch` +- `req_to_token` +- `seq_lens` + +lower 成: + +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `page_table` / `block_tables` + +所以,`radix runtime -> paged/index metadata` 不是 plugin 特例,而是公共设计模式的一部分。 + +## 4. Metadata 分层结论 + +### 4.1 现在至少有三套 metadata family + +1. **`ATOM native`** + - `AttentionMetaData` +2. **`vllm plugin`** + - `AttentionMetaData` 外壳 + - `plugin_metadata` 内层 payload +3. **`sglang plugin`** + - `ForwardMetadata` + +### 4.2 `AttentionMetaData` 和 `plugin/attention.py` 的关系 + +`AttentionMetaData` 是外层通用容器; +`plugin/attention.py` 里的很多 dataclass,是 `vllm plugin` 为表达宿主 batch/runtime 语义而塞进 `plugin_metadata` 的 payload。 + +所以它们不是并列关系,而是: + +```text +AttentionMetaData + └─ plugin_metadata + ├─ flash-style plugin metadata + ├─ MLA plugin metadata + └─ sparse MLA plugin metadata +``` + +### 4.3 `ForwardMetadata` 的定位 + +`ForwardMetadata` 更像: + +- `sglang backend` 内部的 lowering state +- 它不是最终统一 metadata +- 更不是天然可共享的 host-agnostic object + +它里面混着: + +- runtime control 信息 +- execution metadata +- persistent kernel 预构造结果 + +## 5. 三类字段 + +为了判断能不能共享,本次把字段分成三类: + +### 5.1 host/runtime-bound + +典型例子: + +- `vllm` 的 `query_start_loc` +- `num_decodes` +- `num_prefills` +- `chunked_context` +- `sglang` 的 `run_graph` + +这些字段描述的是宿主框架如何组织 batch/runtime,而不是 kernel 怎么算。 + +### 5.2 execution/kernel-bound + +典型例子: + +- `slot_mapping` +- `block_tables` +- `page_table` +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `kv_last_page_len` +- `work_indptr` +- `reduce_indptr` + +这些字段才是 kernel 真正需要消费的执行信息。 + +### 5.3 hardware-sensitive + +典型例子: + +- `max_q_len` +- `max_kv_len` +- `num_kv_splits` +- `head_dim` +- `attn_out_dtype` + +这些字段不是 host 语义,但会影响: + +- page/block 解释 +- kernel 选路 +- workspace/split 策略 + +## 6. 关于 `kv_indptr` / `kv_indices` 的关键结论 + +这个问题本次单独做了较多背景解释,结论是: + +### 6.1 它们不是完全 storage-agnostic + +因为它们绑定的是某种 **可索引的 KV 地址空间**: + +- token slot index +- page/block index +- compacted KV index + +所以它们不能脱离底层存储模型独立存在。 + +### 6.2 但它们可以是 host/runtime-agnostic + +因为它们不表达: + +- `ForwardBatch.forward_mode` +- `query_start_loc` +- prefix 命中策略 +- graph/speculative 的宿主控制语义 + +它们只表达: + +- 这次 kernel 要去哪些 KV 地址读数据 +- 每个 request 的边界在哪里 + +### 6.3 两个例子 + +#### 例子 A:`sglang` + +- `req_to_token[req_id, token_pos] = slot_id` +- lowering 后得到: + - `kv_indptr = [0, 10]` + - `kv_indices = [8, 9, 10, 11, 0, 1, 2, 3, 4, 5]` + +#### 例子 B:`vllm/ATOM` + +- 上层是 `block_table` +- 展开后同样可得到: + - `kv_indptr = [0, 10]` + - `kv_indices = [8, 9, 10, 11, 0, 1, 2, 3, 4, 5]` + +这说明: + +- 上层来源不同 +- 但 lower 后的 execution metadata 可以收敛 + +## 7. public `sglang` 与 `sglang plugin` 在 KV cache 管理上的异同 + +### 7.1 相同点 + +- owner 都还是 public `sglang` + - `ReqToTokenPool` + - `TokenToKVPool` + - `KVCache` + - `radix_cache` +- request 到 slot 的映射机制没变 +- prefix / radix 复用体系没变 + +### 7.2 不同点 + +#### A. MHA 写 cache 路径不同 + +public `sglang` 更多沿用 pool 提供的标准写入接口: + +- `token_to_kv_pool.set_kv_buffer(...)` + +而 `sglang plugin` 在 MHA 路径里会绕过标准写法,自己做一次: + +- `set_kv_buffer_with_layout_shuffle(...)` +- 把同一块 public pool buffer 按更偏 `ATOM` kernel 的 SHUFFLE layout 写进去 + +#### B. MLA 路径更接近 public `sglang` + +MLA 上,plugin 更多沿用 public pool 的 contract: + +- `token_to_kv_pool.set_kv_buffer(...)` +- `get_key_buffer(...)` +- `get_value_buffer(...)` + +#### C. Lowering 目标不同 + +- public `sglang` 的 metadata 更通用 +- plugin 的 `ForwardMetadata` 更偏 `ATOM` kernel + - `page_table` + - `pa_metadata_*` + +### 7.3 一个更准确的说法 + +`sglang plugin` **没有改写 public `sglang` 的 pool owner / allocator / radix tree**, +但它在 MHA 路径上: + +- 改变了同一块 public pool buffer 的布局解释与写入约定 +- 并把 lowering 结果继续朝 `ATOM` kernel 所需的 metadata 形态推进 + +## 8. Kernel 共用的两种 draft + +本次会话里,实际做了两种方向的 draft。 + +### 8.1 方案 A:共享 kernel wrapper + +思路: + +- 新增共享小 metadata + - `MLADecodeKernelMetadata` + - `MHAAsmKernelMetadata` + - `MHAPersistentKernelMetadata` +- 新增共享 wrapper + - `run_mla_decode_kernel` + - `run_mha_asm_kernel` + - `run_mha_persistent_kernel` + +优点: + +- 边界清晰 +- 更像长期可演进的 execution API +- 不需要统一大 metadata + +缺点: + +- 还是引入了一层新的 shared contract + +### 8.2 方案 B:`sglang plugin` 主动适配 `ATOM core` + +思路: + +- 不改 `ATOM core` +- 在 `sglang plugin` 里把 `ForwardMetadata` 转成 `AttentionMetaData` +- 再伪造最小 `shim/self` +- 直接调用: + - `PagedAttentionImpl.paged_attention_asm` + - `PagedAttentionImpl.paged_attention_persistent_asm` + - `MLAAttention._forward_decode` + +优点: + +- 不改 `ATOM core` +- 更直观地证明 “plugin 主动复用 core” 是可行的 + +缺点: + +- 需要 shim +- 尤其 MLA 路径更脆弱 +- 更适合验证可行性,不适合长期保留 + +## 9. 当前最值得保留的判断 + +### 9.1 不应该强行统一大 metadata + +不建议直接统一: + +- `AttentionMetaData` +- `plugin_metadata` +- `ForwardMetadata` + +因为它们都带有较重的宿主 runtime 语义。 + +### 9.2 最值得共享的是 kernel-facing execution metadata + +真正最有价值的共享边界是: + +- `kv_indptr` +- `kv_indices` +- `qo_indptr` +- `page_table` +- `context_lens` +- `work_*` +- `reduce_*` + +以及基于它们的 kernel call / runner。 + +### 9.3 重构顺序建议 + +如果继续推进 `sglang plugin` 重构,更合理的顺序是: + +1. 保持 `sglang` runtime 自有 +2. 显式化 metadata lowering 边界 +3. 尽量共享 kernel call / execution helper +4. 最后再考虑是否继续上推到 runner 层 + +## 10. 一句话总结 + +本次会话最终收敛出的关键认识是: + +**`sglang plugin attention` 重构的关键,不是把 `sglang` runtime 改得更像 `vllm`,而是把 runtime、metadata lowering、kernel execution 这三层拆清楚;只要 lowering 之后能在 execution metadata 上对齐,就能在不统一宿主语义的前提下,共用更多 `ATOM` 的底层 attention 代码。** diff --git a/work_log/attn_refractory/decoupling-diagram.md b/work_log/attn_refractory/decoupling-diagram.md new file mode 100644 index 0000000000..b95c0031e5 --- /dev/null +++ b/work_log/attn_refractory/decoupling-diagram.md @@ -0,0 +1,158 @@ +# ATOM Plugin Decoupling Diagram + +下面这张图对比了当前入口架构和目标解耦架构。这里采用的是更激进、也更干净的方案: + +- `vLLM plugin` 和 `SGLang plugin` 在 `atom/plugin/` 层彻底分家 +- 不再保留共享的 plugin runtime +- 共享只发生在更底层的 `ATOM core` + +## Current + +```mermaid +flowchart TD + A[Shared Plugin Entry
prepare.py / register.py / config.py] + B[Global Runtime State
_CURRENT_FRAMEWORK
current_atom_config
ops.Attention] + C[vLLM Branch
platform / registry / patch / graph] + D[SGLang Branch
wrapper / registry hijack / RadixAttention / graph] + E[Mixed Plugin Core
attention selection / static_forward_context / loader glue] + + A --> B + B --> C + B --> D + C --> E + D --> E +``` + +### 当前设计的核心问题 + +- 入口层通过切换全局状态来区分 `vLLM` 和 `SGLang`,而不是通过物理隔离的子系统建立边界。 +- `prepare.py` / `register.py` / `config.py` 这些共享入口,本身就是耦合中心。 +- host 差异直接泄漏到 plugin 内部结构里,导致 backend 选择、attention 抽象和 runtime context 都知道上层框架。 + +## Target + +```mermaid +flowchart TD + A1[vLLM Plugin
bootstrap.py
register.py
config.py
runtime.py
attention glue
model adapters] + A2[SGLang Plugin
bootstrap.py
register.py
config.py
runtime.py
attention glue
model adapters] + + B[ATOM Core
model_ops
kernels
loader
metadata structs
model specialization] + + A1 --> B + A2 --> B +``` + +### 目标设计的关键点 + +- `vLLM` 和 `SGLang` 在 plugin 层是两个完整子系统,而不是一套共享 runtime 上的两个分支。 +- `atom/plugin/` 下不再存在共享的 `prepare_model()`、共享的 `register.py` 总入口、共享的 plugin config translator。 +- `vLLM plugin` 自己处理: + - platform registration + - model override + - graph / MLA / weight hook patch + - vLLM model adapter / attention glue +- `SGLang plugin` 自己处理: + - external model wrapper + - attention backend registration / override + - graph / speculative / wrapper patch + - SGLang model adapter / attention glue +- 真正共享的只保留在更底层的 `ATOM core`: + - `model_ops` + - kernels + - loader + - generic metadata structures + - model-family specialization + +## 最重要的切断点 + +```mermaid +flowchart LR + A[Cut 1
shared prepare.py] + B[Cut 2
shared register.py] + C[Cut 3
shared config.py] + D[Cut 4
framework global state] + E[Cut 5
ops.Attention global rebinding] + + A --> F[vllm/ and sglang/ own bootstrap] + B --> F + C --> F + D --> G[host-local runtime only] + E --> H[host-specific attention glue] +``` + +## 推荐目录形态 + +```text +atom/plugin/ + vllm/ + bootstrap.py + register.py + config.py + runtime.py + attention/ + models/ + graph/ + sglang/ + bootstrap.py + register.py + config.py + runtime.py + attention/ + models/ + graph/ +``` + +这里的重点不是“换目录”,而是: + +- `vllm/runtime.py` 和 `sglang/runtime.py` 各自独立 +- `vllm/config.py` 和 `sglang/config.py` 各自独立 +- `vllm/register.py` 和 `sglang/register.py` 各自独立 +- 不再有 plugin 级共享 runtime + +## 术语说明 + +### bootstrap + +`bootstrap` 在这里指的是“插件被宿主框架加载时,最先执行的接线初始化层”。 + +它负责的事情通常包括: + +- 注册 plugin 扩展点 +- 安装 patch / hook +- 构造 host adapter / wrapper +- 把宿主配置翻译到 ATOM 可用的格式 + +它不应该负责: + +- attention forward 本身 +- 长期 runtime state 管理 +- 每次请求都要走的核心执行逻辑 + +### register + +`register` 指的是把 plugin 的能力挂到宿主暴露的扩展点上,例如: + +- 注册 platform +- 注册 model class +- 注册 attention backend + +它更偏“声明接入点”,通常是 bootstrap 的一部分。 + +### adapter + +`adapter` 指的是宿主框架和 ATOM core 之间的桥接层。 +它的职责是做接口和参数语义翻译,而不是定义 attention / model 的核心语义。 + +### runtime + +`runtime` 指的是插件在宿主里长期存在、服务执行链路的那部分运行时逻辑。 +如果按这份设计推进,`runtime` 应该是 host-local 的: + +- `vLLM runtime` 只属于 `vLLM plugin` +- `SGLang runtime` 只属于 `SGLang plugin` + +而不是共享一个“Plugin runtime”。 + +## 一句话结论 + +真正的解耦不是把共享入口再包装一层,而是把 `atom/plugin` 从“共享状态机”改成“两个完全独立的 host plugin 子系统”,共享只留给更底层的 `ATOM core`。 diff --git a/work_log/attn_refractory/vllm-integration-architecture.md b/work_log/attn_refractory/vllm-integration-architecture.md new file mode 100644 index 0000000000..e8876268d0 --- /dev/null +++ b/work_log/attn_refractory/vllm-integration-architecture.md @@ -0,0 +1,269 @@ +# ATOM vLLM Plugin Integration Architecture + +## 1. 启动入口 + +vLLM plugin 的 setuptools entrypoint 定义在: + +- `pyproject.toml` + +关键位置: + +- `vllm.platform_plugins` + - `atom = "atom.plugin.vllm.register:register_platform"` +- `vllm.general_plugins` + - `atom_model_registry = "atom.plugin.vllm.register:register_model"` + +这意味着 vLLM 启动时会先进入: + +- `atom/plugin/vllm/register.py::register_platform()` +- `atom/plugin/vllm/register.py::register_model()` + +## 2. 关键代码文件与职责 + +### 2.1 `atom/plugin/vllm/register.py` + +这是 vLLM plugin 的启动接线层,主要负责: + +- 设置 plugin mode +- 返回自定义 platform +- 覆盖 vLLM 的 `ModelRegistry` +- 安装 MLA patch +- patch `Attention.process_weights_after_loading` +- 安装 graph capture patch + +重点符号: + +- `register_platform()` +- `register_model()` +- `_VLLM_MODEL_REGISTRY_OVERRIDES` + +## 2.2 `atom/plugin/vllm/platform.py` + +这是 vLLM 平台侧 attention backend 选择入口。 + +重点符号: + +- `ATOMPlatform.get_attn_backend_cls()` + +它根据 `attn_selector_config` 决定: + +- MHA -> `atom.model_ops.attentions.aiter_attention.AiterBackend` +- MLA -> `atom.model_ops.attentions.aiter_mla.AiterMLABackend` +- sparse MLA -> `atom.plugin.vllm.attention_backend.mla_sparse.AiterMLASparseBackend` + +这里是 vLLM runtime 真正选择 “哪个 ATOM attention backend class” 的入口。 + +## 2.3 `atom/plugin/vllm/model_wrapper.py` + +这是 vLLM model integration 的核心文件。 + +主要职责: + +- 把 `VllmConfig` 翻译成 `atom_config` +- 做 plugin 运行环境准备 +- 选择并实例化 ATOM 模型类 +- 把 vLLM forward context 中的 `positions` 传给 ATOM 路径 +- 对 sparse MLA indexer 做额外注册 + +重点符号: + +- `_prepare_env()` +- `ATOMModelBase.__init__()` +- `_register_indexer_caches_with_vllm()` +- `forward()` + +## 2.4 `atom/plugin/config.py` + +这是 vLLM 配置翻译的桥。 + +重点符号: + +- `generate_atom_config_for_vllm_plugin()` +- `_generate_atom_config_from_vllm_config()` + +它负责把: + +- `VllmConfig` +- `scheduler_config` +- `cache_config` +- `parallel_config` +- `quant_config` + +映射成 ATOM 自己的 `Config` 与 `PluginConfig`。 + +## 2.5 `atom/model_ops/paged_attention.py` + +这是 ATOM attention 在 vLLM plugin 下真正进入执行的关键对象。 + +重点符号: + +- `PagedAttention.__init__()` +- `PagedAttention.forward()` + +在 `is_vllm()` 下,它不会直接走 native/server 的那套 `impl` 初始化,而是: + +- 构造 vLLM 的 `Attention` / `MLAAttention` 外壳 +- 把额外 impl 参数塞进去 +- 注册到 `static_forward_context` +- forward 时通过 `unified_attention_with_output_base_for_plugin_mode()` 回调到 ATOM impl + +## 2.6 `atom/plugin/attention.py` + +这是 vLLM plugin mode 的核心 glue 层。 + +虽然文件在 `plugin/` 根目录,但从职责看它明显偏 vLLM。 + +主要职责: + +- 定义 `vllmAiterAttentionBackendMethods` +- 提供 backend decorator / metadata builder decorator +- 处理 plugin-mode metadata +- 处理 `static_forward_context` +- 为 vLLM 的 attention backend 补 OOT 接口行为 + +重点符号: + +- `vllmAiterAttentionBackendMethods` +- 各类 `*DecoratorForPluginMode` + +## 2.7 `atom/plugin/attention_mha.py` + +这是 vLLM plugin mode 下 MHA impl 的补丁层。 + +重点符号: + +- `PagedAttentionImplPluginModeMethods` + +它补的是: + +- rope/cache/update +- plugin-mode forward +- 与 vLLM KV cache/metadata 对齐的逻辑 + +## 2.8 `atom/plugin/attention_mla.py` + +这是 vLLM plugin mode 下 MLA impl 的补丁层。 + +重点符号: + +- `MLAAttentionImplPluginModeMethods` +- `_mla_plugin_mode_init` + +它补的是: + +- MLA plugin-mode metadata +- qk rope / kv cache / chunked prefill +- vLLM MLA 路径专有行为 + +## 2.9 `atom/plugin/vllm/mla_patch.py` + +这是 vLLM 原生 `MLAAttention` 的 monkey patch 入口。 + +重点符号: + +- `_patch_vllm_mla_attention_forward_impl()` +- `_patch_vllm_mla_attention_process_weights_after_loading()` +- `patch_vllm_mla_attention()` + +这层作用是: + +- 把 vLLM `MLAAttention.forward_impl()` 改接到 ATOM impl +- 把权重后处理改接到 ATOM 的 `impl.process_weights_after_loading()` + +## 2.10 `atom/plugin/vllm/graph_capture_patch.py` + +这是 vLLM graph capture 补丁入口。 + +重点符号: + +- `apply_graph_capture_patch()` + +它实际委托到共享实现: + +- `atom/plugin/graph_capture_patch.py` + +但调用点是 vLLM 自己的 plugin register 流程。 + +## 3. 当前 vLLM plugin 集成链路 + +```mermaid +flowchart TD + A[vLLM startup] + B[entrypoint
register_platform] + C[entrypoint
register_model] + D[ATOMPlatform.get_attn_backend_cls] + E[ATOMModelBase] + F[generate_atom_config_for_vllm_plugin] + G[_prepare_env
set_attn_cls + init_aiter_dist] + H[Instantiate ATOM model] + I[PagedAttention / MLAAttention wrapper] + J[plugin/attention*.py glue] + K[ATOM impl
PagedAttentionImpl / MLAAttention] + L[aiter kernels] + + A --> B + A --> C + B --> D + C --> E + E --> F + E --> G + E --> H + H --> I + I --> J + J --> K + K --> L +``` + +## 4. 最关键的架构判断 + +### 4.1 vLLM plugin 不是简单“替换 backend class” + +它实际上同时做了三件事: + +1. 替换 vLLM 的 model entry +2. 替换/选择 vLLM 的 attention backend class +3. 用 patch + decorator 的方式把 vLLM 的 layer/runtime 语义桥接到 ATOM impl + +### 4.2 真正共享得最好的是 native/server 与 vLLM plugin 的 attention impl + +尤其: + +- `PagedAttentionImpl` +- `MLAAttention` +- 低层 `aiter` kernel family + +也就是说,vLLM plugin 侧的关键复杂度主要不在 kernel,而在: + +- `ModelRegistry` override +- `VllmConfig -> atom_config` +- `static_forward_context` +- `MLAAttention` patch +- graph capture patch + +### 4.3 当前 `plugin/attention.py`、`attention_mha.py`、`attention_mla.py` + +虽然放在 `atom/plugin/` 根目录,但从 vLLM integration 的角度看,它们本质上更像: + +- `vLLM plugin impl glue` +- 而不是真正框架无关的 shared runtime + +## 5. 推荐你重点阅读的文件顺序 + +如果你是为了快速理解 vLLM plugin 集成架构,建议按这个顺序读: + +1. `pyproject.toml` +2. `atom/plugin/vllm/register.py` +3. `atom/plugin/vllm/platform.py` +4. `atom/plugin/vllm/model_wrapper.py` +5. `atom/plugin/config.py` +6. `atom/model_ops/paged_attention.py` +7. `atom/plugin/vllm/mla_patch.py` +8. `atom/plugin/attention.py` +9. `atom/plugin/attention_mha.py` +10. `atom/plugin/attention_mla.py` + +## 6. 一句话结论 + +vLLM plugin 的集成方式,本质上是: + +**用 entrypoint + platform + model wrapper 接入 vLLM,用 ATOM 自己的 attention impl 和 aiter kernels 提供核心执行,再用 patch/decorator 把 vLLM 的 layer/runtime 语义桥接进去。** From 755cbf76952d78ae8bbca669b14964582ed364c6 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Tue, 12 May 2026 08:29:41 +0000 Subject: [PATCH 02/17] [ATOM-SGL][Attn refrac] Separate model-specific MLA from SGL full attention backend --- atom/plugin/register.py | 2 +- .../full_attention/__init__.py | 8 ++ .../full_attention_backend.py} | 12 +- .../{ => full_attention}/radix_attention.py | 0 .../sglang/models/base_model_wrapper.py | 2 +- atom/plugin/sglang/models/deepseek_mla.py | 77 +++++++++++ .../deepseek_mla_forward.py} | 130 +++--------------- tests/plugin/test_sglang_model_wrapper.py | 6 +- tests/plugin/test_sglang_register.py | 21 ++- 9 files changed, 130 insertions(+), 128 deletions(-) create mode 100644 atom/plugin/sglang/attention_backend/full_attention/__init__.py rename atom/plugin/sglang/attention_backend/{sgl_attn_backend.py => full_attention/full_attention_backend.py} (99%) rename atom/plugin/sglang/attention_backend/{ => full_attention}/radix_attention.py (100%) create mode 100644 atom/plugin/sglang/models/deepseek_mla.py rename atom/plugin/sglang/{attention_backend/sgl_attention_mla.py => models/deepseek_mla_forward.py} (90%) diff --git a/atom/plugin/register.py b/atom/plugin/register.py index f5db8a59a6..90ac34e71e 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -45,7 +45,7 @@ def _register_custom_attention_to_sglang() -> None: from sglang.srt.layers.attention.attention_registry import ( register_attention_backend, ) - from atom.plugin.sglang.attention_backend.sgl_attn_backend import ( + from atom.plugin.sglang.attention_backend.full_attention.full_attention_backend import ( ATOMAttnBackendForSgl, ) diff --git a/atom/plugin/sglang/attention_backend/full_attention/__init__.py b/atom/plugin/sglang/attention_backend/full_attention/__init__.py new file mode 100644 index 0000000000..2b9f6f2726 --- /dev/null +++ b/atom/plugin/sglang/attention_backend/full_attention/__init__.py @@ -0,0 +1,8 @@ +from .radix_attention import RadixAttention +from .full_attention_backend import ATOMAttnBackendForSgl, ForwardMetadata + +__all__ = [ + "RadixAttention", + "ATOMAttnBackendForSgl", + "ForwardMetadata", +] diff --git a/atom/plugin/sglang/attention_backend/sgl_attn_backend.py b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py similarity index 99% rename from atom/plugin/sglang/attention_backend/sgl_attn_backend.py rename to atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py index 9fab7eaff5..40fe70b9e7 100644 --- a/atom/plugin/sglang/attention_backend/sgl_attn_backend.py +++ b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py @@ -1,10 +1,10 @@ from __future__ import annotations -# sglang-specific attention backend replacing sglang's built-in AiterAttnBackend. -# Shared by ALL models (DeepSeek, Qwen3, etc.) — handles KV cache writes, -# page-table fixup, pa_persistent_fwd decode path, and MLA prefill kernels. -# Sits at the lowest layer of the attention stack: sglang's RadixAttention -# delegates the actual kernel dispatch here. +# SGLang full-attention backend replacing sglang's built-in AiterAttnBackend. +# Shared by ALL full-attention models (DeepSeek, Qwen3, etc.) — handles KV +# cache writes, page-table fixup, pa_persistent_fwd decode path, and MLA +# prefill kernels. Sits at the lowest layer of the attention stack: +# sglang's RadixAttention delegates the actual kernel dispatch here. # # TODO: rewrite this file once sglang's attention flow is unified into ATOM's # attention layer — KV cache management and attention kernel dispatch will then @@ -47,7 +47,7 @@ except ImportError as e: raise ImportError( "Failed to import 'aiter', which provides AMD-specific attention kernels " - "required by sgl_attn_backend. Please ensure 'aiter' is installed and " + "required by full_attention_backend. Please ensure 'aiter' is installed and " f"available on your AMD system. Original import error: {e}" ) from e diff --git a/atom/plugin/sglang/attention_backend/radix_attention.py b/atom/plugin/sglang/attention_backend/full_attention/radix_attention.py similarity index 100% rename from atom/plugin/sglang/attention_backend/radix_attention.py rename to atom/plugin/sglang/attention_backend/full_attention/radix_attention.py diff --git a/atom/plugin/sglang/models/base_model_wrapper.py b/atom/plugin/sglang/models/base_model_wrapper.py index 3f2c743b14..dbace2e647 100644 --- a/atom/plugin/sglang/models/base_model_wrapper.py +++ b/atom/plugin/sglang/models/base_model_wrapper.py @@ -386,7 +386,7 @@ def __init__( # Apply ds model-specific sglang patches (attn dispatch, weight hooks, etc.) # TODO: will remove this after sglang supports atom attention backend if self.model_arch_spec.apply_deepseek_patch: - from atom.plugin.sglang.attention_backend.sgl_attention_mla import ( + from atom.plugin.sglang.models.deepseek_mla import ( setup_deepseek_for_sglang, ) diff --git a/atom/plugin/sglang/models/deepseek_mla.py b/atom/plugin/sglang/models/deepseek_mla.py new file mode 100644 index 0000000000..1d7b1e44bc --- /dev/null +++ b/atom/plugin/sglang/models/deepseek_mla.py @@ -0,0 +1,77 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""Model-level DeepSeek MLA patching for SGLang plugin mode. + +This module owns the monkey-patch entrypoints that adapt DeepSeek MLA models to +SGLang plugin mode. The heavy DeepSeek-specific forward and weight helpers live +in `atom.plugin.sglang.models.deepseek_mla_forward`. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import torch + +from atom.plugin.sglang.models.deepseek_mla_forward import ( + _patch_kv_b_proj_for_sglang_mxfp4, + forward_sgl_plugin_mode, + init_sgl_attrs, + process_mla_kv_b_proj_after_loading, +) + +if TYPE_CHECKING: + from atom.models.deepseek_v2 import DeepseekV2MLAAttention + + +def setup_deepseek_for_sglang(model) -> None: + """Patch a DeepSeek V2/V3 model for SGLang plugin mode.""" + config = model.config + + # Store atom_config for the OOT wrapper before install-time hooks run. + if not hasattr(model, "atom_config"): + from atom.config import get_current_atom_config + + model.atom_config = get_current_atom_config() + + kv_cache_dtype = model.atom_config.kv_cache_dtype + + # Initialise SGLang's MLA TP context before patching per-layer forwards. + from sglang.srt.configs.model_config import is_deepseek_nsa + from sglang.srt.layers.communicator import get_attn_tp_context + + get_attn_tp_context().init_context(config.q_lora_rank, is_deepseek_nsa(config)) + + from atom.models.deepseek_v2 import DeepseekV2MLAAttention + + for module in model.modules(): + if isinstance(module, DeepseekV2MLAAttention): + _patch_mla_attention_for_sglang(module, config, kv_cache_dtype) + + +def _patch_mla_attention_for_sglang( + attn: "DeepseekV2MLAAttention", + config: Any, + kv_cache_dtype: str = "bf16", +) -> None: + """Patch one DeepSeek MLA layer for SGLang plugin mode.""" + init_sgl_attrs(attn, config, kv_cache_dtype) + _patch_kv_b_proj_for_sglang_mxfp4(attn) + + def patched_forward( + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + from atom.plugin.sglang.models.base_model_wrapper import ( + get_current_forward_batch, + ) + + kwargs["forward_batch"] = get_current_forward_batch() + return forward_sgl_plugin_mode(attn, positions, hidden_states, **kwargs) + + attn.forward = patched_forward + attn.process_weights_after_loading = lambda: process_mla_kv_b_proj_after_loading( + attn + ) diff --git a/atom/plugin/sglang/attention_backend/sgl_attention_mla.py b/atom/plugin/sglang/models/deepseek_mla_forward.py similarity index 90% rename from atom/plugin/sglang/attention_backend/sgl_attention_mla.py rename to atom/plugin/sglang/models/deepseek_mla_forward.py index 8a9cbb1f48..80fb075e4f 100644 --- a/atom/plugin/sglang/attention_backend/sgl_attention_mla.py +++ b/atom/plugin/sglang/models/deepseek_mla_forward.py @@ -1,17 +1,19 @@ -"""Sglang-specific MLA forward and weight processing for DeepseekV2/V3. +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -DeepSeek MLA (Multi-Latent Attention) forward logic for sglang plugin mode: +"""Model-specific DeepSeek MLA helpers for SGLang plugin mode. + +DeepSeek MLA (Multi-Latent Attention) forward logic for SGLang plugin mode: absorbed BMM computation, MHA/MLA path dispatch (prefill -> MHA, decode -> MLA), -kv_b_proj weight splitting (w_kc/w_vc), and monkey-patch setup via -setup_deepseek_for_sglang(). +and kv_b_proj weight splitting (w_kc/w_vc). -This module is lazily imported from base_model_wrapper.py only when running in -sglang plugin mode (``is_sglang() == True``). Keeping all sglang-dependent -imports here avoids crashing when sglang is not installed. +This module lives under ``atom.plugin.sglang.models`` because the logic is +DeepSeek-model-specific rather than a generic SGLang attention backend. TODO: rewrite this file once sglang's attention flow is unified into ATOM's -attention layer — the MLA absorbed path and MHA dispatch will then be handled -natively by ATOM's attention ops, making this sglang-specific module unnecessary. +attention layer - the MLA absorbed path and MHA dispatch will then be handled +natively by ATOM's attention ops, making this sglang-specific module +unnecessary. """ from __future__ import annotations @@ -31,7 +33,6 @@ from atom.models.utils import maybe_prefix from atom.models.deepseek_v2 import _fuse_rmsnorm_quant -# sglang imports from sglang.srt.layers.communicator import AttentionInputs, get_attn_tp_context from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode @@ -99,7 +100,6 @@ def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): raise RuntimeError("bmm_fp8 requires CUDA (sgl_kernel)") -# NamedTuple for prepare → core data flow class SglPrepareResult(NamedTuple): q_pe: torch.Tensor k_pe: torch.Tensor @@ -198,7 +198,6 @@ def _prepare_weight_for_bmm( ) -# Init helpers def init_sgl_attrs( attn: DeepseekV2MLAAttention, config, @@ -232,7 +231,6 @@ def init_sgl_attrs( attn.attn_mha.attn.kv_b_proj = None -# Absorbed batched-matmul (shared by prepare and core) def mla_absorbed_bmm( attn: DeepseekV2MLAAttention, inp: torch.Tensor, @@ -241,12 +239,7 @@ def mla_absorbed_bmm( weight_scale_k: Optional[torch.Tensor], out_dim: int, ) -> torch.Tensor: - """Batched matmul for MLA absorbed weights (w_kc / w_vc). - - Handles deep_gemm, mxfp4, fp8-triton, fp8-cublas, and bf16 fallback paths. - inp: (num_tokens, num_heads, in_dim) — token-major - Returns: (num_tokens, num_heads, out_dim) — token-major - """ + """Batched matmul for MLA absorbed weights (w_kc / w_vc).""" effective_weight_scale = ( weight_scale_k if weight_scale_k is not None else weight_scale ) @@ -315,7 +308,6 @@ def mla_absorbed_bmm( ) return out.transpose(0, 1) - # CUDA fp8 path if weight.dtype == torch.float8_e4m3fn: val, scale = per_tensor_quant_mla_fp8( inp.transpose(0, 1), @@ -324,7 +316,6 @@ def mla_absorbed_bmm( out = bmm_fp8(val, weight, scale, effective_weight_scale, torch.bfloat16) return out.transpose(0, 1) - # bf16 fallback return torch.bmm(inp.transpose(0, 1), weight).transpose(0, 1) @@ -388,14 +379,13 @@ def mla_v_up_proj( ).flatten(1, 2) -# Forward: prepare → core def forward_sgl_prepare( attn: DeepseekV2MLAAttention, positions: torch.Tensor, hidden_states: torch.Tensor, **model_kwargs, ) -> SglPrepareResult: - """Prepare QKV for sglang MLA attention (adapted from sglang forward_absorb_prepare).""" + """Prepare QKV for sglang MLA attention.""" hidden_states_scale = None if isinstance(hidden_states, tuple): hidden_states, hidden_states_scale = hidden_states @@ -437,9 +427,6 @@ def forward_sgl_prepare( k_nope = latent_cache[..., : attn.kv_lora_rank] q_scale = None - # Reuse native ATOM gating for q/k RMSNorm fusion. Quant fusion is used - # when DeepSeek enables qknorm-quant; otherwise keep the non-quant fused - # path aligned with native ATOM before falling back to plain layernorm. if getattr(attn, "fuse_qknorm_quant", False): q, q_scale, q_lora, k_nope = _fuse_qk_rmsnorm_and_q_quant( attn, @@ -449,7 +436,6 @@ def forward_sgl_prepare( ) elif getattr(attn, "fuse_qknorm", False): q, k_nope = _fuse_qk_rmsnorm(attn, q, k_nope) - # Otherwise keep the original overlap path for unfused qk norm. elif attn.alt_stream is not None and get_is_capture_mode(): current_stream = torch.cuda.current_stream() attn.alt_stream.wait_stream(current_stream) @@ -461,11 +447,9 @@ def forward_sgl_prepare( q = attn.q_a_layernorm(q) k_nope = attn.kv_a_layernorm(k_nope) - if attn.use_nsa: - if q_lora is None: - q_lora = q + if attn.use_nsa and q_lora is None: + q_lora = q - # overlap q_b_proj and indexer during decode if ( attn.alt_stream is not None and get_is_capture_mode() @@ -542,7 +526,7 @@ def forward_sgl_core( attn: DeepseekV2MLAAttention, prepared: SglPrepareResult, ) -> torch.Tensor: - """Core MLA attention computation for sglang (adapted from sglang forward_absorb_core).""" + """Core MLA attention computation for sglang.""" save_kv_cache = True if attn.use_fused_qk_rope_concat_and_cache_mla: @@ -581,7 +565,6 @@ def forward_sgl_core( is_neox=attn.rotary_emb.is_neox_style, is_nope_first=True, ) - # Decode/speculative MLA consumes q plus packed MLA cache directly. k = None v = None save_kv_cache = False @@ -607,7 +590,6 @@ def forward_sgl_core( ) attn_output = attn_output.view(-1, attn.num_local_heads, attn.kv_lora_rank) - # up-proj by w_vc attn_bmm_output = mla_v_up_proj( attn, attn_output, attn.w_vc, attn.w_scale, attn.w_scale_v, attn.v_head_dim ) @@ -922,8 +904,6 @@ def prepare_qkv_latent( hidden_states, hidden_states_scale = hidden_states qkv_lora = attn.fused_qkv_a_proj(hidden_states, hidden_states_scale) - # Fallback: when communicator does not enable input_scattered gather, - # force qkv latent token dimension to align with positions. expected_tokens = 0 if hasattr(forward_batch, "positions") and forward_batch.positions is not None: expected_tokens = int(forward_batch.positions.shape[0]) @@ -946,7 +926,6 @@ def prepare_qkv_latent( return qkv_lora -# Top-level forward entry point def forward_sgl_plugin_mode( attn: DeepseekV2MLAAttention, positions: torch.Tensor, @@ -987,7 +966,6 @@ def forward_sgl_plugin_mode( raise ValueError(f"Unsupported plugin attention path: {attn_path}") -# Weight post-processing: decomposed into sub-functions def _read_kv_b_proj_weight(attn: DeepseekV2MLAAttention) -> torch.Tensor: """Read kv_b_proj weight, handling AWQ and fnuz dtypes.""" if hasattr(attn.kv_b_proj, "qweight"): @@ -1018,8 +996,6 @@ def _read_kv_b_proj_weight(attn: DeepseekV2MLAAttention) -> torch.Tensor: else: w = attn.kv_b_proj.weight - # On ROCm, ATOM creates parameters with fnuz dtype but loads fn bytes. - # View-cast back to fn so the normalize path works correctly. if _is_fp8_fnuz and w.dtype == torch.float8_e4m3fnuz: w = w.view(torch.float8_e4m3fn) @@ -1043,10 +1019,7 @@ def _process_fp8_weight( w: torch.Tensor, weight_block_size: Optional[list[int]], ) -> tuple[torch.Tensor, bool, Optional[torch.Tensor]]: - """Process FP8 weights for kv_b_proj. - - Returns (w, use_deep_gemm_bmm, block_scale). - """ + """Process FP8 weights for kv_b_proj.""" from atom.model_ops.utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import ( block_quant_dequant, @@ -1215,8 +1188,6 @@ def _split_and_assign_kc_vc( w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) w_vc = w_vc.contiguous().transpose(1, 2) - # Align bf16 kv_b_proj post-load handling with vLLM: split first, then - # quantize kc/vc independently for the fp8 BMM path. if w.dtype == torch.bfloat16 and (_is_hip or _is_cuda): w_kc, w_scale_k = dynamic_per_batched_tensor_quant(w_kc, dtype=dtypes.fp8) w_vc, w_scale_v = dynamic_per_batched_tensor_quant(w_vc, dtype=dtypes.fp8) @@ -1262,17 +1233,14 @@ def process_mla_kv_b_proj_after_loading(attn: DeepseekV2MLAAttention) -> None: use_deep_gemm_bmm = False block_scale = None - # fp8 path if w.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): w, use_deep_gemm_bmm, block_scale = _process_fp8_weight( attn, w, weight_block_size ) - # int8 path if w.dtype == torch.int8: w = _process_int8_weight(attn, w, weight_block_size) - # split and assign kc/vc _split_and_assign_kc_vc(attn, w, use_deep_gemm_bmm, block_scale, weight_block_size) @@ -1309,67 +1277,3 @@ def process_weights_after_loading_with_mxfp4_preserve(): process_weights_after_loading_with_mxfp4_preserve ) kv_b_proj._sgl_mxfp4_preserve_patched = True - - -# One-time model setup (called from base_model_wrapper.py) -def setup_deepseek_for_sglang(model) -> None: - """Patch a DeepseekV2/V3 model for sglang plugin mode. - - - Initialises sglang TP context - - Patches each MLAAttention.forward to dispatch to the sglang MLA path - - Registers process_weights_after_loading hooks - - Stores atom_config on the model - """ - config = model.config - - # Store atom_config (needed by load_weights in the OOT wrapper) - if not hasattr(model, "atom_config"): - from atom.config import get_current_atom_config - - model.atom_config = get_current_atom_config() - - kv_cache_dtype = model.atom_config.kv_cache_dtype - - # Initialise sglang TP context for MLA gather/scatter - from sglang.srt.configs.model_config import is_deepseek_nsa - from sglang.srt.layers.communicator import get_attn_tp_context - - get_attn_tp_context().init_context(config.q_lora_rank, is_deepseek_nsa(config)) - - # Patch each MLAAttention instance - from atom.models.deepseek_v2 import DeepseekV2MLAAttention - - for module in model.modules(): - if isinstance(module, DeepseekV2MLAAttention): - _patch_mla_attention_for_sglang(module, config, kv_cache_dtype) - - -def _patch_mla_attention_for_sglang(attn, config, kv_cache_dtype: str = "bf16") -> None: - """Patch a single DeepseekV2MLAAttention for sglang plugin mode. - - We patch attn.forward (rather than relying solely on ops.Attention = - RadixAttention) because MLA's absorbed-weight forward path replaces the - *entire* forward method — including RoPE, and absorbed - BMM — not just the attention backend. ops.Attention = RadixAttention - handles the backend layer (flash_attn / paged_attn dispatch) and is - already set via set_attn_cls(); this patch sits above that layer. - """ - init_sgl_attrs(attn, config, kv_cache_dtype) - _patch_kv_b_proj_for_sglang_mxfp4(attn) - - def patched_forward( - positions: torch.Tensor, - hidden_states: torch.Tensor, - **kwargs, - ) -> torch.Tensor: - from atom.plugin.sglang.models.base_model_wrapper import ( - get_current_forward_batch, - ) - - kwargs["forward_batch"] = get_current_forward_batch() - return forward_sgl_plugin_mode(attn, positions, hidden_states, **kwargs) - - attn.forward = patched_forward - attn.process_weights_after_loading = lambda: process_mla_kv_b_proj_after_loading( - attn - ) diff --git a/tests/plugin/test_sglang_model_wrapper.py b/tests/plugin/test_sglang_model_wrapper.py index e4015ed9dc..20d0e07923 100644 --- a/tests/plugin/test_sglang_model_wrapper.py +++ b/tests/plugin/test_sglang_model_wrapper.py @@ -54,8 +54,7 @@ def _make_fake_modules(*, is_last_rank: bool, setup_hook=None) -> dict[str, Modu forward_batch_mod.ForwardBatch = object forward_batch_mod.PPProxyTensors = object - attn_backend_pkg = _package("atom.plugin.sglang.attention_backend") - mla_mod = ModuleType("atom.plugin.sglang.attention_backend.sgl_attention_mla") + mla_mod = ModuleType("atom.plugin.sglang.models.deepseek_mla") mla_mod.setup_deepseek_for_sglang = setup_hook or (lambda model: None) return { @@ -68,8 +67,7 @@ def _make_fake_modules(*, is_last_rank: bool, setup_hook=None) -> dict[str, Modu "sglang.srt.layers.quantization.base_config": quant_base_mod, "sglang.srt.model_executor": model_executor_pkg, "sglang.srt.model_executor.forward_batch_info": forward_batch_mod, - "atom.plugin.sglang.attention_backend": attn_backend_pkg, - "atom.plugin.sglang.attention_backend.sgl_attention_mla": mla_mod, + "atom.plugin.sglang.models.deepseek_mla": mla_mod, } diff --git a/tests/plugin/test_sglang_register.py b/tests/plugin/test_sglang_register.py index 562aaf8bc8..1adb20fb42 100644 --- a/tests/plugin/test_sglang_register.py +++ b/tests/plugin/test_sglang_register.py @@ -324,10 +324,13 @@ def __init__(self, runner): "atom.models.qwen3_moe": ModuleType("atom.models.qwen3_moe"), "atom.models.glm4_moe": ModuleType("atom.models.glm4_moe"), "atom.models.deepseek_v2": ModuleType("atom.models.deepseek_v2"), + "atom.models.minimax_m2": ModuleType("atom.models.minimax_m2"), + "atom.models.qwen3_next": ModuleType("atom.models.qwen3_next"), + "atom.models.qwen3_5": ModuleType("atom.models.qwen3_5"), "atom.config": ModuleType("atom.config"), "atom.plugin.prepare": fake_prepare_mod, - "atom.plugin.sglang.attention_backend.sgl_attn_backend": ModuleType( - "atom.plugin.sglang.attention_backend.sgl_attn_backend" + "atom.plugin.sglang.attention_backend.full_attention.full_attention_backend": ModuleType( + "atom.plugin.sglang.attention_backend.full_attention.full_attention_backend" ), } fake_modules["atom.models.qwen3"].Qwen3ForCausalLM = type( @@ -342,9 +345,21 @@ def __init__(self, runner): fake_modules["atom.models.deepseek_v2"].DeepseekV3ForCausalLM = type( "DeepseekV3ForCausalLM", (), {} ) + fake_modules["atom.models.minimax_m2"].MiniMaxM2ForCausalLM = type( + "MiniMaxM2ForCausalLM", (), {} + ) + fake_modules["atom.models.qwen3_next"].Qwen3NextForCausalLM = type( + "Qwen3NextForCausalLM", (), {} + ) + fake_modules["atom.models.qwen3_5"].Qwen3_5ForCausalLM = type( + "Qwen3_5ForCausalLM", (), {} + ) + fake_modules["atom.models.qwen3_5"].Qwen3_5MoeForCausalLM = type( + "Qwen3_5MoeForCausalLM", (), {} + ) fake_modules["atom.config"].Config = type("Config", (), {}) fake_modules[ - "atom.plugin.sglang.attention_backend.sgl_attn_backend" + "atom.plugin.sglang.attention_backend.full_attention.full_attention_backend" ].ATOMAttnBackendForSgl = _FakeBackend with patch.dict(sys.modules, fake_modules): From 83c5f69590257abf9106a2e819af182b50b9ea03 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Tue, 12 May 2026 08:33:02 +0000 Subject: [PATCH 03/17] remove work log --- .../2026-04-30-attn-refractory.md | 354 ------------- ...26-05-07-radix-cache-vs-paged-attention.md | 482 ------------------ ...26-05-08-kv-indices-address-space-notes.md | 359 ------------- .../2026-05-08-metadata-layering-summary.md | 144 ------ ...6-05-11-attn-refractory-session-summary.md | 398 --------------- .../attn_refractory/decoupling-diagram.md | 158 ------ .../vllm-integration-architecture.md | 269 ---------- 7 files changed, 2164 deletions(-) delete mode 100644 work_log/attn_refractory/2026-04-30-attn-refractory.md delete mode 100644 work_log/attn_refractory/2026-05-07-radix-cache-vs-paged-attention.md delete mode 100644 work_log/attn_refractory/2026-05-08-kv-indices-address-space-notes.md delete mode 100644 work_log/attn_refractory/2026-05-08-metadata-layering-summary.md delete mode 100644 work_log/attn_refractory/2026-05-11-attn-refractory-session-summary.md delete mode 100644 work_log/attn_refractory/decoupling-diagram.md delete mode 100644 work_log/attn_refractory/vllm-integration-architecture.md diff --git a/work_log/attn_refractory/2026-04-30-attn-refractory.md b/work_log/attn_refractory/2026-04-30-attn-refractory.md deleted file mode 100644 index f0946299a9..0000000000 --- a/work_log/attn_refractory/2026-04-30-attn-refractory.md +++ /dev/null @@ -1,354 +0,0 @@ -# 2026-04-30 ATOM Attention Refactory Session Summary - -## 1. 本次会话的核心目标 - -本次连续讨论的主线是围绕 `ATOM plugin` 中 attention backend 的重构方向展开,重点包括: - -- 调研当前 `ATOM plugin` attention backend 的架构问题 -- 评估 `vLLM plugin` 与 `SGLang plugin` 的耦合程度 -- 分析 SGLang plugin 在支持新模型、特别是非传统 MHA/MLA 类型 backend 时会遇到的困难 -- 讨论如何在 plugin 层把 `vLLM` 与 `SGLang` 解耦 -- 讨论 `SGLang` 中 `hybrid/composite backend` 的接入方式 -- 评估三种模式共享 ATOM attention 实现的难度: - - `ATOM native/server` - - `ATOM vLLM plugin` - - `ATOM SGLang plugin` - -## 2. 对当前 ATOM plugin attention 架构的主要判断 - -### 2.1 当前不是一个真正统一的 attention backend 架构 - -本次调研得到的一个核心结论是: - -- `vLLM plugin` 和 `SGLang plugin` 虽然都叫 “ATOM attention” -- 但它们并不是通过同一套干净的 backend contract 接进来的 -- 更准确地说,是两套宿主接入模型外加多层 patch / adapter / runtime glue 共同构成的系统 - -### 2.2 当前的主要问题不是底层 kernel,而是 runtime / ownership / patch 层次 - -现有问题主要集中在: - -- 全局状态过重 - - `_CURRENT_FRAMEWORK` - - `current_atom_config` - - `ops.Attention` -- vLLM 与 SGLang 的接入模型差异很大,但还共享 `plugin` 根目录里一批伪公共逻辑 -- 很多扩展点依赖 monkey patch 和 runtime override -- SGLang 侧对 DeepSeek MLA 已经上卷到 model-level patch,而不是单纯 backend 替换 - -### 2.3 attention 文件结构的拆分轴不统一 - -本次对 `plugin` 目录下 attention 相关文件的解读认为,目前至少混杂了三层不同对象: - -- host/runtime glue -- backend runtime core -- model-specific specialization - -尤其在 `SGLang` 侧: - -- `radix_attention.py` 更像 adapter -- `sgl_attn_backend.py` 更像 full-attn runtime backend core -- `sgl_attention_mla.py` 更像 DeepSeek MLA specialization - -而在根目录: - -- `attention.py` -- `attention_mha.py` -- `attention_mla.py` -- `attention_mla_sparse.py` - -虽然放在 `plugin/` 根目录,但职责上明显更接近 `vLLM plugin` 的 glue / patch 层,而不是真正共享的 plugin runtime core。 - -## 3. 关于 vLLM plugin 与 SGLang plugin 的耦合 - -### 3.1 结论:耦合度高 - -这次讨论中对二者耦合的判断是: - -- 不是轻量共享几个工具函数 -- 而是共享了一套 runtime state、attention 抽象、selector 和 plugin-mode 判断方式 - -关键耦合点包括: - -- `atom/plugin/prepare.py` -- `atom/plugin/register.py` -- `atom/plugin/config.py` -- `atom.plugin.prepare.is_vllm()/is_sglang()/is_plugin_mode()` -- `ops.Attention` 的全局切换 - -### 3.2 关键判断 - -真正的问题不是 “目录没有拆开”,而是: - -**host runtime 差异没有被限制在 adapter 边界,而是泄漏进了 backend 选择、metadata 组织和全局状态模型。** - -## 4. 关于 plugin 层解耦的共识 - -### 4.1 plugin 层不应该继续保留共享 runtime - -讨论最终倾向于一个更激进但更干净的方向: - -- `atom/plugin/vllm/` 是一个完整子系统 -- `atom/plugin/sglang/` 是另一个完整子系统 -- `atom/plugin/` 根目录不再承担共享 runtime 中心的角色 - -共享只保留给更底层的 `ATOM core`,例如: - -- `model_ops` -- kernels -- loader -- metadata helpers -- model-family specialization 中真正 host-agnostic 的部分 - -### 4.2 “bootstrap” 的定义 - -本次还明确了 `bootstrap` 在本讨论语境中的含义: - -- 不是核心执行逻辑 -- 而是插件被宿主框架加载时的最早期接线/初始化层 - -也就是: - -- 注册扩展点 -- 安装 patch -- 选择 wrapper / adapter / backend -- 建立 host 与 plugin 的连接关系 - -## 5. 已经做过的代码原型 - -本次会话中,为了验证“按 host 拆入口”的思路,已经做了一批原型代码修改: - -### 5.1 新增的 bootstrap / prepare / register 原型 - -- `atom/plugin/sglang/bootstrap.py` -- `atom/plugin/vllm/bootstrap.py` -- `atom/plugin/sglang/prepare.py` -- `atom/plugin/sglang/register.py` - -### 5.2 `prepare.py` 已部分收缩为兼容层 - -- 实际的 SGLang prepare 逻辑已经移动到 `atom/plugin/sglang/prepare.py` -- `atom/plugin/prepare.py` 现在更像一个 legacy shim -- 但它仍保留了 framework-state helper(因为仓库里很多地方还在依赖) - -### 5.3 SGLang register 逻辑已开始下沉 - -SGLang 独有的几块逻辑已经迁入: - -- `register_ops_to_sglang` -- `set_sglang_attn_cls` -- `init_aiter_dist_for_sglang` -- `bootstrap_sglang_runtime` - -共享 `atom/plugin/register.py` 目前更多是兼容 facade。 - -### 5.4 说明 - -这些改动的定位是: - -- 用来显式化未来结构 -- 不是完整重构 -- 还保留了较多兼容层 - -## 6. 关于 SGLang attention backend 的重要判断 - -### 6.1 `ATOMAttnBackendForSgl` 与 public `AiterAttnBackend` - -在一次反复讨论后,本次对它们的关系收敛为: - -- 在 backend 角色位置上,两者基本是 apple-to-apple -- `ATOMAttnBackendForSgl` 不是“更高层的新东西” -- 而是 `SGLang full-attention runtime backend` 的 ATOM 对位实现 - -因此更合适的重构思路不是怀疑它的存在合法性,而是: - -- 保留它作为 full-attn backend core -- 再去拆它内部的 metadata / kv_cache / decode / graph 等职责 - -### 6.2 DeepSeek MLA 的特殊性 - -`sgl_attention_mla.py` 暴露出一个很重要的事实: - -- 对 `MLA`,尤其是 `DeepSeek MLA` -- SGLang plugin 已经不是单纯换 backend -- 而是 patch 了模型级 `forward` - -这说明: - -**MLA 在 SGLang plugin 里已经进入了 model specialization 维度。** - -## 7. 关于 GDN / KDA / Lightning / Mamba2 等非传统 “attention” - -### 7.1 调研结论 - -public SGLang 已经证明,系统里不止 full-attention backend,还存在: - -- linear attention backend - - `GDNAttnBackend` - - `KDAAttnBackend` - - `LightningAttentionBackend` - - `Mamba2AttnBackend` -- hybrid/composite backend - - `HybridLinearAttnBackend` - - `HybridAttnBackend` -- wrapper/composition backend - - `TboAttnBackend` - -### 7.2 对当前 ATOM plugin 设计造成的挑战 - -当前 ATOM plugin 的主抽象仍偏向: - -- MHA -- MLA -- full attention runtime - -这会带来几个结构性问题: - -1. 还没有把 `linear backend` 当成一等公民 -2. 还没有给 `hybrid/composite backend` 预留独立宿主位置 -3. 容易把 algorithm backend、kernel backend、host glue 混成一层 -4. 容易继续把 `GDN` 等路径硬塞回 full-attn backend 体系 - -### 7.3 当前共识 - -后续不应该继续只谈 “attention backend 重构”,而应该升级为: - -**sequence backend / mixer backend 体系重构** - -建议至少在设计上并列三类: - -- `full_attn` -- `linear_attn` -- `hybrid/composite` - -## 8. hybrid/composite backend 在 plugin 侧的接入思路 - -### 8.1 核心判断 - -由于 ATOM SGLang plugin 的启动方式是: - -- 启动 `sglang server` -- 再通过 plugin runtime override 接管 model / attention - -所以要接 `hybrid/composite backend`,不可避免需要 patch 某个 seam。 - -### 8.2 最好的 seam - -本次讨论认为,public SGLang 里最好的 seam 是: - -- `model_runner.py` 中导入并调用的 `attn_backend_wrapper` - -注意: - -- 它不是 `ModelRunner` 的成员方法 -- 而是 `model_runner.py` 模块级导入的符号 - -所以如果要 patch,更稳的是 patch: - -- `sglang.srt.model_executor.model_runner.attn_backend_wrapper` - -而不是仅 patch `attention_registry.attn_backend_wrapper`。 - -### 8.3 最推荐的方式 - -对 hybrid/composite backend 的接入方式,本次最终倾向于: - -**plugin-own 一个薄的 composite wrapper/factory,只在 backend 构造 seam 做单点 patch。** - -不推荐: - -- 深度 monkey patch public hybrid backend 实现 -- 一开始就复制整套 public linear/full backend 逻辑 - -更推荐: - -- full side 先用 ATOM full backend -- linear side 初期先复用 public SGLang 的 `GDNAttnBackend` / `KDAAttnBackend` / `LightningAttentionBackend` / `Mamba2AttnBackend` -- composite wrapper 由 plugin 侧拥有 - -## 9. 关于三种模式共享 ATOM attention 实现的难度 - -本次对: - -- `ATOM native/server` -- `ATOM vLLM plugin` -- `ATOM SGLang plugin` - -三种模式共享 ATOM attention 实现,得到的判断如下。 - -### 9.1 MHA - -难度:`Medium-High` - -原因: - -- `ATOM native` 与 `ATOM vLLM plugin` 已经在 `PagedAttentionImpl` 等层共享较多 -- 真正难的是 `SGLang plugin` 这一侧 -- 它不走 `PagedAttention` 的 ATOM 主路径,而走: - - `RadixAttention` - - `ATOMAttnBackendForSgl` - - `ForwardBatch` - -所以 MHA 共享的主要困难在于: - -**runtime orchestration 差异** - -### 9.2 MLA - -难度:`High` - -原因: - -- native/server 与 vLLM plugin 仍然较多共享 `MLAAttention` -- 但 SGLang plugin 下的 DeepSeek MLA 已经上卷到 model specialization -- 不只是 backend 不同,连 forward 组织方式都不同 - -所以 MLA 的困难在于: - -**runtime orchestration + model specialization 双重叠加** - -## 10. 已产出的文档与图 - -本次会话过程中,已额外产出下列文档/图,用于不同角度说明问题。 - -### 10.1 代码目录内 Markdown - -- `atom/plugin/decoupling-diagram.md` -- `atom/plugin/vllm-integration-architecture.md` -- `atom/plugin/sglang-attention-backend-survey.md` - -### 10.2 Canvas / 可视化分析 - -- `atom-attention-backend-architecture-review.canvas.tsx` -- `atom-plugin-coupling-risk-analysis.canvas.tsx` -- `atom-plugin-decoupling-diagram.canvas.tsx` -- `atom-vllm-plugin-architecture.canvas.tsx` -- `atom-attention-sharing-modes.canvas.tsx` - -### 10.3 这些产物对应的主题 - -- 当前 attention backend 架构缺陷 -- vLLM / SGLang plugin 耦合与风险 -- plugin 解耦方向与 bootstrap 理解 -- vLLM plugin 集成架构 -- public SGLang backend survey -- 三种模式共享 ATOM attention 的难度评估 - -## 11. 当前最值得继续推进的方向 - -如果延续本次会话的结论,后续工作最值得按下面顺序推进: - -1. 继续收缩 `atom/plugin/` 根目录的 shared runtime 语义 -2. 把 `full_attn / linear_attn / hybrid` 三层作为 plugin 后续结构设计的一等公民 -3. 在 `SGLang plugin` 侧先补出 `hybrid/composite` 组合层 -4. 初期复用 public SGLang linear backend,优先验证 runtime 结构是否成立 -5. 然后再评估哪些 linear backend 需要逐步替换成 ATOM 自己的实现 -6. 对 `MLA` 尽早拆开: - - generic MLA runtime - - DeepSeek specialization - -## 12. 一句话总结 - -本次会话最终把问题收敛为: - -**当前 ATOM plugin attention 的关键矛盾,不是底层 kernel 能不能共享,而是 vLLM / SGLang / native 三种模式在 runtime、metadata、model specialization 和 backend ownership 上没有对齐;后续重构应从“统一 attention backend”升级为“重新定义 plugin 的 sequence backend / host-owned runtime 结构”。** diff --git a/work_log/attn_refractory/2026-05-07-radix-cache-vs-paged-attention.md b/work_log/attn_refractory/2026-05-07-radix-cache-vs-paged-attention.md deleted file mode 100644 index f0f9d5bccc..0000000000 --- a/work_log/attn_refractory/2026-05-07-radix-cache-vs-paged-attention.md +++ /dev/null @@ -1,482 +0,0 @@ -# 2026-05-07 Radix Cache vs Paged Attention Notes - -> 预估阅读时间:15 分钟 -> 主题:梳理 `SGLang` 中 `radix cache` / `RadixAttention` 与 `ATOM` / `vLLM` 常见的 `paged attention` 之间的关系,解释为什么它们最终可以落到相同的 kernel,以及这件事对 `sglang plugin` attention backend 重构意味着什么。 - -## 1. 这篇笔记想回答什么 - -围绕 `sglang plugin attn backend` 的重构,最近反复出现了几个问题: - -1. `SGLang` 用的是 `RadixAttention`,`ATOM`/`vLLM plugin` 里常见的是 `PagedAttention`,两者是不是两套完全不同的注意力体系? -2. 如果它们真的不同,为什么在 `DeepSeek MLA` 这类路径里,最终又会调用到相同的底层 kernel? -3. 既然 `SGLang` 还多了一层 `radix cache` 前缀管理,为什么没有一个非常直观的结论说 "`SGLang` 一定比 `vLLM` 更快 / 更省显存"? -4. 对当前重构来说,`radix` / `paged` 的差异到底应该被归类到: - - host runtime 差异 - - KV cache 管理差异 - - metadata lowering 差异 - - kernel 差异 - 的哪一层? - -本文尝试把这几个问题放进同一个分析框架里。 - -## 2. 先给结论 - -### 2.1 最核心的一句话 - -`RadixAttention` 和 `PagedAttention` 主要不是在底层注意力数学上不同,而是在 **runtime 如何管理、共享、命中和定位历史 KV** 上不同。 -一旦 runtime 把历史 KV lowering 成底层 kernel 能吃的 `page_table`、`kv_indptr`、`kv_indices`、`qo_indptr` 之类的 metadata,kernel 就不再关心这些索引最初是来自 radix tree 还是来自 paged block table。 - -### 2.2 另一句更工程化的结论 - -对重构来说,不应该把 “`RadixAttention` vs `PagedAttention`” 当作 “能不能共享底层执行实现” 的直接判断依据。 -更应该把系统拆成三层: - -1. **host/runtime 层** - 例如 `ForwardBatch`、`RadixAttention`、`ModelRunner`、prefix cache、speculative 调度 -2. **execution metadata lowering 层** - 例如 `page_table`、`block_tables`、`kv_indptr`、`kv_indices`、`qo_indptr` -3. **kernel / execution core 层** - 例如 `pa_fwd_asm`、`pa_persistent_fwd`、`flash_attn_varlen_func`、`mla_decode_fwd`、`mla_prefill_fwd` - -`radix` 和 `paged` 的差异,主要在第 1、2 层; -兼容性和共享机会,主要出现在第 2、3 层。 - -## 3. 为什么名字容易让人误解 - -当前最容易造成误解的点是: - -- `PagedAttention` 这个名字听起来像“最终执行算法” -- `RadixAttention` 这个名字也听起来像“最终执行算法” - -但在工程上,它们都更接近 **host-facing attention runtime abstraction**,而不是最终 kernel 名字。 - -更具体一点: - -- `PagedAttention` 更强调 **按固定 page/block 管理 KV cache** -- `RadixAttention` 更强调 **按 prefix/radix tree 管理请求历史与前缀复用** - -这两者都不是在说: - -- `QK^T` 怎么算 -- softmax 怎么做 -- decode kernel 用哪套 asm/triton/flashinfer - -这些底层执行问题,是更下一层的 backend / kernel 决定的。 - -## 4. `RadixAttention` 到底是什么 - -### 4.1 `RadixAttention` 不是 kernel - -在 public `SGLang` 里,`RadixAttention` 本身并不直接定义底层 kernel 选择。 -它更像是: - -- SGLang 模型层里统一使用的 attention layer 壳 -- 它知道自己在 `ForwardBatch` 语义里运行 -- 它把真正的执行委托给 `forward_batch.attn_backend` - -因此,`RadixAttention` 更像一个 **layer/runtime 入口点**。 - -### 4.2 `radix` 的重点是 prefix 管理 - -`radix cache` 的本质,是一棵 prefix tree,用来做: - -- 前缀命中 -- 前缀复用 -- unfinished / finished request 的缓存插入 -- 基于 prefix 的 eviction / lifecycle 管理 - -它回答的问题更像: - -- 这个请求的历史前缀已经缓存到哪里了? -- 哪一部分可以直接复用,不必重新 prefill? -- 哪些 KV slot/page 仍然属于共享前缀? -- 哪些历史片段可以回收? - -所以 `radix` 优化的主要是 **prefix reuse**,而不是单次 attention kernel 的算力效率。 - -### 4.3 `radix cache` 在 SGLang 里不是完全“反 page”的 - -这一点很重要。public `SGLang` 的 `radix_cache` 本身就已经带有 page-aware 的逻辑。 - -例如它有 page 对齐: - -- `page_align_keys()` - -也有按 page 粒度匹配 prefix: - -- `_key_match_paged()` - -这意味着: - -- `radix cache` 解决的是前缀共享与命中问题 -- 但它并不排斥最终按 page/block 粒度来表达缓存 - -也就是说,**radix tree 的逻辑管理** 与 **paged/block 的物理表达** 并不冲突。 - -## 5. `PagedAttention` 到底是什么 - -### 5.1 `PagedAttention` 的重点不是 prefix tree - -`PagedAttention` 更强调: - -- KV cache 被组织成固定大小的 page/block -- 每个请求通过 `page_table` / `block_table` 找到自己历史对应的 page -- kernel 依照这些 page/block 索引读取历史 KV - -它回答的问题更像: - -- 当前请求历史有哪些 block/page? -- 它们在物理内存中的位置是什么? -- 这些 page 如何映射到 kernel 所需的 block table? - -所以 `paged attention` 更偏: - -- **物理存储布局** -- **kernel 读取 KV 的地址组织方式** - -### 5.2 `PagedAttention` 天生不等于 prefix reuse - -`paged attention` 当然也可以配合 prefix cache 使用,但 prefix reuse 并不是它名字里最核心的概念。 -它的核心价值更在于: - -- 支持非连续物理布局 -- 让 decode kernel 可以按 page/block 高效读取 KV - -## 6. 为什么 `radix` 和 `paged` 最后可以兼容 - -### 6.1 因为它们主要解决的是不同层的问题 - -可以把它们看成: - -- `radix`: 逻辑历史管理器 -- `paged`: 物理页表表达方式 - -两者关系有点像: - -- `radix` 负责决定“哪些历史是共享的、哪些是命中的、请求当前逻辑历史是什么” -- `paged` 负责决定“这些历史最终以什么 page/block 索引形式交给 kernel” - -只要 runtime 最终能把逻辑历史翻译成 page/block 或 indptr/index 形式,底层 kernel 并不需要知道前缀最初是如何组织的。 - -### 6.2 兼容发生在 lowering 之后 - -这一点非常关键: - -上层看起来是: - -- `ForwardBatch` -- `RadixAttention` -- `req_to_token` -- prefix cache - -但 lower 到 backend 之后,会变成执行态 metadata,例如: - -- `page_table` -- `block_tables` -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `kv_last_page_len` - -到了这个层次,kernel 只看: - -- query 的 shape -- KV cache 的 shape -- 历史长度 -- 索引表 / indptr - -它不看: - -- 这套索引是不是来自 radix tree -- 是不是来自某个 prefix cache 命中 -- 是不是通过 `PagedAttention` 对象构造出来的 - -所以兼容性不是“名字兼容”,而是 **execution metadata 兼容**。 - -## 7. public `SGLang` 自己就在做类似 lowering - -这不是 `ATOM plugin` 独有的现象。public `SGLang` 本身就在这么做,只是不同 backend lower 到的 metadata 形态略有不同。 - -### 7.1 `aiter_backend` - -public `SGLang` 的 `aiter_backend` 会从: - -- `forward_batch.seq_lens` -- `forward_batch.req_pool_indices` -- `req_to_token` - -构造: - -- `kv_indptr` -- `kv_indices` -- `qo_indptr` - -也就是说,它会把 host runtime 的历史表示 lower 成 AITER kernel 所需的执行态索引。 - -### 7.2 `triton_backend` - -public `SGLang` 的 `triton_backend` 也做类似事情。 -它同样维护: - -- `kv_indptr` -- `kv_indices` -- `qo_indptr` - -然后再调用 Triton 的 decode/extend kernel。 - -### 7.3 `trtllm_mha_backend` / 部分 `flashinfer` 路径 - -这类 backend 更偏向: - -- 直接构造 `page_table` -- 或 `block_tables` - -再喂给 flashinfer / TRTLLM 风格的 kernel。 - -所以更准确地说,public `SGLang` 的共同模式不是: - -> “所有 backend 都把 radix runtime 翻成 paged attention” - -而是: - -> “所有 backend 都会把 radix/runtime 层的历史表示,lower 成各自 kernel 能吃的索引/页表 metadata” - -其中有的偏 `kv_indptr/kv_indices`,有的偏 `page_table/block_tables`。 - -## 8. 从 kernel 角度看,为什么可以兼容 - -### 8.1 kernel 真正关心什么 - -绝大多数 attention kernel 真正关心的是: - -1. 当前 query 的排列方式 -2. 每个请求可见的历史 KV 长度 -3. 如何从某个索引结构里拿到历史 KV 的物理地址 -4. page/block 大小、stride、layout -5. 是否需要 prefix/extend/speculative 的特殊 metadata - -它通常不关心: - -1. 这份索引最初是不是从 prefix tree 查出来的 -2. 命中前缀的逻辑是谁做的 -3. runtime 是 `RadixAttention` 还是 `PagedAttention` - -### 8.2 对 MHA kernel 的影响 - -MHA decode 常见地会落到: - -- `pa_fwd_asm` -- `pa_persistent_fwd` -- `flash_attn_varlen_func` - -这些 kernel 更偏 page/block 或 varlen index 风格。 -只要 runtime 最终给出: - -- `page_table` -- `context_lens` -- `block_tables` - -或者: - -- `kv_indptr` -- `kv_indices` -- `qo_indptr` - -它们都能算。 - -### 8.3 对 MLA kernel 的影响 - -`DeepSeek MLA` 更容易看出这种兼容性,因为 MLA 的底层算子路径更明确。 - -典型 kernel 包括: - -- `mla_decode_fwd` -- `mla_prefill_fwd` -- `concat_and_cache_mla` -- `fused_qk_rope_concat_and_cache_mla` - -这些 kernel 只要求: - -- 压缩后的 latent KV 表示 -- 以及相应的 `qo_indptr` / `kv_indptr` / `kv_indices` - -所以无论上层 host 是: - -- `ATOM native` -- `ATOM vLLM plugin` -- `ATOM SGLang plugin` - -只要最终 lower 成相同的 MLA execution metadata,就自然会落到同一套 MLA kernel。 - -## 9. `DeepSeek MLA` 为什么尤其容易“看起来收敛” - -### 9.1 因为 MLA 的 execution core 更专用 - -`DeepSeek MLA` 的真正执行对象不是: - -- “`RadixAttention` 版本的 MLA” -- “`PagedAttention` 版本的 MLA” - -而是: - -- 一组固定的 MLA latent/cache 表示 -- 一套固定的 rope + kv cache 写入方式 -- 一套固定的 prefill/decode kernel - -换句话说,MLA 的底层 core 更容易抽成“共享执行层”。 - -### 9.2 因为 host 差异主要体现在 metadata 准备与调度 - -对于 `DeepSeek MLA` 来说,上层 host 差异更多体现在: - -- `positions` 从哪里来 -- `forward_batch` 怎么组织 -- decode / extend / speculative / target_verify 怎么分流 -- `req_to_token` 怎么转换成 `kv_indices` - -一旦进入 `mla_decode_fwd` 或 `mla_prefill_fwd` 前,这些差异已经被消化掉了。 - -所以从外部观察,很容易得到一种印象: - -> “怎么上面完全不一样,下面却是同一套 kernel?” - -其实并不奇怪,因为两边只是上层 runtime 不同,而 MLA execution core 本来就在下层收敛。 - -## 10. 既然兼容,为什么 `SGLang` 没有天然更快或更省显存 - -这是另外一个很容易误判的问题。 - -### 10.1 `radix cache` 省的是前缀重复,不是所有 KV - -`radix cache` 能节省的,是重复前缀带来的重复 prefill 和重复 KV 占用。 - -它主要帮助的是: - -- 大量请求共享长前缀 -- 同 prompt 多分支采样 -- prefix-heavy workload -- 某些 speculative / chunked prefill 场景 - -如果 workload 不满足这些条件,它就不会天然表现为显著优势。 - -### 10.2 它优化的是 prefix reuse,不是单 token decode kernel - -radix 管理层并不会让 `mla_decode_fwd`、`pa_fwd_asm` 这种 kernel 自己更便宜。 -kernel 的 raw 速度主要还是受: - -- 头数 -- head dim -- KV layout -- quant -- page size -- GPU kernel 实现 - -这些因素决定。 - -所以常见现象是: - -- prefix-heavy workload 下,`SGLang` 的收益更明显 -- decode-heavy workload 下,收益没那么明显 -- 如果 `vLLM` 也有自己的 prefix caching/block caching,差距会进一步缩小 - -### 10.3 radix 自己也有管理成本 - -`radix cache` 不是免费层。它也有: - -- prefix match -- tree split / insert / eviction -- request -> KV slot 映射维护 -- speculative / unfinished request 的额外 bookkeeping - -如果命中率不高,这部分成本并不会自动转化成收益。 - -## 11. 对当前重构最有价值的判断 - -### 11.1 不要把 `RadixAttention` vs `PagedAttention` 当成是否共享 execution core 的边界 - -因为从 public `SGLang` 和当前 `ATOM plugin` 的代码路径看,二者最终都在做类似事情: - -- 从 host runtime 的历史表示出发 -- 构造底层 kernel 所需的 execution metadata - -所以: - -- runtime 不同,不代表 execution core 必须不同 -- 名字不同,不代表 kernel 层必须分叉 - -### 11.2 真正的边界更适合这样划 - -#### A. `sglang plugin runtime` 应拥有的部分 - -- `ForwardBatch` -- `RadixAttention` -- prefix cache / radix cache -- `ModelRunner` / `attn_backend_wrapper` 相关调度 -- speculative / graph / verify / chunked prefill 等宿主语义 - -#### B. execution metadata lowering 层 - -这层是最值得重新定义的边界。它负责: - -- 把 `req_to_token`、prefix 命中结果、cache state -- 转成 `page_table` / `kv_indptr` / `kv_indices` / `qo_indptr` - -这层既带有 host 语义,又已经开始接近共享执行核心,是未来最值得梳理清楚的一层。 - -#### C. shared execution core - -例如: - -- `mla_decode_fwd` -- `mla_prefill_fwd` -- `pa_fwd_asm` -- `pa_persistent_fwd` -- `flash_attn_varlen_func` - -以及与这些 kernel 紧耦合的那部分 ATOM core helper。 - -### 11.3 这对 `sglang plugin` 的重构意味着什么 - -当前 `sglang plugin` 最不该做的是: - -- 因为底层 kernel 能共享,就把 `sglang` 的 runtime seam 也强行对齐到 `vllM` -- 或者因为上层叫 `RadixAttention`,就认为它和 `PagedAttention` 完全不能共享底层执行实现 - -更合理的重构方向是: - -1. 保持 `sglang` 的 host/runtime 语义完整 -2. 识别 runtime lowering 层的清晰边界 -3. 尽量把 execution core 继续下沉回共享 `ATOM core` - -## 12. 一种对外更容易讲清楚的表述 - -如果后续要向别人解释,可以用下面这几句话: - -### 12.1 关于 `radix` vs `paged` - -`radix cache` 主要解决的是 prefix 命中与复用,`paged attention` 主要解决的是 page/block 形式的 KV 组织与读取。 -前者偏逻辑管理,后者偏物理表达;两者不是同一层次的问题。 - -### 12.2 关于为什么最终会调到同一个 kernel - -因为 attention kernel 只关心最终的执行态 metadata,比如 `page_table`、`kv_indptr`、`kv_indices`、`qo_indptr`。 -一旦 runtime 把历史 KV lowering 成这类形式,kernel 就不再关心这些索引最初是由 radix tree 还是 paged runtime 生成的。 - -### 12.3 关于为什么 `SGLang` 没有天然显著更快 - -因为 `radix` 带来的主要收益是 prefix reuse,而不是单 token decode kernel 的 raw speed。 -如果 workload 不是 prefix-heavy,或者 decode 占主要成本,那么 radix 带来的优势不会自然放大成“总是更快 / 更省显存”。 - -## 13. 最终总结 - -把这次探索收敛成一句话: - -**`RadixAttention` 与 `PagedAttention` 的主要差异,不在底层注意力数学,也不一定在最终 KV 的物理表示本身,而在 host runtime 如何管理、共享、命中并 lowering 历史 KV;一旦 lowering 完成,底层 kernel 完全可以是共享的。** - -对当前 `sglang plugin attn backend` 的重构来说,真正值得做的不是争论 “`radix` 和 `paged` 能不能统一成一个名字”,而是把系统拆清楚: - -- 哪些是 `sglang` 必须自有的 runtime / prefix cache / scheduling 语义 -- 哪些是 execution metadata lowering -- 哪些已经是可以共享的 `ATOM` execution core - -这比直接讨论 “`radix` 和 `paged` 谁更先进、谁应该替代谁” 更接近真正的工程问题。 diff --git a/work_log/attn_refractory/2026-05-08-kv-indices-address-space-notes.md b/work_log/attn_refractory/2026-05-08-kv-indices-address-space-notes.md deleted file mode 100644 index e5a63be0f1..0000000000 --- a/work_log/attn_refractory/2026-05-08-kv-indices-address-space-notes.md +++ /dev/null @@ -1,359 +0,0 @@ -# 2026-05-08 KV Indices and Address Space Notes - -> 预估阅读时间:8-10 分钟 -> 主题:解释 `kv_indptr` / `kv_indices` 为什么不是纯 host metadata,也不是完全 storage-agnostic 的抽象,并用 `sglang` / `ATOM` / `vllm plugin` 的实际 KV cache 存储举例说明。 - -## 1. 想回答的问题 - -最近有一个很容易卡住重构讨论的问题: - -> `kv_indptr` / `kv_indices` 看起来和 KV cache 的存储方式深度绑定。 -> 既然如此,为什么还能把它们看成 `sglang` / `ATOM` / `vllm plugin` 之间可以共享的 metadata? - -这个问题的关键在于区分两件事: - -1. **host/runtime 是否相同** -2. **kernel 最终消费的地址空间模型是否相同** - -一句话先说结论: - -**`kv_indptr` / `kv_indices` 不是完全 storage-agnostic 的抽象,它们绑定的是某种“可索引的 KV 地址空间”; -但它们可以是 host/runtime-agnostic 的,只要不同宿主最终都能把自己的 runtime state lowering 到同一种地址空间模型。** - -## 2. 三层视角 - -理解这个问题,最好先把 attention 相关代码拆成三层: - -### 2.1 host/runtime 层 - -这一层描述的是宿主框架如何组织请求和调度,例如: - -- `ATOM native` 的 `ScheduledBatch` -- `vllm plugin` 的 `query_start_loc` / `num_prefills` -- `sglang plugin` 的 `ForwardBatch` / `forward_mode` / `spec_info` - -这一层表达的是: - -- 哪些请求在 batch 中 -- 哪些是 decode,哪些是 prefill -- prefix / speculative / graph 的语义是什么 - -### 2.2 execution metadata lowering 层 - -这一层做的事情是: - -- 把 host/runtime 里的 batch 状态 -- 翻译成 kernel 真正能消费的索引结构 - -典型字段: - -- `block_tables` / `page_table` -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `kv_last_page_len` - -### 2.3 kernel / execution core 层 - -这一层只关心: - -- query tensor -- KV cache tensor -- index / indptr / page table -- persistent kernel workspace - -例如: - -- `mla_decode_fwd` -- `mla_prefill_fwd` -- `pa_fwd_asm` -- `pa_persistent_fwd` - -## 3. `sglang` 的实际 KV cache 存储 - -在 `sglang` 里,KV cache 本身不是存在 radix tree 里。 -radix tree 负责的是 prefix 复用与命中;真正的 KV 还是存在 memory pool 里。 - -`sglang` 自己的注释说得很清楚: - -- `ReqToTokenPool` 负责 request -> token location 映射 -- `TokenToKVPoolAllocator` 负责 KV cache index 管理 -- `KVCache` 真正持有物理 K/V tensor - -也就是说,`sglang` 的 runtime 世界里至少有两层: - -1. **逻辑视角**:某个 request 的第 `i` 个 token 属于哪里 -2. **物理视角**:这个 token 对应的 K/V 存在物理 pool 的哪个 slot - -### 3.1 `ReqToTokenPool` - -`ReqToTokenPool` 本质上是一张二维表: - -```text -req_to_token[req_id, token_pos] = physical_slot_id -``` - -它回答的是: - -- 某个 request 的第 `j` 个逻辑 token -- 对应到物理 KV pool 里的哪个 slot - -### 3.2 物理 KV pool - -物理 pool 里,K 和 V 都是按 slot 存的。 -写入 KV 时,最终就是按 `loc/indices` 写入: - -```text -k_cache[indices] = k -v_cache[indices] = v -``` - -所以从 kernel 角度看,最终它看到的是: - -- 一个可以索引的 KV 地址空间 -- 一组 slot index - -而不是看到“radix tree”本身。 - -## 4. `ATOM` / `vllm plugin` 的实际 KV cache 存储 - -`ATOM` / `vllm plugin` 更偏 paged/block 风格。 - -典型思路是: - -- 每个 request 有一张 `block_table` -- 每个 block 对应一个固定大小的 page -- kernel 按 `block_table` 或其展开后的 index 去读 KV - -所以它的上层直觉更像: - -```text -request -> block table -> page/block address -> physical KV -``` - -和 `sglang` 相比,最大的不同不是“有没有物理地址空间”,而是: - -- `sglang` 上层先经过 `req_to_token` -- `vllm/ATOM` 上层先经过 `block_table` - -但两边最终都能 lower 到一组可供 kernel 读取的地址索引。 - -## 5. 例子一:MLA decode 中的 `kv_indptr` / `kv_indices` - -这个例子里,可以把 `kv_indices` 理解为 **token-slot index**。 - -### 5.1 在 `sglang` 里 - -假设一个 request 当前长度是 10, -它在物理 KV pool 里的 slot 顺序不是连续的,而是: - -```text -[8, 9, 10, 11, 0, 1, 2, 3, 4, 5] -``` - -那 `req_to_token[req_id, :10]` 就大致表示: - -```text -logical token position: 0 1 2 3 4 5 6 7 8 9 -physical slot id: 8 9 10 11 0 1 2 3 4 5 -``` - -如果这个 batch 只有这一个 request,那么 lowering 之后可以得到: - -```text -kv_indptr = [0, 10] -kv_indices = [8, 9, 10, 11, 0, 1, 2, 3, 4, 5] -``` - -这里: - -- `kv_indptr` 表示“第 0 个 request 的 KV 范围是 `[0, 10)`” -- `kv_indices` 表示“真正去物理 KV pool 里读这些 slot” - -### 5.2 在 `vllm/ATOM` 里 - -如果 page size 是 4,同样 10 个 token 对应的 block layout 可以写成: - -```text -block 2 -> slots [8, 9, 10, 11] -block 0 -> slots [0, 1, 2, 3] -block 1 -> slots [4, 5, 6, 7] -``` - -block table 大致就是: - -```text -[2, 0, 1] -``` - -展开前 10 个 token 后,对应的 slot 顺序仍然是: - -```text -[8, 9, 10, 11, 0, 1, 2, 3, 4, 5] -``` - -于是 lower 之后,给 MLA decode kernel 的: - -```text -kv_indptr = [0, 10] -kv_indices = [8, 9, 10, 11, 0, 1, 2, 3, 4, 5] -``` - -和 `sglang` 这边是等价的。 - -### 5.3 这个例子说明什么 - -说明在 MLA decode 这个路径里: - -- `kv_indices` 绑定的是 **token-slot 地址空间** -- 它不是“完全与存储无关” -- 但它已经不关心 slot 列表最初是由 `req_to_token` 还是 `block_table` 推出来的 - -也就是说,**它和具体宿主无关,但和被选定的地址空间模型有关。** - -## 6. 例子二:MHA persistent decode 中的 `page_table` / `kv_indices` - -这个例子里,`kv_indices` 更接近 **page/block index**,而不是 token-slot index。 - -### 6.1 在 `sglang plugin` 里 - -仍然假设 page size = 4。 -如果某个 request 的 token-slot 顺序是: - -```text -[8, 9, 10, 11, 0, 1, 2, 3, 4, 5] -``` - -那么每页的起始 slot 是: - -```text -[8, 0, 4] -``` - -除以 page size 后,对应 page id: - -```text -[2, 0, 1] -``` - -在 `sglang plugin` 里,这就是 `page_table` 的来源: -从 `req_to_token` 按页采样,再除以 `page_size`。 - -然后 `_build_pa_metadata_for_decode()` 再把它变成 `pa_persistent_fwd` 所需的 page-level `kv_indices`。 - -### 6.2 在 `ATOM` / `vllm` 里 - -这一层本来就是 paged/block 视角,所以 `block_table` 原生就已经是: - -```text -[2, 0, 1] -``` - -也就是说: - -- `sglang` 是 `req_to_token -> page_table` -- `vllm/ATOM` 是直接 `block_table -> page_table` - -但到了 persistent kernel 眼里,最后看到的是同一种 page/block 地址空间。 - -### 6.3 这个例子说明什么 - -说明在 MHA persistent decode 这个路径里: - -- `kv_indices` 绑定的是 **page/block 地址空间** -- 它仍然不是“完全 storage-agnostic” -- 但也不需要知道上层 host 是 `sglang` 还是 `vllm` - -只要两边都能 lower 到同一个 page/block 地址空间,kernel 就能共享。 - -## 7. 这两个例子真正说明的事 - -上面两个例子其实在说明同一个结论: - -### 7.1 `kv_indptr` / `kv_indices` 不是 host metadata - -它们不描述: - -- `ForwardBatch.forward_mode` -- `query_start_loc` -- `num_prefills` -- `run_graph` -- prefix hit 的高层语义 - -它们只描述: - -- 当前这次 kernel 调用 -- 要读哪些 KV 地址 -- 每个 request 的边界在哪里 - -所以它们更接近: - -- execution metadata -- kernel-facing metadata - -而不是 host/runtime metadata。 - -### 7.2 但它们也不是完全 storage-agnostic - -它们依赖: - -- 当前 KV cache 的地址空间模型 -- kernel 对该地址空间的解释方式 - -如果底层 cache 不是: - -- 可线性索引的 token slot -- 或可 flatten 的 page/block index - -那 `kv_indptr` / `kv_indices` 这套抽象就未必成立。 - -所以它们不是“无限通用”的。 - -## 8. 对重构的直接启示 - -这个背景知识对当前 `sglang attention` 重构最有价值的结论是: - -### 8.1 不要去统一 host metadata - -比如不要强行统一: - -- `ForwardBatch` -- `query_start_loc` -- `num_prefills` -- `run_graph` - -这些都属于宿主框架自己的 runtime 语言。 - -### 8.2 应该争取统一 execution metadata - -例如: - -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `page_table` -- `kv_last_page_len` -- `work_indptr` -- `reduce_indptr` - -以及围绕这些字段构造的更小 dataclass,例如当前 draft 里已经开始尝试的: - -- `MLADecodeKernelMetadata` -- `MHAAsmKernelMetadata` -- `MHAPersistentKernelMetadata` - -### 8.3 正确的共享边界 - -所以真正可共享的不是: - -- `sglang` 的 `ForwardMetadata` -- `vllm plugin` 那一整套 metadata family - -而是: - -- 它们再往下切出来的 -- kernel-facing execution metadata - -## 9. 一句话总结 - -**`kv_indptr` / `kv_indices` 不是“和存储彻底无关”的抽象,它们绑定的是某种被 kernel 接受的 KV 地址空间;但正因为它们已经脱离了宿主 batch/runtime 语义,`sglang` 和 `vllm/ATOM` 即使上层完全不同,也仍然可以在这层 metadata 上收敛并共享 kernel 调用逻辑。** diff --git a/work_log/attn_refractory/2026-05-08-metadata-layering-summary.md b/work_log/attn_refractory/2026-05-08-metadata-layering-summary.md deleted file mode 100644 index 86bccce8e6..0000000000 --- a/work_log/attn_refractory/2026-05-08-metadata-layering-summary.md +++ /dev/null @@ -1,144 +0,0 @@ -# 2026-05-08 Metadata Layering Summary - -## 1. 目的 - -这份笔记简要总结 `ATOM native`、`vllm plugin`、`sglang plugin` 三边 metadata 的关系,重点回答: - -- `ATOM/atom/utils/forward_context.py` 中的 `AttentionMetaData` 是什么 -- `ATOM/atom/plugin/attention.py` 里的 metadata dataclass 在做什么 -- `ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` 里的 `ForwardMetadata` 属于哪一层 -- 哪些 metadata 值得共享,哪些不值得强行统一 - -## 2. 三套 metadata 的定位 - -### 2.1 `AttentionMetaData` - -文件:`ATOM/atom/utils/forward_context.py` - -它是 `ATOM` attention 执行链路里的**通用外层容器**。 -里面既有: - -- 通用 attention 字段 - - `slot_mapping` - - `block_tables` - - `kv_indptr` - - `kv_indices` - - `work_meta_data` -- 也预留了 plugin 扩展入口 - - `plugin_metadata` - -所以它更像一个大容器,而不是某个 host 专属的数据模型。 - -### 2.2 `vllm plugin metadata` - -文件:`ATOM/atom/plugin/attention.py` - -`vllm plugin` 不是重写一套完全独立的 metadata,而是在 `AttentionMetaData` 外壳上,额外挂了一层 `plugin_metadata` payload。 - -代表类型包括: - -- `AiterFlashAttentionMetadataForPluginMode` -- `AiterMLACommonMetadataForPluginMode` -- `AiterMLADecodeMetadataForPluginMode` -- `AiterMLASparseMetadataForPluginMode` - -这些 dataclass 主要承载: - -- `vllm` 的 batch/runtime 语义 -- decode/prefill/extend phase 拆分 -- chunked prefill / DCP / sparse MLA 等 feature-specific 信息 - -一句话:**`vllm plugin` 是“外层复用 `AttentionMetaData`,内层自带一套 host-specific metadata family”。** - -### 2.3 `sglang plugin` 的 `ForwardMetadata` - -文件:`ATOM/atom/plugin/sglang/attention_backend/sgl_attn_backend.py` - -`ForwardMetadata` 不是 `AttentionMetaData.plugin_metadata` 那种 payload,而是 `sglang backend` 内部的**lowering 结果缓存对象**。 - -它更接近: - -- `ForwardBatch` -- `req_to_token` -- `token_to_kv_pool` -- graph/speculative path - -向 kernel metadata 的过渡层。 - -一句话:**`ForwardMetadata` 更像 backend-local lowering state,而不是 host-agnostic 的最终执行 metadata。** - -## 3. 三类字段 - -### 3.1 host/runtime-bound - -这类字段描述宿主框架如何组织 batch/runtime,而不是 kernel 如何计算。 - -例子: - -- `vllm plugin` - - `query_start_loc` - - `num_decodes` - - `num_prefills` - - `chunked_context` -- `sglang plugin` - - `run_graph` - - 各种 `ForwardBatch.forward_mode` 派生分支语义 - -### 3.2 execution/kernel-bound - -这类字段最接近 kernel 真正需要的执行信息。 - -例子: - -- `slot_mapping` -- `block_tables` -- `context_lens` -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `kv_last_page_len` -- `work_indptr` -- `reduce_indptr` - -### 3.3 hardware-sensitive - -这类字段和 page size、dtype、kernel 选择、并行配置相关。 - -例子: - -- `max_q_len` -- `max_kv_len` -- `num_kv_splits` -- `attn_out_dtype` -- `head_dim` - -## 4. 哪些值得共享 - -### 不建议强行共享的 - -- `AttentionMetaData` 整体 -- `vllm plugin` 整套 `plugin_metadata` -- `sglang plugin` 整体的 `ForwardMetadata` - -原因不是它们没价值,而是它们都带着很重的 host/runtime 语义。 - -### 值得重点共享的 - -应该共享的是更小的 **execution metadata view**,例如当前 draft 已经开始尝试的: - -- `MLADecodeKernelMetadata` -- `MHAAsmKernelMetadata` -- `MHAPersistentKernelMetadata` - -这些结构只保留某个 kernel call 真正需要的字段,更适合在: - -- `ATOM core` -- `vllm plugin` -- `sglang plugin` - -之间共享。 - -## 5. 一句话结论 - -现在不是三边共用“一套 metadata”,而是三边各自都有一层 host-owned metadata / lowering 体系。 -真正值得共享的,不是这些大 metadata 本身,而是从它们里面切出来的、更小的、kernel-facing execution metadata。 diff --git a/work_log/attn_refractory/2026-05-11-attn-refractory-session-summary.md b/work_log/attn_refractory/2026-05-11-attn-refractory-session-summary.md deleted file mode 100644 index 3b2128f5bd..0000000000 --- a/work_log/attn_refractory/2026-05-11-attn-refractory-session-summary.md +++ /dev/null @@ -1,398 +0,0 @@ -# 2026-05-11 ATOM Attention Refactory Session Summary - -> 主题:从 `sglang plugin` 设计者视角,梳理 `attention backend` 重构中的 runtime / metadata / KV cache / kernel reuse 边界,并记录几种 kernel 共用 draft 的结论。 - -## 1. 本次会话的主线 - -本次讨论的核心不是继续从 `vllm plugin` 出发看问题,而是明确切换到: - -- **`sglang plugin attn backend` 设计者视角** -- 关注 `sglang plugin` 自己的 runtime 语义 -- 判断哪些层应该继续由 `sglang plugin` 自有 -- 判断哪些层值得尽量与 `ATOM core` 共用 - -围绕这个目标,本次会话主要探索了 6 个问题: - -1. `RadixAttention` / `radix cache` 与 `PagedAttention` 的关系 -2. public `sglang` 是否也会把 radix runtime lowering 到 paged / index metadata -3. metadata 为什么会越长越多,以及哪些字段本质上属于 host/runtime -4. `kv_indptr` / `kv_indices` 为什么既绑定存储模型,又仍然能跨 host 共用 -5. `sglang plugin` 和 public `sglang` 在 KV cache 管理上的异同 -6. 在不大改文件结构、不做过度软件工程的前提下,怎么试探性地共用 kernel call - -## 2. 对 `sglang plugin` 重构最关键的几个判断 - -### 2.1 不是所有问题都该叫 “attention backend 复用” - -当前问题至少分成三层: - -1. **host/runtime 层** - - `ForwardBatch` - - `RadixAttention` - - speculative / graph / verify / extend 调度 -2. **execution metadata lowering 层** - - `page_table` - - `kv_indptr` - - `kv_indices` - - `qo_indptr` - - `kv_last_page_len` -3. **kernel / execution core 层** - - `flash_attn_varlen_func` - - `mla_decode_fwd` - - `mla_prefill_fwd` - - `pa_fwd_asm` - - `pa_persistent_fwd` - -后续任何“共用更多 code”的讨论,都应该先说明是在第几层说话。 - -### 2.2 `sglang plugin` 不应该为了复用去对齐到 `vllm` 的 host runtime - -`vllm plugin` 和 `sglang plugin` 上层接入模型不同: - -- `vllm plugin` 更像 layer-owned / metadata-builder-owned 路径 -- `sglang plugin` 更像 `ForwardBatch` 驱动的 backend runtime 路径 - -所以: - -- 不该共享 `vllm plugin` 那层 runtime facade -- 也不该强行统一大 metadata 对象 -- 真正值得共享的,是更靠近 kernel 的 execution metadata 和 kernel call - -### 2.3 `sglang plugin` 想共用 `ATOM` code,最现实的边界在 kernel call 一层 - -如果先不做大规模结构改造,最适合先共用的是: - -- `mla_decode_fwd` -- `pa_persistent_fwd` -- `pa_fwd_asm` -- `flash_attn_varlen_func` - -也就是: - -- 保留各自的 runtime lowering -- 先把“掉 kernel 前最后一跳”抽出来 - -## 3. `RadixAttention` / `radix cache` 和 `PagedAttention` 的结论 - -### 3.1 `RadixAttention` 不是最终 kernel - -`RadixAttention` 更像: - -- `sglang` 模型层统一使用的 attention layer 壳 -- 它知道自己跑在 `ForwardBatch` 语义里 -- 最终还是委托给 `forward_batch.attn_backend` - -所以它不是 “另一套底层注意力数学”,而是 **host-facing runtime abstraction**。 - -### 3.2 `PagedAttention` 和 `radix` 解决的问题不在同一层 - -- `radix cache` 更偏 prefix 命中 / 复用 / eviction -- `paged attention` 更偏 block/page 形式的物理组织和读取 - -可以理解成: - -- `radix` 管“历史怎么共享和命中” -- `paged` 管“命中的历史最后如何变成 page/block 索引给 kernel” - -### 3.3 public `sglang` 自己也在做类似 lowering - -这不是 `ATOM plugin` 才有的行为。 - -public `sglang` 的: - -- `aiter_backend` -- `triton_backend` -- `trtllm_mha_backend` -- 部分 `flashinfer` 路径 - -本身也会把: - -- `ForwardBatch` -- `req_to_token` -- `seq_lens` - -lower 成: - -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `page_table` / `block_tables` - -所以,`radix runtime -> paged/index metadata` 不是 plugin 特例,而是公共设计模式的一部分。 - -## 4. Metadata 分层结论 - -### 4.1 现在至少有三套 metadata family - -1. **`ATOM native`** - - `AttentionMetaData` -2. **`vllm plugin`** - - `AttentionMetaData` 外壳 - - `plugin_metadata` 内层 payload -3. **`sglang plugin`** - - `ForwardMetadata` - -### 4.2 `AttentionMetaData` 和 `plugin/attention.py` 的关系 - -`AttentionMetaData` 是外层通用容器; -`plugin/attention.py` 里的很多 dataclass,是 `vllm plugin` 为表达宿主 batch/runtime 语义而塞进 `plugin_metadata` 的 payload。 - -所以它们不是并列关系,而是: - -```text -AttentionMetaData - └─ plugin_metadata - ├─ flash-style plugin metadata - ├─ MLA plugin metadata - └─ sparse MLA plugin metadata -``` - -### 4.3 `ForwardMetadata` 的定位 - -`ForwardMetadata` 更像: - -- `sglang backend` 内部的 lowering state -- 它不是最终统一 metadata -- 更不是天然可共享的 host-agnostic object - -它里面混着: - -- runtime control 信息 -- execution metadata -- persistent kernel 预构造结果 - -## 5. 三类字段 - -为了判断能不能共享,本次把字段分成三类: - -### 5.1 host/runtime-bound - -典型例子: - -- `vllm` 的 `query_start_loc` -- `num_decodes` -- `num_prefills` -- `chunked_context` -- `sglang` 的 `run_graph` - -这些字段描述的是宿主框架如何组织 batch/runtime,而不是 kernel 怎么算。 - -### 5.2 execution/kernel-bound - -典型例子: - -- `slot_mapping` -- `block_tables` -- `page_table` -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `kv_last_page_len` -- `work_indptr` -- `reduce_indptr` - -这些字段才是 kernel 真正需要消费的执行信息。 - -### 5.3 hardware-sensitive - -典型例子: - -- `max_q_len` -- `max_kv_len` -- `num_kv_splits` -- `head_dim` -- `attn_out_dtype` - -这些字段不是 host 语义,但会影响: - -- page/block 解释 -- kernel 选路 -- workspace/split 策略 - -## 6. 关于 `kv_indptr` / `kv_indices` 的关键结论 - -这个问题本次单独做了较多背景解释,结论是: - -### 6.1 它们不是完全 storage-agnostic - -因为它们绑定的是某种 **可索引的 KV 地址空间**: - -- token slot index -- page/block index -- compacted KV index - -所以它们不能脱离底层存储模型独立存在。 - -### 6.2 但它们可以是 host/runtime-agnostic - -因为它们不表达: - -- `ForwardBatch.forward_mode` -- `query_start_loc` -- prefix 命中策略 -- graph/speculative 的宿主控制语义 - -它们只表达: - -- 这次 kernel 要去哪些 KV 地址读数据 -- 每个 request 的边界在哪里 - -### 6.3 两个例子 - -#### 例子 A:`sglang` - -- `req_to_token[req_id, token_pos] = slot_id` -- lowering 后得到: - - `kv_indptr = [0, 10]` - - `kv_indices = [8, 9, 10, 11, 0, 1, 2, 3, 4, 5]` - -#### 例子 B:`vllm/ATOM` - -- 上层是 `block_table` -- 展开后同样可得到: - - `kv_indptr = [0, 10]` - - `kv_indices = [8, 9, 10, 11, 0, 1, 2, 3, 4, 5]` - -这说明: - -- 上层来源不同 -- 但 lower 后的 execution metadata 可以收敛 - -## 7. public `sglang` 与 `sglang plugin` 在 KV cache 管理上的异同 - -### 7.1 相同点 - -- owner 都还是 public `sglang` - - `ReqToTokenPool` - - `TokenToKVPool` - - `KVCache` - - `radix_cache` -- request 到 slot 的映射机制没变 -- prefix / radix 复用体系没变 - -### 7.2 不同点 - -#### A. MHA 写 cache 路径不同 - -public `sglang` 更多沿用 pool 提供的标准写入接口: - -- `token_to_kv_pool.set_kv_buffer(...)` - -而 `sglang plugin` 在 MHA 路径里会绕过标准写法,自己做一次: - -- `set_kv_buffer_with_layout_shuffle(...)` -- 把同一块 public pool buffer 按更偏 `ATOM` kernel 的 SHUFFLE layout 写进去 - -#### B. MLA 路径更接近 public `sglang` - -MLA 上,plugin 更多沿用 public pool 的 contract: - -- `token_to_kv_pool.set_kv_buffer(...)` -- `get_key_buffer(...)` -- `get_value_buffer(...)` - -#### C. Lowering 目标不同 - -- public `sglang` 的 metadata 更通用 -- plugin 的 `ForwardMetadata` 更偏 `ATOM` kernel - - `page_table` - - `pa_metadata_*` - -### 7.3 一个更准确的说法 - -`sglang plugin` **没有改写 public `sglang` 的 pool owner / allocator / radix tree**, -但它在 MHA 路径上: - -- 改变了同一块 public pool buffer 的布局解释与写入约定 -- 并把 lowering 结果继续朝 `ATOM` kernel 所需的 metadata 形态推进 - -## 8. Kernel 共用的两种 draft - -本次会话里,实际做了两种方向的 draft。 - -### 8.1 方案 A:共享 kernel wrapper - -思路: - -- 新增共享小 metadata - - `MLADecodeKernelMetadata` - - `MHAAsmKernelMetadata` - - `MHAPersistentKernelMetadata` -- 新增共享 wrapper - - `run_mla_decode_kernel` - - `run_mha_asm_kernel` - - `run_mha_persistent_kernel` - -优点: - -- 边界清晰 -- 更像长期可演进的 execution API -- 不需要统一大 metadata - -缺点: - -- 还是引入了一层新的 shared contract - -### 8.2 方案 B:`sglang plugin` 主动适配 `ATOM core` - -思路: - -- 不改 `ATOM core` -- 在 `sglang plugin` 里把 `ForwardMetadata` 转成 `AttentionMetaData` -- 再伪造最小 `shim/self` -- 直接调用: - - `PagedAttentionImpl.paged_attention_asm` - - `PagedAttentionImpl.paged_attention_persistent_asm` - - `MLAAttention._forward_decode` - -优点: - -- 不改 `ATOM core` -- 更直观地证明 “plugin 主动复用 core” 是可行的 - -缺点: - -- 需要 shim -- 尤其 MLA 路径更脆弱 -- 更适合验证可行性,不适合长期保留 - -## 9. 当前最值得保留的判断 - -### 9.1 不应该强行统一大 metadata - -不建议直接统一: - -- `AttentionMetaData` -- `plugin_metadata` -- `ForwardMetadata` - -因为它们都带有较重的宿主 runtime 语义。 - -### 9.2 最值得共享的是 kernel-facing execution metadata - -真正最有价值的共享边界是: - -- `kv_indptr` -- `kv_indices` -- `qo_indptr` -- `page_table` -- `context_lens` -- `work_*` -- `reduce_*` - -以及基于它们的 kernel call / runner。 - -### 9.3 重构顺序建议 - -如果继续推进 `sglang plugin` 重构,更合理的顺序是: - -1. 保持 `sglang` runtime 自有 -2. 显式化 metadata lowering 边界 -3. 尽量共享 kernel call / execution helper -4. 最后再考虑是否继续上推到 runner 层 - -## 10. 一句话总结 - -本次会话最终收敛出的关键认识是: - -**`sglang plugin attention` 重构的关键,不是把 `sglang` runtime 改得更像 `vllm`,而是把 runtime、metadata lowering、kernel execution 这三层拆清楚;只要 lowering 之后能在 execution metadata 上对齐,就能在不统一宿主语义的前提下,共用更多 `ATOM` 的底层 attention 代码。** diff --git a/work_log/attn_refractory/decoupling-diagram.md b/work_log/attn_refractory/decoupling-diagram.md deleted file mode 100644 index b95c0031e5..0000000000 --- a/work_log/attn_refractory/decoupling-diagram.md +++ /dev/null @@ -1,158 +0,0 @@ -# ATOM Plugin Decoupling Diagram - -下面这张图对比了当前入口架构和目标解耦架构。这里采用的是更激进、也更干净的方案: - -- `vLLM plugin` 和 `SGLang plugin` 在 `atom/plugin/` 层彻底分家 -- 不再保留共享的 plugin runtime -- 共享只发生在更底层的 `ATOM core` - -## Current - -```mermaid -flowchart TD - A[Shared Plugin Entry
prepare.py / register.py / config.py] - B[Global Runtime State
_CURRENT_FRAMEWORK
current_atom_config
ops.Attention] - C[vLLM Branch
platform / registry / patch / graph] - D[SGLang Branch
wrapper / registry hijack / RadixAttention / graph] - E[Mixed Plugin Core
attention selection / static_forward_context / loader glue] - - A --> B - B --> C - B --> D - C --> E - D --> E -``` - -### 当前设计的核心问题 - -- 入口层通过切换全局状态来区分 `vLLM` 和 `SGLang`,而不是通过物理隔离的子系统建立边界。 -- `prepare.py` / `register.py` / `config.py` 这些共享入口,本身就是耦合中心。 -- host 差异直接泄漏到 plugin 内部结构里,导致 backend 选择、attention 抽象和 runtime context 都知道上层框架。 - -## Target - -```mermaid -flowchart TD - A1[vLLM Plugin
bootstrap.py
register.py
config.py
runtime.py
attention glue
model adapters] - A2[SGLang Plugin
bootstrap.py
register.py
config.py
runtime.py
attention glue
model adapters] - - B[ATOM Core
model_ops
kernels
loader
metadata structs
model specialization] - - A1 --> B - A2 --> B -``` - -### 目标设计的关键点 - -- `vLLM` 和 `SGLang` 在 plugin 层是两个完整子系统,而不是一套共享 runtime 上的两个分支。 -- `atom/plugin/` 下不再存在共享的 `prepare_model()`、共享的 `register.py` 总入口、共享的 plugin config translator。 -- `vLLM plugin` 自己处理: - - platform registration - - model override - - graph / MLA / weight hook patch - - vLLM model adapter / attention glue -- `SGLang plugin` 自己处理: - - external model wrapper - - attention backend registration / override - - graph / speculative / wrapper patch - - SGLang model adapter / attention glue -- 真正共享的只保留在更底层的 `ATOM core`: - - `model_ops` - - kernels - - loader - - generic metadata structures - - model-family specialization - -## 最重要的切断点 - -```mermaid -flowchart LR - A[Cut 1
shared prepare.py] - B[Cut 2
shared register.py] - C[Cut 3
shared config.py] - D[Cut 4
framework global state] - E[Cut 5
ops.Attention global rebinding] - - A --> F[vllm/ and sglang/ own bootstrap] - B --> F - C --> F - D --> G[host-local runtime only] - E --> H[host-specific attention glue] -``` - -## 推荐目录形态 - -```text -atom/plugin/ - vllm/ - bootstrap.py - register.py - config.py - runtime.py - attention/ - models/ - graph/ - sglang/ - bootstrap.py - register.py - config.py - runtime.py - attention/ - models/ - graph/ -``` - -这里的重点不是“换目录”,而是: - -- `vllm/runtime.py` 和 `sglang/runtime.py` 各自独立 -- `vllm/config.py` 和 `sglang/config.py` 各自独立 -- `vllm/register.py` 和 `sglang/register.py` 各自独立 -- 不再有 plugin 级共享 runtime - -## 术语说明 - -### bootstrap - -`bootstrap` 在这里指的是“插件被宿主框架加载时,最先执行的接线初始化层”。 - -它负责的事情通常包括: - -- 注册 plugin 扩展点 -- 安装 patch / hook -- 构造 host adapter / wrapper -- 把宿主配置翻译到 ATOM 可用的格式 - -它不应该负责: - -- attention forward 本身 -- 长期 runtime state 管理 -- 每次请求都要走的核心执行逻辑 - -### register - -`register` 指的是把 plugin 的能力挂到宿主暴露的扩展点上,例如: - -- 注册 platform -- 注册 model class -- 注册 attention backend - -它更偏“声明接入点”,通常是 bootstrap 的一部分。 - -### adapter - -`adapter` 指的是宿主框架和 ATOM core 之间的桥接层。 -它的职责是做接口和参数语义翻译,而不是定义 attention / model 的核心语义。 - -### runtime - -`runtime` 指的是插件在宿主里长期存在、服务执行链路的那部分运行时逻辑。 -如果按这份设计推进,`runtime` 应该是 host-local 的: - -- `vLLM runtime` 只属于 `vLLM plugin` -- `SGLang runtime` 只属于 `SGLang plugin` - -而不是共享一个“Plugin runtime”。 - -## 一句话结论 - -真正的解耦不是把共享入口再包装一层,而是把 `atom/plugin` 从“共享状态机”改成“两个完全独立的 host plugin 子系统”,共享只留给更底层的 `ATOM core`。 diff --git a/work_log/attn_refractory/vllm-integration-architecture.md b/work_log/attn_refractory/vllm-integration-architecture.md deleted file mode 100644 index e8876268d0..0000000000 --- a/work_log/attn_refractory/vllm-integration-architecture.md +++ /dev/null @@ -1,269 +0,0 @@ -# ATOM vLLM Plugin Integration Architecture - -## 1. 启动入口 - -vLLM plugin 的 setuptools entrypoint 定义在: - -- `pyproject.toml` - -关键位置: - -- `vllm.platform_plugins` - - `atom = "atom.plugin.vllm.register:register_platform"` -- `vllm.general_plugins` - - `atom_model_registry = "atom.plugin.vllm.register:register_model"` - -这意味着 vLLM 启动时会先进入: - -- `atom/plugin/vllm/register.py::register_platform()` -- `atom/plugin/vllm/register.py::register_model()` - -## 2. 关键代码文件与职责 - -### 2.1 `atom/plugin/vllm/register.py` - -这是 vLLM plugin 的启动接线层,主要负责: - -- 设置 plugin mode -- 返回自定义 platform -- 覆盖 vLLM 的 `ModelRegistry` -- 安装 MLA patch -- patch `Attention.process_weights_after_loading` -- 安装 graph capture patch - -重点符号: - -- `register_platform()` -- `register_model()` -- `_VLLM_MODEL_REGISTRY_OVERRIDES` - -## 2.2 `atom/plugin/vllm/platform.py` - -这是 vLLM 平台侧 attention backend 选择入口。 - -重点符号: - -- `ATOMPlatform.get_attn_backend_cls()` - -它根据 `attn_selector_config` 决定: - -- MHA -> `atom.model_ops.attentions.aiter_attention.AiterBackend` -- MLA -> `atom.model_ops.attentions.aiter_mla.AiterMLABackend` -- sparse MLA -> `atom.plugin.vllm.attention_backend.mla_sparse.AiterMLASparseBackend` - -这里是 vLLM runtime 真正选择 “哪个 ATOM attention backend class” 的入口。 - -## 2.3 `atom/plugin/vllm/model_wrapper.py` - -这是 vLLM model integration 的核心文件。 - -主要职责: - -- 把 `VllmConfig` 翻译成 `atom_config` -- 做 plugin 运行环境准备 -- 选择并实例化 ATOM 模型类 -- 把 vLLM forward context 中的 `positions` 传给 ATOM 路径 -- 对 sparse MLA indexer 做额外注册 - -重点符号: - -- `_prepare_env()` -- `ATOMModelBase.__init__()` -- `_register_indexer_caches_with_vllm()` -- `forward()` - -## 2.4 `atom/plugin/config.py` - -这是 vLLM 配置翻译的桥。 - -重点符号: - -- `generate_atom_config_for_vllm_plugin()` -- `_generate_atom_config_from_vllm_config()` - -它负责把: - -- `VllmConfig` -- `scheduler_config` -- `cache_config` -- `parallel_config` -- `quant_config` - -映射成 ATOM 自己的 `Config` 与 `PluginConfig`。 - -## 2.5 `atom/model_ops/paged_attention.py` - -这是 ATOM attention 在 vLLM plugin 下真正进入执行的关键对象。 - -重点符号: - -- `PagedAttention.__init__()` -- `PagedAttention.forward()` - -在 `is_vllm()` 下,它不会直接走 native/server 的那套 `impl` 初始化,而是: - -- 构造 vLLM 的 `Attention` / `MLAAttention` 外壳 -- 把额外 impl 参数塞进去 -- 注册到 `static_forward_context` -- forward 时通过 `unified_attention_with_output_base_for_plugin_mode()` 回调到 ATOM impl - -## 2.6 `atom/plugin/attention.py` - -这是 vLLM plugin mode 的核心 glue 层。 - -虽然文件在 `plugin/` 根目录,但从职责看它明显偏 vLLM。 - -主要职责: - -- 定义 `vllmAiterAttentionBackendMethods` -- 提供 backend decorator / metadata builder decorator -- 处理 plugin-mode metadata -- 处理 `static_forward_context` -- 为 vLLM 的 attention backend 补 OOT 接口行为 - -重点符号: - -- `vllmAiterAttentionBackendMethods` -- 各类 `*DecoratorForPluginMode` - -## 2.7 `atom/plugin/attention_mha.py` - -这是 vLLM plugin mode 下 MHA impl 的补丁层。 - -重点符号: - -- `PagedAttentionImplPluginModeMethods` - -它补的是: - -- rope/cache/update -- plugin-mode forward -- 与 vLLM KV cache/metadata 对齐的逻辑 - -## 2.8 `atom/plugin/attention_mla.py` - -这是 vLLM plugin mode 下 MLA impl 的补丁层。 - -重点符号: - -- `MLAAttentionImplPluginModeMethods` -- `_mla_plugin_mode_init` - -它补的是: - -- MLA plugin-mode metadata -- qk rope / kv cache / chunked prefill -- vLLM MLA 路径专有行为 - -## 2.9 `atom/plugin/vllm/mla_patch.py` - -这是 vLLM 原生 `MLAAttention` 的 monkey patch 入口。 - -重点符号: - -- `_patch_vllm_mla_attention_forward_impl()` -- `_patch_vllm_mla_attention_process_weights_after_loading()` -- `patch_vllm_mla_attention()` - -这层作用是: - -- 把 vLLM `MLAAttention.forward_impl()` 改接到 ATOM impl -- 把权重后处理改接到 ATOM 的 `impl.process_weights_after_loading()` - -## 2.10 `atom/plugin/vllm/graph_capture_patch.py` - -这是 vLLM graph capture 补丁入口。 - -重点符号: - -- `apply_graph_capture_patch()` - -它实际委托到共享实现: - -- `atom/plugin/graph_capture_patch.py` - -但调用点是 vLLM 自己的 plugin register 流程。 - -## 3. 当前 vLLM plugin 集成链路 - -```mermaid -flowchart TD - A[vLLM startup] - B[entrypoint
register_platform] - C[entrypoint
register_model] - D[ATOMPlatform.get_attn_backend_cls] - E[ATOMModelBase] - F[generate_atom_config_for_vllm_plugin] - G[_prepare_env
set_attn_cls + init_aiter_dist] - H[Instantiate ATOM model] - I[PagedAttention / MLAAttention wrapper] - J[plugin/attention*.py glue] - K[ATOM impl
PagedAttentionImpl / MLAAttention] - L[aiter kernels] - - A --> B - A --> C - B --> D - C --> E - E --> F - E --> G - E --> H - H --> I - I --> J - J --> K - K --> L -``` - -## 4. 最关键的架构判断 - -### 4.1 vLLM plugin 不是简单“替换 backend class” - -它实际上同时做了三件事: - -1. 替换 vLLM 的 model entry -2. 替换/选择 vLLM 的 attention backend class -3. 用 patch + decorator 的方式把 vLLM 的 layer/runtime 语义桥接到 ATOM impl - -### 4.2 真正共享得最好的是 native/server 与 vLLM plugin 的 attention impl - -尤其: - -- `PagedAttentionImpl` -- `MLAAttention` -- 低层 `aiter` kernel family - -也就是说,vLLM plugin 侧的关键复杂度主要不在 kernel,而在: - -- `ModelRegistry` override -- `VllmConfig -> atom_config` -- `static_forward_context` -- `MLAAttention` patch -- graph capture patch - -### 4.3 当前 `plugin/attention.py`、`attention_mha.py`、`attention_mla.py` - -虽然放在 `atom/plugin/` 根目录,但从 vLLM integration 的角度看,它们本质上更像: - -- `vLLM plugin impl glue` -- 而不是真正框架无关的 shared runtime - -## 5. 推荐你重点阅读的文件顺序 - -如果你是为了快速理解 vLLM plugin 集成架构,建议按这个顺序读: - -1. `pyproject.toml` -2. `atom/plugin/vllm/register.py` -3. `atom/plugin/vllm/platform.py` -4. `atom/plugin/vllm/model_wrapper.py` -5. `atom/plugin/config.py` -6. `atom/model_ops/paged_attention.py` -7. `atom/plugin/vllm/mla_patch.py` -8. `atom/plugin/attention.py` -9. `atom/plugin/attention_mha.py` -10. `atom/plugin/attention_mla.py` - -## 6. 一句话结论 - -vLLM plugin 的集成方式,本质上是: - -**用 entrypoint + platform + model wrapper 接入 vLLM,用 ATOM 自己的 attention impl 和 aiter kernels 提供核心执行,再用 patch/decorator 把 vLLM 的 layer/runtime 语义桥接进去。** From 77aa82ad2b845f470c9cef707695a7cf0ea7d5d1 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Wed, 13 May 2026 07:21:25 +0000 Subject: [PATCH 04/17] [ATOM-SGL][Attn refrac] Route DeepSeek MLA through an SGLang wrapper Move the SGLang DeepSeek MLA runtime entry from legacy forward glue into SGLangDeepseekMLAAttention while keeping RadixAttention and the full-attention backend as the host/backend layers. Shrink deepseek_mla_forward.py into a helper module and clarify absorbed vs non-absorbed path naming. --- atom/plugin/sglang/models/deepseek_mla.py | 28 +- .../sglang/models/deepseek_mla_attention.py | 340 ++++++++++ .../sglang/models/deepseek_mla_forward.py | 611 +----------------- 3 files changed, 378 insertions(+), 601 deletions(-) create mode 100644 atom/plugin/sglang/models/deepseek_mla_attention.py diff --git a/atom/plugin/sglang/models/deepseek_mla.py b/atom/plugin/sglang/models/deepseek_mla.py index 1d7b1e44bc..c6b62f07ac 100644 --- a/atom/plugin/sglang/models/deepseek_mla.py +++ b/atom/plugin/sglang/models/deepseek_mla.py @@ -3,20 +3,20 @@ """Model-level DeepSeek MLA patching for SGLang plugin mode. -This module owns the monkey-patch entrypoints that adapt DeepSeek MLA models to -SGLang plugin mode. The heavy DeepSeek-specific forward and weight helpers live -in `atom.plugin.sglang.models.deepseek_mla_forward`. +This module owns the install-time hooks that adapt DeepSeek MLA models to +SGLang plugin mode. The heavy DeepSeek-specific runtime helpers live in +`atom.plugin.sglang.models.deepseek_mla_forward`. """ from __future__ import annotations from typing import TYPE_CHECKING, Any -import torch - +from atom.plugin.sglang.models.deepseek_mla_attention import ( + SGLangDeepseekMLAAttention, +) from atom.plugin.sglang.models.deepseek_mla_forward import ( _patch_kv_b_proj_for_sglang_mxfp4, - forward_sgl_plugin_mode, init_sgl_attrs, process_mla_kv_b_proj_after_loading, ) @@ -58,20 +58,8 @@ def _patch_mla_attention_for_sglang( """Patch one DeepSeek MLA layer for SGLang plugin mode.""" init_sgl_attrs(attn, config, kv_cache_dtype) _patch_kv_b_proj_for_sglang_mxfp4(attn) - - def patched_forward( - positions: torch.Tensor, - hidden_states: torch.Tensor, - **kwargs: Any, - ) -> torch.Tensor: - from atom.plugin.sglang.models.base_model_wrapper import ( - get_current_forward_batch, - ) - - kwargs["forward_batch"] = get_current_forward_batch() - return forward_sgl_plugin_mode(attn, positions, hidden_states, **kwargs) - - attn.forward = patched_forward + if not isinstance(attn.mla_attn, SGLangDeepseekMLAAttention): + attn.mla_attn = SGLangDeepseekMLAAttention(attn, attn.mla_attn) attn.process_weights_after_loading = lambda: process_mla_kv_b_proj_after_loading( attn ) diff --git a/atom/plugin/sglang/models/deepseek_mla_attention.py b/atom/plugin/sglang/models/deepseek_mla_attention.py new file mode 100644 index 0000000000..ce20aae418 --- /dev/null +++ b/atom/plugin/sglang/models/deepseek_mla_attention.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""DeepSeek MLA wrapper for SGLang plugin mode. + +This adapter keeps the model-side entry at ``self.mla_attn(...)`` and owns the +SGLang-specific runtime dispatch for DeepSeek MLA. It is intentionally shaped +closer to the vLLM plugin path than the older model-side monkey-patched entry. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import torch +from torch import nn + +if TYPE_CHECKING: + from atom.models.deepseek_v2 import DeepseekV2MLAAttention + + +class SGLangDeepseekMLAAttention(nn.Module): + """Enter SGLang DeepSeek MLA runtime through ``self.mla_attn(...)``.""" + + def __init__( + self, + owner_attn: "DeepseekV2MLAAttention", + base_attn: nn.Module, + ) -> None: + super().__init__() + self.owner_attn = owner_attn + self.base_attn = base_attn + + @property + def attn(self): + return getattr(self.base_attn, "attn", self.base_attn) + + def _get_forward_batch(self, kwargs: dict[str, Any]): + forward_batch = kwargs.get("forward_batch", None) + if forward_batch is None: + from atom.plugin.sglang.models.base_model_wrapper import ( + get_current_forward_batch, + ) + + forward_batch = get_current_forward_batch() + kwargs["forward_batch"] = forward_batch + if forward_batch is None: + raise RuntimeError( + "forward_batch is required for SGLang DeepSeek MLA wrapper" + ) + return forward_batch + + def _infer_total_tokens(self, forward_batch, tensor: torch.Tensor) -> int: + if hasattr(forward_batch, "input_ids") and forward_batch.input_ids is not None: + return int(forward_batch.input_ids.shape[0]) + if hasattr(forward_batch, "positions") and forward_batch.positions is not None: + return int(forward_batch.positions.shape[0]) + if hasattr(forward_batch, "seq_lens_sum"): + return int(forward_batch.seq_lens_sum) + return int(tensor.shape[0]) + + def _maybe_all_gather( + self, + tensor: torch.Tensor | None, + *, + total_tokens: int, + input_scattered: bool, + ): + if tensor is None or not input_scattered: + return tensor + from sglang.srt.distributed import get_tp_group + + output = tensor.new_empty((total_tokens, *tensor.shape[1:])) + get_tp_group().all_gather_into_tensor(output, tensor) + return output + + def _gather_runtime_inputs( + self, + q_input: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + positions: torch.Tensor, + q_scale: torch.Tensor | None, + *, + forward_batch, + input_scattered: bool, + ): + total_tokens = self._infer_total_tokens(forward_batch, q_input) + q_input = self._maybe_all_gather( + q_input, + total_tokens=total_tokens, + input_scattered=input_scattered, + ) + kv_c_normed = self._maybe_all_gather( + kv_c_normed, + total_tokens=total_tokens, + input_scattered=input_scattered, + ) + k_pe = self._maybe_all_gather( + k_pe, + total_tokens=total_tokens, + input_scattered=input_scattered, + ) + positions = self._maybe_all_gather( + positions, + total_tokens=total_tokens, + input_scattered=input_scattered, + ) + q_scale = self._maybe_all_gather( + q_scale, + total_tokens=total_tokens, + input_scattered=input_scattered, + ) + return q_input, kv_c_normed, k_pe, positions, q_scale + + def _project_q( + self, + q_input: torch.Tensor, + q_scale: torch.Tensor | None, + ) -> torch.Tensor: + attn = self.owner_attn + from atom.plugin.sglang.models.deepseek_mla_forward import _unwrap_linear_output + + if attn.q_lora_rank is not None: + q = ( + attn.q_b_proj(q_input, q_scale) + if q_scale is not None + else attn.q_b_proj(q_input) + ) + else: + q = ( + attn.q_proj(q_input, q_scale) + if q_scale is not None + else attn.q_proj(q_input) + ) + return _unwrap_linear_output(q).view( + -1, attn.num_local_heads, attn.qk_head_dim + ) + + def _forward_absorbed( + self, + q_input: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + positions: torch.Tensor, + q_scale: torch.Tensor | None, + *, + forward_batch, + ) -> torch.Tensor: + attn = self.owner_attn + from aiter import dtypes + from atom.model_ops.attention_mla import fused_qk_rope_concat_and_cache_mla + from atom.plugin.sglang.models.deepseek_mla_forward import ( + _get_sglang_radix_attn, + mla_absorbed_bmm, + mla_v_up_proj, + ) + from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp + + q = self._project_q(q_input, q_scale) + k_nope = kv_c_normed.unsqueeze(1) + k_pe = k_pe.unsqueeze(1) + q_nope, q_pe = q.split( + [attn.qk_nope_head_dim, attn.qk_rope_head_dim], dim=-1 + ) + q_nope_out = mla_absorbed_bmm( + attn, q_nope, attn.w_kc, attn.w_scale, attn.w_scale_k, attn.kv_lora_rank + ) + + if attn.rotary_emb is not None and not attn.use_fused_qk_rope_concat_and_cache_mla: + q_pe, k_pe = attn.rotary_emb(positions, q_pe, k_pe) + + if nsa_use_prefill_cp(forward_batch): + latent_cache = torch.cat([k_nope.squeeze(1), k_pe.squeeze(1)], dim=-1) + k_nope, k_pe = attn.rebuild_cp_kv_cache( + latent_cache, forward_batch, k_nope, k_pe + ) + + save_kv_cache = True + if attn.use_fused_qk_rope_concat_and_cache_mla: + mla_attn = _get_sglang_radix_attn(self.base_attn) + kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(mla_attn.layer_id) + q_out_dtype = ( + dtypes.fp8 + if attn.kv_cache_dtype == "fp8_e4m3" + else q_nope_out.dtype + ) + q = torch.empty( + ( + q_nope_out.shape[0], + attn.num_local_heads, + attn.kv_lora_rank + attn.qk_rope_head_dim, + ), + dtype=q_out_dtype, + device=q_nope_out.device, + ) + fused_qk_rope_concat_and_cache_mla( + q_nope_out, + q_pe, + k_nope, + k_pe, + kv_cache, + q, + forward_batch.out_cache_loc, + mla_attn.k_scale, + mla_attn.k_scale, + positions, + attn.rotary_emb.cos_cache, + attn.rotary_emb.sin_cache, + is_neox=attn.rotary_emb.is_neox_style, + is_nope_first=True, + ) + k = None + v = None + save_kv_cache = False + else: + q = torch.cat([q_nope_out, q_pe], dim=-1) + k = torch.cat([k_nope, k_pe], dim=-1) + v = k_nope + + attn_output = self.base_attn( + q, + k, + v, + forward_batch=forward_batch, + save_kv_cache=save_kv_cache, + ) + attn_output = attn_output.view(-1, attn.num_local_heads, attn.kv_lora_rank) + attn_bmm_output = mla_v_up_proj( + attn, attn_output, attn.w_vc, attn.w_scale, attn.w_scale_v, attn.v_head_dim + ) + return attn.o_proj(attn_bmm_output) + + def _forward_non_absorbed( + self, + q_input: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + positions: torch.Tensor, + q_scale: torch.Tensor | None, + *, + forward_batch, + ) -> torch.Tensor: + attn = self.owner_attn + from atom.plugin.sglang.models.deepseek_mla_forward import ( + _concat_mha_k_for_non_absorbed, + _set_mla_kv_buffer_for_non_absorbed, + _unwrap_linear_output, + ) + + q = self._project_q(q_input, q_scale) + _, q_pe = q.split([attn.qk_nope_head_dim, attn.qk_rope_head_dim], dim=-1) + + kv_a = kv_c_normed + k_pe = k_pe.unsqueeze(1) + if attn.rotary_emb is not None: + q_pe, k_pe = attn.rotary_emb(positions, q_pe, k_pe) + q[..., attn.qk_nope_head_dim :] = q_pe + + _set_mla_kv_buffer_for_non_absorbed(attn, kv_a, k_pe, forward_batch) + + kv = _unwrap_linear_output(attn.kv_b_proj(kv_a)).view( + -1, attn.num_local_heads, attn.qk_nope_head_dim + attn.v_head_dim + ) + k_nope = kv[..., : attn.qk_nope_head_dim] + v = kv[..., attn.qk_nope_head_dim :] + k = _concat_mha_k_for_non_absorbed(attn, k_nope, k_pe) + + attn_output = attn.attn_non_absorbed( + q, + k, + v, + forward_batch=forward_batch, + save_kv_cache=False, + ) + attn_output = attn_output.reshape(-1, attn.num_local_heads * attn.v_head_dim) + return attn.o_proj(attn_output) + + def forward( + self, + q_input: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + positions: torch.Tensor, + q_scale: torch.Tensor | None = None, + **kwargs: Any, + ) -> torch.Tensor: + attn = self.owner_attn + forward_batch = self._get_forward_batch(kwargs) + + from atom.plugin.sglang.models.deepseek_mla_forward import ( + _can_run_non_absorbed_mla_now, + ) + from sglang.srt.layers.communicator import get_attn_tp_context + + attn_tp_context = get_attn_tp_context() + with attn_tp_context.maybe_input_scattered(forward_batch): + q_input, kv_c_normed, k_pe, positions, q_scale = self._gather_runtime_inputs( + q_input, + kv_c_normed, + k_pe, + positions, + q_scale, + forward_batch=forward_batch, + input_scattered=attn_tp_context.input_scattered, + ) + + use_non_absorbed = forward_batch.forward_mode.is_extend_without_speculative() + if not use_non_absorbed and forward_batch.forward_mode.is_draft_extend(): + extend_prefix_lens_cpu = getattr( + forward_batch, "extend_prefix_lens_cpu", None + ) + use_non_absorbed = ( + extend_prefix_lens_cpu is not None + and not any(extend_prefix_lens_cpu) + ) + + if use_non_absorbed: + if _can_run_non_absorbed_mla_now(attn, forward_batch): + attn.current_sgl_plugin_attn_path = "non_absorbed" + return self._forward_non_absorbed( + q_input, + kv_c_normed, + k_pe, + positions, + q_scale, + forward_batch=forward_batch, + ) + attn.current_sgl_plugin_attn_path = "absorbed_fallback" + else: + attn.current_sgl_plugin_attn_path = "absorbed" + + return self._forward_absorbed( + q_input, + kv_c_normed, + k_pe, + positions, + q_scale, + forward_batch=forward_batch, + ) diff --git a/atom/plugin/sglang/models/deepseek_mla_forward.py b/atom/plugin/sglang/models/deepseek_mla_forward.py index 80fb075e4f..58362c54dd 100644 --- a/atom/plugin/sglang/models/deepseek_mla_forward.py +++ b/atom/plugin/sglang/models/deepseek_mla_forward.py @@ -1,40 +1,28 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -"""Model-specific DeepSeek MLA helpers for SGLang plugin mode. +"""Helper functions for DeepSeek MLA in SGLang plugin mode. -DeepSeek MLA (Multi-Latent Attention) forward logic for SGLang plugin mode: -absorbed BMM computation, MHA/MLA path dispatch (prefill -> MHA, decode -> MLA), -and kv_b_proj weight splitting (w_kc/w_vc). - -This module lives under ``atom.plugin.sglang.models`` because the logic is -DeepSeek-model-specific rather than a generic SGLang attention backend. - -TODO: rewrite this file once sglang's attention flow is unified into ATOM's -attention layer - the MLA absorbed path and MHA dispatch will then be handled -natively by ATOM's attention ops, making this sglang-specific module -unnecessary. +This module now contains only the low-level helpers that are still shared by +the SGLang DeepSeek MLA wrapper and the install-time weight hooks: +absorbed BMM math, small utility helpers, non-absorbed cache staging, and +kv_b_proj post-load processing. """ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, NamedTuple, Optional +from typing import TYPE_CHECKING, Any, Optional import torch from aiter import dtypes -from aiter.dist.parallel_state import get_tensor_model_parallel_world_size, get_tp_group from atom.model_ops.base_attention import Attention from atom.model_ops.attention_mla import ( - concat_and_cache_mla, dynamic_per_batched_tensor_quant, - fused_qk_rope_concat_and_cache_mla, ) -from atom.models.utils import maybe_prefix from atom.models.deepseek_v2 import _fuse_rmsnorm_quant +from atom.models.utils import maybe_prefix -from sglang.srt.layers.communicator import AttentionInputs, get_attn_tp_context -from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.models.deepseek_common.utils import ( _use_aiter_gfx95, @@ -100,25 +88,6 @@ def bmm_fp8(A, B, A_scale, B_scale, dtype, out=None): raise RuntimeError("bmm_fp8 requires CUDA (sgl_kernel)") -class SglPrepareResult(NamedTuple): - q_pe: torch.Tensor - k_pe: torch.Tensor - q_nope_out: torch.Tensor - k_nope: torch.Tensor - forward_batch: Any - zero_allocator: Any - positions: torch.Tensor - topk_indices: Optional[torch.Tensor] - llama_4_scaling: Optional[Any] - - -class SglMhaPrepareResult(NamedTuple): - q: torch.Tensor - k: torch.Tensor - v: torch.Tensor - forward_batch: Any - - def _unwrap_linear_output(output: Any) -> torch.Tensor: """Normalize ATOM/public-SGLang linear outputs to a tensor.""" if isinstance(output, tuple): @@ -216,7 +185,7 @@ def init_sgl_attrs( attn.w_scale = None attn.w_scale_k = None attn.w_scale_v = None - attn.attn_mha = Attention( + attn.attn_non_absorbed = Attention( num_heads=attn.num_local_heads, head_dim=attn.qk_head_dim, scale=attn.scaling, @@ -225,10 +194,10 @@ def init_sgl_attrs( layer_num=attn.layer_num, use_mla=False, v_head_dim=attn.v_head_dim, - prefix=maybe_prefix(attn.prefix, "attn_mha"), + prefix=maybe_prefix(attn.prefix, "attn_non_absorbed"), ) - if hasattr(attn.attn_mha, "attn"): - attn.attn_mha.attn.kv_b_proj = None + if hasattr(attn.attn_non_absorbed, "attn"): + attn.attn_non_absorbed.attn.kv_b_proj = None def mla_absorbed_bmm( @@ -379,276 +348,11 @@ def mla_v_up_proj( ).flatten(1, 2) -def forward_sgl_prepare( - attn: DeepseekV2MLAAttention, - positions: torch.Tensor, - hidden_states: torch.Tensor, - **model_kwargs, -) -> SglPrepareResult: - """Prepare QKV for sglang MLA attention.""" - hidden_states_scale = None - if isinstance(hidden_states, tuple): - hidden_states, hidden_states_scale = hidden_states - - forward_batch = model_kwargs.get("forward_batch", None) - zero_allocator = model_kwargs.get("zero_allocator", None) - llama_4_scaling = model_kwargs.get("llama_4_scaling", None) - q_lora = None - topk_indices = None - - if attn.q_lora_rank is not None: - q, latent_cache = ( - get_attn_tp_context() - .fetch_qkv_latent() - .split( - [attn.q_lora_rank, attn.kv_lora_rank + attn.qk_rope_head_dim], - dim=-1, - ) - ) - - if ( - q.shape[0] != positions.shape[0] - and get_tensor_model_parallel_world_size() > 1 - ): - qkv_lora = torch.cat([q, latent_cache], dim=-1) - qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) - if qkv_lora.shape[0] < positions.shape[0]: - raise RuntimeError( - f"qkv_lora gather mismatch: got {qkv_lora.shape[0]}, " - f"expected {positions.shape[0]}" - ) - qkv_lora = qkv_lora[: positions.shape[0]] - q, latent_cache = torch.split( - qkv_lora, - [attn.q_lora_rank, attn.kv_lora_rank + attn.qk_rope_head_dim], - dim=-1, - ) - - k_nope = latent_cache[..., : attn.kv_lora_rank] - q_scale = None - - if getattr(attn, "fuse_qknorm_quant", False): - q, q_scale, q_lora, k_nope = _fuse_qk_rmsnorm_and_q_quant( - attn, - q, - k_nope, - output_unquantized_q=attn.use_nsa, - ) - elif getattr(attn, "fuse_qknorm", False): - q, k_nope = _fuse_qk_rmsnorm(attn, q, k_nope) - elif attn.alt_stream is not None and get_is_capture_mode(): - current_stream = torch.cuda.current_stream() - attn.alt_stream.wait_stream(current_stream) - q = attn.q_a_layernorm(q) - with torch.cuda.stream(attn.alt_stream): - k_nope = attn.kv_a_layernorm(k_nope) - current_stream.wait_stream(attn.alt_stream) - else: - q = attn.q_a_layernorm(q) - k_nope = attn.kv_a_layernorm(k_nope) - - if attn.use_nsa and q_lora is None: - q_lora = q - - if ( - attn.alt_stream is not None - and get_is_capture_mode() - and forward_batch.forward_mode.is_decode_or_idle() - and q_lora is not None - ): - current_stream = torch.cuda.current_stream() - attn.alt_stream.wait_stream(current_stream) - with torch.cuda.stream(attn.alt_stream): - k_nope = k_nope.unsqueeze(1) - q = _unwrap_linear_output( - attn.q_b_proj(q, q_scale) - if q_scale is not None - else attn.q_b_proj(q) - ).view(-1, attn.num_local_heads, attn.qk_head_dim) - topk_indices = attn.indexer( - x=hidden_states, - q_lora=q_lora, - positions=positions, - forward_batch=forward_batch, - layer_id=attn.layer_num, - ) - current_stream.wait_stream(attn.alt_stream) - else: - k_nope = k_nope.unsqueeze(1) - q = _unwrap_linear_output( - attn.q_b_proj(q, q_scale) if q_scale is not None else attn.q_b_proj(q) - ).view(-1, attn.num_local_heads, attn.qk_head_dim) - if q_lora is not None: - topk_indices = attn.indexer( - x=hidden_states, - q_lora=q_lora, - positions=positions, - forward_batch=forward_batch, - layer_id=attn.layer_num, - ) - else: - q = _unwrap_linear_output(attn.q_proj(hidden_states)).view( - -1, attn.num_local_heads, attn.qk_head_dim - ) - latent_cache = _unwrap_linear_output(attn.kv_a_proj_with_mqa(hidden_states)) - k_nope = latent_cache[..., : attn.kv_lora_rank] - k_nope = attn.kv_a_layernorm(k_nope).unsqueeze(1) - - q_nope, q_pe = q.split([attn.qk_nope_head_dim, attn.qk_rope_head_dim], dim=-1) - k_pe = latent_cache[..., attn.kv_lora_rank :].unsqueeze(1) - - q_nope_out = mla_absorbed_bmm( - attn, q_nope, attn.w_kc, attn.w_scale, attn.w_scale_k, attn.kv_lora_rank - ) - - if attn.rotary_emb is not None and not attn.use_fused_qk_rope_concat_and_cache_mla: - q_pe, k_pe = attn.rotary_emb(positions, q_pe, k_pe) - - if nsa_use_prefill_cp(forward_batch): - k_nope, k_pe = attn.rebuild_cp_kv_cache( - latent_cache, forward_batch, k_nope, k_pe - ) - - return SglPrepareResult( - q_pe=q_pe, - k_pe=k_pe, - q_nope_out=q_nope_out, - k_nope=k_nope, - forward_batch=forward_batch, - zero_allocator=zero_allocator, - positions=positions, - topk_indices=topk_indices, - llama_4_scaling=llama_4_scaling, - ) - - -def forward_sgl_core( - attn: DeepseekV2MLAAttention, - prepared: SglPrepareResult, -) -> torch.Tensor: - """Core MLA attention computation for sglang.""" - save_kv_cache = True - - if attn.use_fused_qk_rope_concat_and_cache_mla: - mla_attn = _get_sglang_radix_attn(attn.mla_attn) - kv_cache = prepared.forward_batch.token_to_kv_pool.get_key_buffer( - mla_attn.layer_id - ) - q_out_dtype = ( - dtypes.fp8 - if attn.kv_cache_dtype == "fp8_e4m3" - else prepared.q_nope_out.dtype - ) - q = torch.empty( - ( - prepared.q_nope_out.shape[0], - attn.num_local_heads, - attn.kv_lora_rank + attn.qk_rope_head_dim, - ), - dtype=q_out_dtype, - device=prepared.q_nope_out.device, - ) - - fused_qk_rope_concat_and_cache_mla( - prepared.q_nope_out, - prepared.q_pe, - prepared.k_nope, - prepared.k_pe, - kv_cache, - q, - prepared.forward_batch.out_cache_loc, - mla_attn.k_scale, - mla_attn.k_scale, - prepared.positions, - attn.rotary_emb.cos_cache, - attn.rotary_emb.sin_cache, - is_neox=attn.rotary_emb.is_neox_style, - is_nope_first=True, - ) - k = None - v = None - save_kv_cache = False - else: - q = torch.cat([prepared.q_nope_out, prepared.q_pe], dim=-1) - k = torch.cat([prepared.k_nope, prepared.k_pe], dim=-1) - v = prepared.k_nope - - if prepared.llama_4_scaling is not None: - q = q * prepared.llama_4_scaling - - extra_kwargs = {} - if prepared.topk_indices is not None: - extra_kwargs["topk_indices"] = prepared.topk_indices - - attn_output = attn.mla_attn( - q, - k, - v, - forward_batch=prepared.forward_batch, - save_kv_cache=save_kv_cache, - **extra_kwargs, - ) - attn_output = attn_output.view(-1, attn.num_local_heads, attn.kv_lora_rank) - - attn_bmm_output = mla_v_up_proj( - attn, attn_output, attn.w_vc, attn.w_scale, attn.w_scale_v, attn.v_head_dim - ) - - return attn.o_proj(attn_bmm_output) - - -def _dispatch_sgl_plugin_attn_path(forward_batch) -> str: - """Decide the attention algorithm for this batch based on forward_mode. - - Returns "mha" for extend/prefill-style batches (uses standard Q×K×V - with flash_attn) or "mla" for decode/verify batches (uses absorbed - weights + mla_decode_fwd). - - This is the per-batch *routing* decision, distinct from - ``_can_run_sgl_mha_now`` which is a *capability* gate checking whether - the model configuration supports the MHA path at all. - """ - if forward_batch.forward_mode.is_extend_without_speculative(): - return "mha" - - if forward_batch.forward_mode.is_draft_extend(): - # The explicit K/V path is only memory-friendly for no-prefix draft - # extend. With prefix/context, SGLang's MLA backend has to materialize - # full k_prefix/v_prefix from latent cache, which can OOM during graph - # capture. Use absorbed MLA for those batches until chunked prefix - # expansion exists here. - extend_prefix_lens_cpu = getattr(forward_batch, "extend_prefix_lens_cpu", None) - if extend_prefix_lens_cpu is not None and not any(extend_prefix_lens_cpu): - return "mha" - - return "mla" - - -def forward_sgl_plugin_mode_mla( - attn: DeepseekV2MLAAttention, - positions: torch.Tensor, - hidden_states: torch.Tensor, - **model_kwargs, -) -> torch.Tensor: - prepared = forward_sgl_prepare(attn, positions, hidden_states, **model_kwargs) - from atom.utils.forward_context import get_forward_context - - if get_forward_context().context.is_dummy_run: - base_hidden_states = ( - hidden_states[0] if isinstance(hidden_states, tuple) else hidden_states - ) - dummy_output = base_hidden_states.new_empty( - (base_hidden_states.shape[0], base_hidden_states.shape[-1]) - ) - return dummy_output - return forward_sgl_core(attn, prepared) - - def _get_sglang_radix_attn(attn_module): return attn_module.attn if hasattr(attn_module, "attn") else attn_module -def _concat_mha_k_for_sgl_mha( +def _concat_mha_k_for_non_absorbed( attn: DeepseekV2MLAAttention, k_nope: torch.Tensor, k_pe: torch.Tensor, @@ -664,7 +368,7 @@ def _concat_mha_k_for_sgl_mha( except ImportError as exc: logger.warning( "Unable to import concat_and_cast_mha_k_triton; " - "falling back to torch native MHA K concat: %s", + "falling back to torch native non-absorbed K concat: %s", exc, ) else: @@ -676,58 +380,32 @@ def _concat_mha_k_for_sgl_mha( return k -def _set_mla_kv_buffer_for_mha( +def _set_mla_kv_buffer_for_non_absorbed( attn: DeepseekV2MLAAttention, kv_a: torch.Tensor, k_pe: torch.Tensor, forward_batch, ) -> None: - attn_mha = _get_sglang_radix_attn(attn.attn_mha) - - kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(attn_mha.layer_id) - concat_and_cache_mla( - kv_a, - k_pe.squeeze(1), - kv_cache, - forward_batch.out_cache_loc.flatten(), - kv_cache_dtype=( - "fp8" if str(attn.kv_cache_dtype).startswith("fp8") else "auto" - ), - scale=attn_mha.k_scale, + attn_non_absorbed = _get_sglang_radix_attn(attn.attn_non_absorbed) + cache_k = torch.cat([kv_a.unsqueeze(1), k_pe], dim=-1) + forward_batch.token_to_kv_pool.set_kv_buffer( + attn_non_absorbed, + forward_batch.out_cache_loc, + cache_k, + cache_k, ) -def _is_mxfp4_kv_b_proj(attn: DeepseekV2MLAAttention) -> bool: - kv_b_proj = attn.kv_b_proj - params_dtype = getattr(kv_b_proj, "params_dtype", None) - if params_dtype == dtypes.fp4x2 or params_dtype == getattr( - torch, "float4_e2m1fn_x2", None - ): - return True - - quant_type = getattr(kv_b_proj, "quant_type", None) - if getattr(quant_type, "name", "") == "per_1x32" or str(quant_type).endswith( - "per_1x32" - ): - return True - - quant_method = getattr(kv_b_proj, "quant_method", None) - quant_config = getattr(quant_method, "quant_config", None) - return bool( - quant_config is not None - and quant_config.get_name() == "quark" - and kv_b_proj.weight.dtype == torch.uint8 - ) - - -def _can_run_sgl_mha_now(attn: DeepseekV2MLAAttention, forward_batch) -> bool: - """Check if the model configuration supports the MHA attention path. +def _can_run_non_absorbed_mla_now( + attn: DeepseekV2MLAAttention, + forward_batch, +) -> bool: + """Check if the model configuration supports the non-absorbed MLA path. - This is a *capability* gate — NSA models cannot use the MHA path. - MXFP4 ``kv_b_proj`` weights are supported here because the MHA prepare - path expands K/V through ``attn.kv_b_proj`` itself, which already owns - the per_1x32 GEMM implementation. Distinct from - ``_dispatch_sgl_plugin_attn_path`` which routes each batch. + This is a capability gate. NSA models cannot use the non-absorbed path. + MXFP4 ``kv_b_proj`` weights are supported because that path expands K/V + through ``attn.kv_b_proj`` itself, which owns the per_1x32 GEMM + implementation. """ del forward_batch if attn.use_nsa: @@ -737,235 +415,6 @@ def _can_run_sgl_mha_now(attn: DeepseekV2MLAAttention, forward_batch) -> bool: return True -def forward_sgl_mha_prepare( - attn: DeepseekV2MLAAttention, - positions: torch.Tensor, - hidden_states: torch.Tensor, - **model_kwargs, -) -> SglMhaPrepareResult: - - forward_batch = model_kwargs.get("forward_batch", None) - if forward_batch is None: - raise RuntimeError("forward_batch is required in forward_sgl_mha_prepare") - - hidden_states_scale = None - if isinstance(hidden_states, tuple): - hidden_states, hidden_states_scale = hidden_states - - attn_mha = _get_sglang_radix_attn(attn.attn_mha) - if getattr(attn_mha, "kv_b_proj", None) is None: - attn_mha.kv_b_proj = attn.kv_b_proj - - if attn.q_lora_rank is not None: - q, latent_cache = ( - get_attn_tp_context() - .fetch_qkv_latent() - .split( - [attn.q_lora_rank, attn.kv_lora_rank + attn.qk_rope_head_dim], - dim=-1, - ) - ) - - if ( - q.shape[0] != positions.shape[0] - and get_tensor_model_parallel_world_size() > 1 - ): - qkv_lora = torch.cat([q, latent_cache], dim=-1) - qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) - if qkv_lora.shape[0] < positions.shape[0]: - raise RuntimeError( - f"qkv_lora gather mismatch: got {qkv_lora.shape[0]}, " - f"expected {positions.shape[0]}" - ) - qkv_lora = qkv_lora[: positions.shape[0]] - q, latent_cache = torch.split( - qkv_lora, - [attn.q_lora_rank, attn.kv_lora_rank + attn.qk_rope_head_dim], - dim=-1, - ) - - if _use_aiter_gfx95 and attn.q_b_proj.weight.dtype == torch.float8_e4m3fn: - (q, q_scale), _, _, _ = _fuse_rmsnorm_quant( - q, - attn.q_a_layernorm.weight, - attn.q_a_layernorm.eps, - None, - None, - None, - res1=None, - dtype_quant=torch.float8_e4m3fn, - group_size=128, - quant_type=_linear_quant_type_value(attn.q_b_proj), - output_unquantized_inp1=False, - transpose_scale=True, - ) - q = _unwrap_linear_output(attn.q_b_proj(q, q_scale)).view( - -1, attn.num_local_heads, attn.qk_head_dim - ) - else: - q = attn.q_a_layernorm(q) - q = _unwrap_linear_output(attn.q_b_proj(q)).view( - -1, attn.num_local_heads, attn.qk_head_dim - ) - else: - q = _unwrap_linear_output(attn.q_proj(hidden_states, hidden_states_scale)).view( - -1, attn.num_local_heads, attn.qk_head_dim - ) - latent_cache = _unwrap_linear_output( - attn.kv_a_proj_with_mqa(hidden_states, hidden_states_scale) - ) - - _, q_pe = q.split([attn.qk_nope_head_dim, attn.qk_rope_head_dim], dim=-1) - kv_a, _ = latent_cache.split([attn.kv_lora_rank, attn.qk_rope_head_dim], dim=-1) - latent_cache = latent_cache.unsqueeze(1) - - if _use_aiter_gfx95 and attn.kv_b_proj.weight.dtype == torch.float8_e4m3fn: - (kv_a_quanted, kv_a_quanted_scale), kv_a, _, _ = _fuse_rmsnorm_quant( - kv_a, - attn.kv_a_layernorm.weight, - attn.kv_a_layernorm.eps, - None, - None, - None, - res1=None, - dtype_quant=torch.float8_e4m3fn, - group_size=128, - quant_type=_linear_quant_type_value(attn.kv_b_proj), - output_unquantized_inp1=True, - transpose_scale=True, - ) - else: - kv_a_quanted = None - kv_a = attn.kv_a_layernorm(kv_a) - - k_pe = latent_cache[:, :, attn.kv_lora_rank :] - if attn.rotary_emb is not None: - q_pe, k_pe = attn.rotary_emb(positions, q_pe, k_pe) - q[..., attn.qk_nope_head_dim :] = q_pe - - _set_mla_kv_buffer_for_mha(attn, kv_a, k_pe, forward_batch) - - if kv_a_quanted is not None: - kv = _unwrap_linear_output(attn.kv_b_proj(kv_a_quanted, kv_a_quanted_scale)) - else: - kv = _unwrap_linear_output(attn.kv_b_proj(kv_a)) - kv = kv.view(-1, attn.num_local_heads, attn.qk_nope_head_dim + attn.v_head_dim) - k_nope = kv[..., : attn.qk_nope_head_dim] - v = kv[..., attn.qk_nope_head_dim :] - k = _concat_mha_k_for_sgl_mha(attn, k_nope, k_pe) - return SglMhaPrepareResult(q=q, k=k, v=v, forward_batch=forward_batch) - - -def forward_sgl_mha_core( - attn: DeepseekV2MLAAttention, - prepared: SglMhaPrepareResult, -) -> torch.Tensor: - attn_output = attn.attn_mha( - prepared.q, - prepared.k, - prepared.v, - forward_batch=prepared.forward_batch, - save_kv_cache=False, - ) - attn_output = attn_output.reshape(-1, attn.num_local_heads * attn.v_head_dim) - return attn.o_proj(attn_output) - - -def forward_sgl_plugin_mode_mha( - attn: DeepseekV2MLAAttention, - positions: torch.Tensor, - hidden_states: torch.Tensor, - **model_kwargs, -) -> torch.Tensor: - forward_batch = model_kwargs.get("forward_batch", None) - if forward_batch is None: - raise RuntimeError("forward_batch is required in forward_sgl_plugin_mode_mha") - if not _can_run_sgl_mha_now(attn, forward_batch): - attn.current_sgl_plugin_attn_path = "mla_fallback" - return forward_sgl_plugin_mode_mla( - attn, - positions, - hidden_states, - **model_kwargs, - ) - prepared = forward_sgl_mha_prepare(attn, positions, hidden_states, **model_kwargs) - return forward_sgl_mha_core(attn, prepared) - - -def prepare_qkv_latent( - attn: DeepseekV2MLAAttention, - hidden_states: torch.Tensor, - forward_batch, -) -> torch.Tensor: - """Prepare QKV latent tensor for the sglang communicator.""" - assert attn.q_lora_rank is not None - hidden_states_scale = None - if isinstance(hidden_states, tuple): - hidden_states, hidden_states_scale = hidden_states - qkv_lora = attn.fused_qkv_a_proj(hidden_states, hidden_states_scale) - - expected_tokens = 0 - if hasattr(forward_batch, "positions") and forward_batch.positions is not None: - expected_tokens = int(forward_batch.positions.shape[0]) - if expected_tokens <= 0: - expected_tokens = int(getattr(forward_batch, "seq_lens_sum", 0) or 0) - - if ( - expected_tokens > 0 - and qkv_lora.shape[0] != expected_tokens - and get_tensor_model_parallel_world_size() > 1 - ): - qkv_lora = get_tp_group().all_gather(qkv_lora, dim=0) - if qkv_lora.shape[0] > expected_tokens: - qkv_lora = qkv_lora[:expected_tokens] - elif qkv_lora.shape[0] < expected_tokens: - raise RuntimeError( - f"prepare_qkv_latent gather mismatch: got {qkv_lora.shape[0]}, " - f"expected {expected_tokens}" - ) - return qkv_lora - - -def forward_sgl_plugin_mode( - attn: DeepseekV2MLAAttention, - positions: torch.Tensor, - hidden_states: torch.Tensor, - **model_kwargs, -) -> torch.Tensor: - """Full MLA forward in sglang plugin mode.""" - forward_batch = model_kwargs.get("forward_batch", None) - if forward_batch is None: - raise RuntimeError("forward_batch is required in forward_sgl_plugin_mode") - - attn_tp_context = get_attn_tp_context() - with attn_tp_context.maybe_input_scattered(forward_batch): - if attn.q_lora_rank is not None: - attn_tp_context.set_attn_inputs( - AttentionInputs( - hidden_states, - forward_batch, - lambda hs, fb: prepare_qkv_latent(attn, hs, fb), - ) - ) - attn_path = _dispatch_sgl_plugin_attn_path(forward_batch) - attn.current_sgl_plugin_attn_path = attn_path - if attn_path == "mha": - return forward_sgl_plugin_mode_mha( - attn, - positions, - hidden_states, - **model_kwargs, - ) - if attn_path == "mla": - return forward_sgl_plugin_mode_mla( - attn, - positions, - hidden_states, - **model_kwargs, - ) - raise ValueError(f"Unsupported plugin attention path: {attn_path}") - - def _read_kv_b_proj_weight(attn: DeepseekV2MLAAttention) -> torch.Tensor: """Read kv_b_proj weight, handling AWQ and fnuz dtypes.""" if hasattr(attn.kv_b_proj, "qweight"): From 128b37cc42f9d1ba74c49fee0255be372f6d10d8 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Mon, 18 May 2026 15:25:41 +0000 Subject: [PATCH 05/17] [ATOM SGL] runtime extraction --- .../sglang/attention_backend/attention_gdn.py | 2 +- .../full_attention/radix_attention.py | 2 +- .../sglang/models/base_model_wrapper.py | 453 +++--------------- .../sglang/models/deepseek_mla_attention.py | 2 +- .../sglang/models/deepseek_nextn_wrapper.py | 71 +-- atom/plugin/sglang/models/qwen3_5.py | 4 +- atom/plugin/sglang/runtime/__init__.py | 20 + atom/plugin/sglang/runtime/context.py | 126 +++++ atom/plugin/sglang/runtime/forward_context.py | 234 +++++++++ atom/plugin/sglang/runtime/model_arch.py | 22 + 10 files changed, 491 insertions(+), 445 deletions(-) create mode 100644 atom/plugin/sglang/runtime/__init__.py create mode 100644 atom/plugin/sglang/runtime/context.py create mode 100644 atom/plugin/sglang/runtime/forward_context.py create mode 100644 atom/plugin/sglang/runtime/model_arch.py diff --git a/atom/plugin/sglang/attention_backend/attention_gdn.py b/atom/plugin/sglang/attention_backend/attention_gdn.py index 2403efaa9e..f29a43b1b7 100644 --- a/atom/plugin/sglang/attention_backend/attention_gdn.py +++ b/atom/plugin/sglang/attention_backend/attention_gdn.py @@ -164,7 +164,7 @@ def _build_gdn_metadata( def build( cls, forward_batch_or_metadata: Any ) -> Optional["SGLangGDNForwardContext"]: - from atom.plugin.sglang.models.base_model_wrapper import ( + from atom.plugin.sglang.runtime import ( SGLangForwardBatchMetadata, ) diff --git a/atom/plugin/sglang/attention_backend/full_attention/radix_attention.py b/atom/plugin/sglang/attention_backend/full_attention/radix_attention.py index 8e20eecb14..613458ee37 100644 --- a/atom/plugin/sglang/attention_backend/full_attention/radix_attention.py +++ b/atom/plugin/sglang/attention_backend/full_attention/radix_attention.py @@ -129,7 +129,7 @@ def forward_impl_plugin_mode( # for sglang, forward_batch is required forward_batch = kwargs.get("forward_batch", None) if forward_batch is None: - from atom.plugin.sglang.models.base_model_wrapper import ( + from atom.plugin.sglang.runtime import ( get_current_forward_batch, ) diff --git a/atom/plugin/sglang/models/base_model_wrapper.py b/atom/plugin/sglang/models/base_model_wrapper.py index dbace2e647..85bc6608e1 100644 --- a/atom/plugin/sglang/models/base_model_wrapper.py +++ b/atom/plugin/sglang/models/base_model_wrapper.py @@ -6,13 +6,8 @@ To add a new model, append its architecture class name to _MODEL_NAMES. """ -import copy - import logging -from contextlib import contextmanager -from contextvars import ContextVar -from dataclasses import dataclass -from typing import Any, ClassVar, Iterable, Optional, Tuple, Union +from typing import Any, Iterable, Optional, Tuple, Union import torch from torch import nn @@ -22,312 +17,26 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -logger = logging.getLogger("atom.plugin.sglang.models") - -_RUNTIME_SENTINEL = object() - -# Context for patched DeepSeek attention layers that need wrapper state without -# changing every intermediate forward signature. ContextVar keeps nested or -# concurrent forwards isolated and lets us reliably restore the prior value. -_current_forward_batch: ContextVar[Optional[ForwardBatch]] = ContextVar( - "atom_sglang_current_forward_batch", default=None +from atom.plugin.sglang.runtime import ( + MODEL_ARCH_SPECS, + SGLangForwardBatchMetadata, + SGLangPluginRuntime, + bind_current_forward_batch, + get_current_forward_batch, + get_model_arch_spec, + plugin_runtime_scope, ) +logger = logging.getLogger("atom.plugin.sglang.models") -def get_current_forward_batch(): - return _current_forward_batch.get() - - -def _is_dummy_forward(forward_batch: ForwardBatch) -> bool: - # SGLang's IDLE batch is the plugin-side equivalent of ATOM dummy run. - forward_mode = getattr(forward_batch, "forward_mode", None) - return bool( - forward_mode is not None - and hasattr(forward_mode, "is_idle") - and forward_mode.is_idle() - ) - - -def _pad_dummy_like( - tensor: Optional[torch.Tensor], - *, - length: int, - fill_value: int | float = 0, -) -> Optional[torch.Tensor]: - if tensor is None: - return None - shape = (length, *tensor.shape[1:]) - return torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device) - - -def _materialize_atom_dummy_forward( - input_ids: Optional[torch.Tensor], - positions: Optional[torch.Tensor], - input_embeds: Optional[torch.Tensor], - forward_batch: ForwardBatch, -) -> tuple[ - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - ForwardBatch, -]: - """Convert an empty SGLang IDLE batch into ATOM-style dummy forward inputs.""" - dummy_positions = positions.new_zeros((1,)) - dummy_input_ids = input_ids.new_zeros((1,)) - dummy_input_embeds = _pad_dummy_like(input_embeds, length=1, fill_value=0) - - model_forward_batch = copy.copy(forward_batch) - model_forward_batch.positions = dummy_positions - model_forward_batch.batch_size = 1 - model_forward_batch.seq_lens_sum = 1 - model_forward_batch.seq_lens = forward_batch.seq_lens.new_ones((1,)) - model_forward_batch.seq_lens_cpu = forward_batch.seq_lens_cpu.new_ones((1,)) - - return dummy_input_ids, dummy_positions, dummy_input_embeds, model_forward_batch - - -def _trim_hidden_states_for_output(hidden_states, num_tokens: int): - if torch.is_tensor(hidden_states): - return hidden_states[:num_tokens] - if isinstance(hidden_states, tuple): - return tuple( - tensor[:num_tokens] if torch.is_tensor(tensor) else tensor - for tensor in hidden_states - ) - return hidden_states - - -def _resolve_num_tokens_across_dp( - atom_config: Any, - forward_batch: ForwardBatch, - num_tokens: int, - is_dummy_run: bool, -) -> torch.Tensor: - """Resolve per-DP token counts for ATOM's CPU-side DPMetadata. - - Real SGLang dp-attention batches carry ``global_num_tokens_cpu`` from the - scheduler. That list is the source of truth for mixed prefill/decode/idle - batches, where token counts may look like [8, 1, 8, 8]. - - Some SGLang synthetic/static batches, especially CUDA graph capture batches, - only keep the global token buffer on GPU. ATOM's DPMetadata is CPU-side and - needs a CPU tensor before model forward, so avoid reading the GPU buffer back - to CPU. We only fallback when the batch advertises the same-shape DP buffer - layout (global_dp_buffer_len == local_num_tokens * dp_size), where the CPU - equivalent is exactly [local_num_tokens] * dp_size. - - IDLE batches are reported by SGLang as 0 tokens on the current rank, but - this wrapper materializes them as one local dummy token before entering - ATOM. Patch the current DP rank after resolving the distribution so - ``DPMetadata`` sees a local count that matches the actual ATOM input. - """ - global_num_tokens_cpu = getattr(forward_batch, "global_num_tokens_cpu", None) - if global_num_tokens_cpu is not None: - num_tokens_across_dp = torch.tensor( - global_num_tokens_cpu, dtype=torch.int32, device="cpu" - ) - else: - dp_size = atom_config.parallel_config.data_parallel_size - global_num_tokens_gpu = getattr(forward_batch, "global_num_tokens_gpu", None) - global_dp_buffer_len = getattr(forward_batch, "global_dp_buffer_len", None) - is_static_same_shape_batch = ( - global_num_tokens_gpu is not None - and global_dp_buffer_len == num_tokens * dp_size - ) - if not is_static_same_shape_batch: - raise RuntimeError( - "[SGL+ATOM] SGLang dp-attention requires " - "forward_batch.global_num_tokens_cpu unless the batch uses static " - "same-shape DP metadata." - ) - - # Static batches, such as CUDA graph capture batches, may only keep - # global token counts on GPU. Avoid GPU-to-CPU reads here and mirror - # their same-shape layout directly for ATOM's CPU DPMetadata. - num_tokens_across_dp = torch.full( - (dp_size,), num_tokens, dtype=torch.int32, device="cpu" - ) - - if is_dummy_run: - # SGLang reports idle ranks as 0 tokens, but ATOM materializes them - # as one local dummy token so collectives and DPMetadata stay aligned. - dp_rank = atom_config.parallel_config.data_parallel_rank - num_tokens_across_dp[dp_rank] = num_tokens - return num_tokens_across_dp - - -def _set_sglang_forward_context( - atom_config: Any, - forward_batch: ForwardBatch, - positions: torch.Tensor, -) -> None: - """Bridge SGLang batch metadata into ATOM's global forward context.""" - from atom.utils.forward_context import ( - AttentionMetaData, - Context, - set_forward_context, - ) - - forward_mode = forward_batch.forward_mode - # TODO: This max_seqlen_q is not the source of truth for prefill attention; - # SGLang plugin attention consumes forward_batch.attn_backend.forward_metadata - # directly. In this wrapper it is only needed by ATOM MoE padding: under - # dp-attention + TP (non-EP all_gather/reduce_scatter), decode/idle batches - # must use 1 so pad_for_all_gather keeps fixed-shape collectives aligned. - # Leaving it as 0 there can make active and dummy ranks send different - # shapes to DP all_gather and hang. - max_seqlen_q = 1 if forward_mode.is_decode_or_idle() else 0 - attn_metadata = AttentionMetaData(max_seqlen_q=max_seqlen_q) - batch_size = int(forward_batch.batch_size) - is_dummy_run = _is_dummy_forward(forward_batch) - is_prefill = forward_mode.is_prefill() - num_tokens = int(positions.shape[0]) - - enable_dp_attention = bool(atom_config.enable_dp_attention) - if enable_dp_attention: - # SGLang owns the cross-DP token distribution under dp-attention; ATOM - # uses it to derive graph_bs and fixed-size MoE gather/scatter buffers. - num_tokens_across_dp = _resolve_num_tokens_across_dp( - atom_config, forward_batch, num_tokens, is_dummy_run - ) - graph_bs = int(torch.max(num_tokens_across_dp).item()) - else: - # Without dp-attention, ATOM runs with local-rank shapes only. There is - # no cross-DP token distribution to pass into DPMetadata, so graph_bs - # follows the local prefill token count or decode batch size. - num_tokens_across_dp = None - graph_bs = num_tokens if is_prefill else batch_size - context = Context( - positions=positions, - is_prefill=is_prefill, - is_dummy_run=is_dummy_run, - batch_size=batch_size, - graph_bs=graph_bs, - ) - set_forward_context( - attn_metadata=attn_metadata, - atom_config=atom_config, - context=context, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - ) - - -def _reset_sglang_forward_context() -> None: - from atom.utils.forward_context import reset_forward_context - - reset_forward_context() - - -@contextmanager -def plugin_runtime_scope( - *, - framework: Optional[str] = None, - atom_config: Any = _RUNTIME_SENTINEL, -): - """Temporarily bind plugin runtime globals to one wrapper instance. - - ATOM core currently relies on process-global framework/config state. In - SGLang speculative mode both target and draft wrappers coexist, so plugin - entrypoints must save/restore those globals around each init/load/forward. - """ - - import atom.config as atom_config_module - import atom.plugin.prepare as plugin_prepare - - prev_framework = plugin_prepare._CURRENT_FRAMEWORK - prev_atom_config = getattr(atom_config_module, "_current_atom_config", None) - - if framework is not None: - plugin_prepare._set_framework_backbone(framework) - if atom_config is not _RUNTIME_SENTINEL: - atom_config_module._current_atom_config = atom_config - - try: - yield - finally: - plugin_prepare._CURRENT_FRAMEWORK = prev_framework - atom_config_module._current_atom_config = prev_atom_config - - -@dataclass(frozen=True) -class SGLangForwardBatchMetadata: - """Small context object for one SGLang model forward.""" - - forward_batch: Optional[ForwardBatch] - pp_proxy_tensors: Optional[PPProxyTensors] = None - save_kv_cache: bool = True - _current: ClassVar[ContextVar[Optional["SGLangForwardBatchMetadata"]]] = ContextVar( - "atom_sglang_current_forward_batch_metadata", - default=None, - ) - - @classmethod - def current(cls) -> Optional["SGLangForwardBatchMetadata"]: - return cls._current.get() - - @classmethod - def build( - cls, - forward_batch: Optional[ - Union[ForwardBatch, "SGLangForwardBatchMetadata"] - ] = None, - *, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - save_kv_cache: Optional[bool] = None, - ) -> Optional["SGLangForwardBatchMetadata"]: - if isinstance(forward_batch, cls): - return forward_batch - if forward_batch is None and pp_proxy_tensors is None and save_kv_cache is None: - return cls.current() - return cls( - forward_batch=forward_batch, - pp_proxy_tensors=pp_proxy_tensors, - save_kv_cache=True if save_kv_cache is None else save_kv_cache, - ) - - @classmethod - @contextmanager - def bind(cls, metadata: Optional["SGLangForwardBatchMetadata"]): - meta_token = cls._current.set(metadata) - batch_token = _current_forward_batch.set( - None if metadata is None else metadata.forward_batch - ) - try: - yield metadata - finally: - _current_forward_batch.reset(batch_token) - cls._current.reset(meta_token) - - @staticmethod - def to_intermediate_tensors( - intermediate_tensors, - metadata: Optional["SGLangForwardBatchMetadata"], - ): - if intermediate_tensors is not None or metadata is None: - return intermediate_tensors - pp_proxy_tensors = metadata.pp_proxy_tensors - if pp_proxy_tensors is None: - return intermediate_tensors - tensors = getattr(pp_proxy_tensors, "tensors", None) - if tensors is None: - return intermediate_tensors - from atom.models.utils import IntermediateTensors - - return IntermediateTensors(dict(tensors)) - - -@dataclass(frozen=True) -class ModelArchSpec: - wrapper_binds_gdn_context: bool = False - apply_deepseek_patch: bool = False - - -_MODEL_ARCH_SPECS = { - "DeepseekV3ForCausalLM": ModelArchSpec(apply_deepseek_patch=True), - "Qwen3MoeForCausalLM": ModelArchSpec(), - "Qwen3NextForCausalLM": ModelArchSpec(wrapper_binds_gdn_context=True), -} +__all__ = [ + "EntryClass", + "SGLangForwardBatchMetadata", + "SGLangPluginRuntime", + "bind_current_forward_batch", + "get_current_forward_batch", + "plugin_runtime_scope", +] class _AtomCausalLMBaseForSglang(nn.Module): @@ -353,7 +62,7 @@ def __init__( self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size self.model_arch = getattr(config, "architectures", [""])[0] - self.model_arch_spec = _MODEL_ARCH_SPECS.get(self.model_arch, ModelArchSpec()) + self.model_arch_spec = get_model_arch_spec(self.model_arch) import atom @@ -449,94 +158,60 @@ def forward( **model_kwargs: Any, ) -> Union[LogitsProcessorOutput, PPProxyTensors]: with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): - metadata = SGLangForwardBatchMetadata.build( - forward_batch, - pp_proxy_tensors=pp_proxy_tensors, - save_kv_cache=model_kwargs.get("save_kv_cache"), - ) - - if _is_dummy_forward(forward_batch): - ( - model_input_ids, - model_positions, - model_input_embeds, - model_forward_batch, - ) = _materialize_atom_dummy_forward( - input_ids, - positions, - input_embeds, - forward_batch, + with SGLangPluginRuntime( + atom_config=self.atom_config, + forward_batch=forward_batch, + positions=positions, + input_ids=input_ids, + input_embeds=input_embeds, + set_forward_context=not self.model_arch_spec.wrapper_binds_gdn_context, + ) as runtime: + metadata = SGLangForwardBatchMetadata.build( + runtime.forward_batch, + pp_proxy_tensors=pp_proxy_tensors, + save_kv_cache=model_kwargs.get("save_kv_cache"), ) - else: - ( - model_input_ids, - model_positions, - model_input_embeds, - model_forward_batch, - ) = ( - input_ids, - positions, - input_embeds, - forward_batch, + model_inputs = dict( + input_ids=runtime.input_ids, + positions=runtime.positions, + intermediate_tensors=SGLangForwardBatchMetadata.to_intermediate_tensors( + pp_proxy_tensors, metadata + ), + inputs_embeds=runtime.input_embeds, ) - - model_inputs = dict( - input_ids=model_input_ids, - positions=model_positions, - intermediate_tensors=SGLangForwardBatchMetadata.to_intermediate_tensors( - pp_proxy_tensors, metadata - ), - inputs_embeds=model_input_embeds, - ) - uses_context_only_forward = ( - self.model_arch_spec.apply_deepseek_patch - or self.model_arch_spec.wrapper_binds_gdn_context - ) - with SGLangForwardBatchMetadata.bind(metadata): - if self.model_arch_spec.wrapper_binds_gdn_context: - from atom.plugin.sglang.attention_backend.attention_gdn import ( - SGLangGDNForwardContext, - ) - - with SGLangGDNForwardContext.bind(metadata): - hidden_states = self.model(**model_inputs) - elif uses_context_only_forward: - try: - _set_sglang_forward_context( - self.atom_config, model_forward_batch, model_positions + uses_context_only_forward = ( + self.model_arch_spec.apply_deepseek_patch + or self.model_arch_spec.wrapper_binds_gdn_context + ) + with SGLangForwardBatchMetadata.bind(metadata): + if self.model_arch_spec.wrapper_binds_gdn_context: + from atom.plugin.sglang.attention_backend.attention_gdn import ( + SGLangGDNForwardContext, ) + + with SGLangGDNForwardContext.bind(metadata): + hidden_states = self.model(**model_inputs) + elif uses_context_only_forward: hidden_states = self.model(**model_inputs) - finally: - _reset_sglang_forward_context() - else: - try: - _set_sglang_forward_context( - self.atom_config, model_forward_batch, model_positions - ) + else: hidden_states = self.model( **model_inputs, - forward_batch=model_forward_batch, + forward_batch=runtime.forward_batch, get_embedding=get_embedding, pp_proxy_tensors=pp_proxy_tensors, **model_kwargs, ) - finally: - _reset_sglang_forward_context() - - if self.pp_group.is_last_rank: - if _is_dummy_forward(forward_batch): - # TODO: Revisit if SGLang ever sends non-empty dummy batches. - # Today this path only runs when an empty IDLE batch is expanded - # to one ATOM dummy token, so the output boundary must trim back to - # the original SGLang-visible length: 0 tokens. - hidden_states = _trim_hidden_states_for_output(hidden_states, 0) - return self.logits_processor( - input_ids, - hidden_states, - self.model.lm_head, - forward_batch, - ) - return hidden_states + + hidden_states = runtime.trim_output(hidden_states) + + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, + hidden_states, + self.model.lm_head, + forward_batch, + ) + return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # The passed `weights` iterable from sglang is ignored because ATOM @@ -552,7 +227,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): EntryClass = [] -for _name in _MODEL_ARCH_SPECS: +for _name in MODEL_ARCH_SPECS: _cls = type(_name, (_AtomCausalLMBaseForSglang,), {}) globals()[_name] = _cls EntryClass.append(_cls) diff --git a/atom/plugin/sglang/models/deepseek_mla_attention.py b/atom/plugin/sglang/models/deepseek_mla_attention.py index ce20aae418..65a240efd2 100644 --- a/atom/plugin/sglang/models/deepseek_mla_attention.py +++ b/atom/plugin/sglang/models/deepseek_mla_attention.py @@ -38,7 +38,7 @@ def attn(self): def _get_forward_batch(self, kwargs: dict[str, Any]): forward_batch = kwargs.get("forward_batch", None) if forward_batch is None: - from atom.plugin.sglang.models.base_model_wrapper import ( + from atom.plugin.sglang.runtime import ( get_current_forward_batch, ) diff --git a/atom/plugin/sglang/models/deepseek_nextn_wrapper.py b/atom/plugin/sglang/models/deepseek_nextn_wrapper.py index 7ad85048c1..9b343a04d8 100644 --- a/atom/plugin/sglang/models/deepseek_nextn_wrapper.py +++ b/atom/plugin/sglang/models/deepseek_nextn_wrapper.py @@ -19,16 +19,11 @@ from atom.config import SpeculativeConfig from atom.plugin.config import generate_atom_config_for_plugin_mode -from atom.plugin.sglang.attention_backend.sgl_attention_mla import ( +from atom.plugin.sglang.models.deepseek_mla import ( setup_deepseek_for_sglang, ) -from atom.plugin.sglang.models.base_model_wrapper import ( - _current_forward_batch, - _is_dummy_forward, - _materialize_atom_dummy_forward, - _reset_sglang_forward_context, - _set_sglang_forward_context, - _trim_hidden_states_for_output, +from atom.plugin.sglang.runtime import ( + SGLangPluginRuntime, plugin_runtime_scope, ) @@ -78,7 +73,7 @@ def _retag_mtp_runtime_layer_ids(model: nn.Module) -> None: _set_runtime_layer_id(self_attn, local_layer_id) - for attr_name in ("mla_attn", "attn_mha"): + for attr_name in ("mla_attn", "attn_non_absorbed", "attn_mha"): attn_obj = getattr(self_attn, attr_name, None) if attn_obj is None: continue @@ -171,54 +166,28 @@ def forward( raise ValueError("DeepSeek MTP draft forward requires speculative info") with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): - if _is_dummy_forward(forward_batch): - ( - model_input_ids, - model_positions, - model_input_embeds, - model_forward_batch, - ) = _materialize_atom_dummy_forward( - input_ids, - positions, - input_embeds, - forward_batch, - ) - model_hidden_states = _materialize_dummy_hidden_states( - forward_batch.spec_info.hidden_states, - length=int(model_positions.shape[0]), - ) - else: - ( - model_input_ids, - model_positions, - model_input_embeds, - model_forward_batch, - ) = ( - input_ids, - positions, - input_embeds, - forward_batch, - ) + with SGLangPluginRuntime( + atom_config=self.atom_config, + forward_batch=forward_batch, + positions=positions, + input_ids=input_ids, + input_embeds=input_embeds, + ) as runtime: model_hidden_states = forward_batch.spec_info.hidden_states - - token = _current_forward_batch.set(model_forward_batch) - try: - _set_sglang_forward_context( - self.atom_config, model_forward_batch, model_positions - ) + if runtime.forward_batch is not forward_batch: + model_hidden_states = _materialize_dummy_hidden_states( + model_hidden_states, + length=int(runtime.positions.shape[0]), + ) hidden_states = self.model( - input_ids=model_input_ids, - positions=model_positions, + input_ids=runtime.input_ids, + positions=runtime.positions, hidden_states=model_hidden_states, - inputs_embeds=model_input_embeds, + inputs_embeds=runtime.input_embeds, ) - finally: - _reset_sglang_forward_context() - _current_forward_batch.reset(token) if self.pp_group.is_last_rank: - if _is_dummy_forward(forward_batch): - hidden_states = _trim_hidden_states_for_output(hidden_states, 0) + hidden_states = runtime.trim_output(hidden_states) return self.logits_processor( input_ids, hidden_states, diff --git a/atom/plugin/sglang/models/qwen3_5.py b/atom/plugin/sglang/models/qwen3_5.py index 22294b8548..c0b6b0ca0b 100644 --- a/atom/plugin/sglang/models/qwen3_5.py +++ b/atom/plugin/sglang/models/qwen3_5.py @@ -37,7 +37,7 @@ from atom.plugin.sglang.attention_backend.attention_gdn import ( SGLangGDNForwardContext, ) -from atom.plugin.sglang.models.base_model_wrapper import ( +from atom.plugin.sglang.runtime import ( SGLangForwardBatchMetadata, ) @@ -448,5 +448,5 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # SGLang discovers these multimodal wrappers from this module's `EntryClass`. # They are not covered by `base_model_wrapper.py`, whose generated entries only -# handle the plain causal-LM architectures in `_MODEL_ARCH_SPECS`. +# handle the plain causal-LM architectures in `MODEL_ARCH_SPECS`. EntryClass = [Qwen3_5ForConditionalGeneration, Qwen3_5MoeForConditionalGeneration] diff --git a/atom/plugin/sglang/runtime/__init__.py b/atom/plugin/sglang/runtime/__init__.py new file mode 100644 index 0000000000..f63890418f --- /dev/null +++ b/atom/plugin/sglang/runtime/__init__.py @@ -0,0 +1,20 @@ +"""Runtime utilities for ATOM's SGLang plugin integration.""" + +from atom.plugin.sglang.runtime.context import ( + SGLangForwardBatchMetadata, + bind_current_forward_batch, + get_current_forward_batch, + plugin_runtime_scope, +) +from atom.plugin.sglang.runtime.forward_context import SGLangPluginRuntime +from atom.plugin.sglang.runtime.model_arch import MODEL_ARCH_SPECS, get_model_arch_spec + +__all__ = [ + "MODEL_ARCH_SPECS", + "SGLangForwardBatchMetadata", + "SGLangPluginRuntime", + "bind_current_forward_batch", + "get_current_forward_batch", + "get_model_arch_spec", + "plugin_runtime_scope", +] diff --git a/atom/plugin/sglang/runtime/context.py b/atom/plugin/sglang/runtime/context.py new file mode 100644 index 0000000000..1fffae92f2 --- /dev/null +++ b/atom/plugin/sglang/runtime/context.py @@ -0,0 +1,126 @@ +"""Runtime context helpers for ATOM's SGLang plugin path.""" + +from __future__ import annotations + +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass +from typing import ClassVar, Optional, Union + +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors + +_RUNTIME_SENTINEL = object() +_current_forward_batch: ContextVar[Optional[ForwardBatch]] = ContextVar( + "atom_sglang_current_forward_batch", default=None +) + + +def get_current_forward_batch(): + return _current_forward_batch.get() + + +@contextmanager +def bind_current_forward_batch(forward_batch: Optional[ForwardBatch]): + token = _current_forward_batch.set(forward_batch) + try: + yield + finally: + _current_forward_batch.reset(token) + + +@contextmanager +def plugin_runtime_scope( + *, + framework: Optional[str] = None, + atom_config=_RUNTIME_SENTINEL, +): + """Temporarily bind process-global ATOM plugin runtime state. + + SGLang target/draft wrappers can coexist during speculative decoding, while + ATOM core still reads process-global framework/config state in some paths. + Keep those globals scoped to one wrapper call and restore them afterwards. + """ + + import atom.config as atom_config_module + import atom.plugin.prepare as plugin_prepare + + prev_framework = plugin_prepare._CURRENT_FRAMEWORK + prev_atom_config = getattr(atom_config_module, "_current_atom_config", None) + + if framework is not None: + plugin_prepare._set_framework_backbone(framework) + if atom_config is not _RUNTIME_SENTINEL: + atom_config_module._current_atom_config = atom_config + + try: + yield + finally: + plugin_prepare._CURRENT_FRAMEWORK = prev_framework + atom_config_module._current_atom_config = prev_atom_config + + +@dataclass(frozen=True) +class SGLangForwardBatchMetadata: + """Small context object for one SGLang model forward.""" + + forward_batch: Optional[ForwardBatch] + pp_proxy_tensors: Optional[PPProxyTensors] = None + save_kv_cache: bool = True + _current: ClassVar[ContextVar[Optional["SGLangForwardBatchMetadata"]]] = ContextVar( + "atom_sglang_current_forward_batch_metadata", + default=None, + ) + + @classmethod + def current(cls) -> Optional["SGLangForwardBatchMetadata"]: + return cls._current.get() + + @classmethod + def build( + cls, + forward_batch: Optional[ + Union[ForwardBatch, "SGLangForwardBatchMetadata"] + ] = None, + *, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + save_kv_cache: Optional[bool] = None, + ) -> Optional["SGLangForwardBatchMetadata"]: + if isinstance(forward_batch, cls): + return forward_batch + if forward_batch is None and pp_proxy_tensors is None and save_kv_cache is None: + return cls.current() + return cls( + forward_batch=forward_batch, + pp_proxy_tensors=pp_proxy_tensors, + save_kv_cache=True if save_kv_cache is None else save_kv_cache, + ) + + @classmethod + @contextmanager + def bind(cls, metadata: Optional["SGLangForwardBatchMetadata"]): + meta_token = cls._current.set(metadata) + batch_token = _current_forward_batch.set( + None if metadata is None else metadata.forward_batch + ) + try: + yield metadata + finally: + _current_forward_batch.reset(batch_token) + cls._current.reset(meta_token) + + @staticmethod + def to_intermediate_tensors( + intermediate_tensors, + metadata: Optional["SGLangForwardBatchMetadata"], + ): + if intermediate_tensors is not None or metadata is None: + return intermediate_tensors + pp_proxy_tensors = metadata.pp_proxy_tensors + if pp_proxy_tensors is None: + return intermediate_tensors + tensors = getattr(pp_proxy_tensors, "tensors", None) + if tensors is None: + return intermediate_tensors + from atom.models.utils import IntermediateTensors + + return IntermediateTensors(dict(tensors)) diff --git a/atom/plugin/sglang/runtime/forward_context.py b/atom/plugin/sglang/runtime/forward_context.py new file mode 100644 index 0000000000..05939cc7cb --- /dev/null +++ b/atom/plugin/sglang/runtime/forward_context.py @@ -0,0 +1,234 @@ +"""Scoped runtime adapter from SGLang batches to ATOM core.""" + +from __future__ import annotations + +import copy +from contextlib import ExitStack +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +from atom.plugin.sglang.runtime.context import bind_current_forward_batch + + +def _is_dummy_forward(forward_batch: ForwardBatch) -> bool: + """Return whether an SGLang batch represents an empty/idle dummy run.""" + + forward_mode = getattr(forward_batch, "forward_mode", None) + return bool( + forward_mode is not None + and hasattr(forward_mode, "is_idle") + and forward_mode.is_idle() + ) + + +def _pad_dummy_like( + tensor: Optional[torch.Tensor], + *, + length: int, + fill_value: int | float = 0, +) -> Optional[torch.Tensor]: + if tensor is None: + return None + shape = (length, *tensor.shape[1:]) + return torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device) + + +def _materialize_atom_dummy_forward( + input_ids: Optional[torch.Tensor], + positions: Optional[torch.Tensor], + input_embeds: Optional[torch.Tensor], + forward_batch: ForwardBatch, +) -> tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + ForwardBatch, +]: + """Convert an empty SGLang IDLE batch into ATOM-style dummy inputs.""" + + if positions is None: + raise RuntimeError("SGLang dummy forward materialization requires positions") + if input_ids is None: + raise RuntimeError("SGLang dummy forward materialization requires input_ids") + + dummy_positions = positions.new_zeros((1,)) + dummy_input_ids = input_ids.new_zeros((1,)) + dummy_input_embeds = _pad_dummy_like(input_embeds, length=1, fill_value=0) + + model_forward_batch = copy.copy(forward_batch) + model_forward_batch.positions = dummy_positions + model_forward_batch.batch_size = 1 + model_forward_batch.seq_lens_sum = 1 + model_forward_batch.seq_lens = forward_batch.seq_lens.new_ones((1,)) + model_forward_batch.seq_lens_cpu = forward_batch.seq_lens_cpu.new_ones((1,)) + + return dummy_input_ids, dummy_positions, dummy_input_embeds, model_forward_batch + + +def _trim_hidden_states_for_output(hidden_states, num_tokens: int): + if torch.is_tensor(hidden_states): + return hidden_states[:num_tokens] + if isinstance(hidden_states, tuple): + return tuple( + tensor[:num_tokens] if torch.is_tensor(tensor) else tensor + for tensor in hidden_states + ) + return hidden_states + + +def _resolve_num_tokens_across_dp( + atom_config: Any, + forward_batch: ForwardBatch, + num_tokens: int, + is_dummy_run: bool, +) -> torch.Tensor: + """Resolve per-DP token counts for ATOM's CPU-side DPMetadata.""" + + global_num_tokens_cpu = getattr(forward_batch, "global_num_tokens_cpu", None) + if global_num_tokens_cpu is not None: + num_tokens_across_dp = torch.tensor( + global_num_tokens_cpu, dtype=torch.int32, device="cpu" + ) + else: + dp_size = atom_config.parallel_config.data_parallel_size + global_num_tokens_gpu = getattr(forward_batch, "global_num_tokens_gpu", None) + global_dp_buffer_len = getattr(forward_batch, "global_dp_buffer_len", None) + is_static_same_shape_batch = ( + global_num_tokens_gpu is not None + and global_dp_buffer_len == num_tokens * dp_size + ) + if not is_static_same_shape_batch: + raise RuntimeError( + "[SGL+ATOM] SGLang dp-attention requires " + "forward_batch.global_num_tokens_cpu unless the batch uses static " + "same-shape DP metadata." + ) + + # Static batches, such as CUDA graph capture batches, may only keep + # global token counts on GPU. Avoid GPU-to-CPU reads here and mirror + # their same-shape layout directly for ATOM's CPU DPMetadata. + num_tokens_across_dp = torch.full( + (dp_size,), num_tokens, dtype=torch.int32, device="cpu" + ) + + if is_dummy_run: + # SGLang reports idle ranks as 0 tokens, but ATOM materializes them + # as one local dummy token so collectives and DPMetadata stay aligned. + dp_rank = atom_config.parallel_config.data_parallel_rank + num_tokens_across_dp[dp_rank] = num_tokens + return num_tokens_across_dp + + +def _set_atom_forward_context( + atom_config: Any, + forward_batch: ForwardBatch, + positions: torch.Tensor, +) -> None: + """Bridge SGLang batch metadata into ATOM's global forward context.""" + + from atom.utils.forward_context import ( + AttentionMetaData, + Context, + set_forward_context, + ) + + forward_mode = forward_batch.forward_mode + # This value is only used by ATOM-side MoE padding in the SGLang wrapper. + max_seqlen_q = 1 if forward_mode.is_decode_or_idle() else 0 + attn_metadata = AttentionMetaData(max_seqlen_q=max_seqlen_q) + batch_size = int(forward_batch.batch_size) + is_dummy_run = _is_dummy_forward(forward_batch) + is_prefill = forward_mode.is_prefill() + num_tokens = int(positions.shape[0]) + + if bool(atom_config.enable_dp_attention): + num_tokens_across_dp = _resolve_num_tokens_across_dp( + atom_config, forward_batch, num_tokens, is_dummy_run + ) + graph_bs = int(torch.max(num_tokens_across_dp).item()) + else: + num_tokens_across_dp = None + graph_bs = num_tokens if is_prefill else batch_size + + context = Context( + positions=positions, + is_prefill=is_prefill, + is_dummy_run=is_dummy_run, + batch_size=batch_size, + graph_bs=graph_bs, + ) + set_forward_context( + attn_metadata=attn_metadata, + atom_config=atom_config, + context=context, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + ) + + +def _reset_atom_forward_context() -> None: + from atom.utils.forward_context import reset_forward_context + + reset_forward_context() + + +@dataclass +class SGLangPluginRuntime: + """Scoped adapter for running ATOM model code under SGLang plugin runtime. + + The adapter owns the temporary translation from SGLang's ``ForwardBatch`` to + ATOM's process-local runtime state. Callers should use the normalized + ``input_ids``, ``positions``, ``input_embeds``, and ``forward_batch`` exposed + by this object while inside the context. + """ + + atom_config: Any + forward_batch: ForwardBatch + positions: torch.Tensor + input_ids: Optional[torch.Tensor] = None + input_embeds: Optional[torch.Tensor] = None + set_forward_context: bool = True + _original_forward_batch: ForwardBatch = field(init=False, repr=False) + _is_dummy_run: bool = field(init=False, default=False) + _exit_stack: ExitStack = field(init=False, repr=False) + + def __enter__(self) -> "SGLangPluginRuntime": + self._original_forward_batch = self.forward_batch + self._is_dummy_run = _is_dummy_forward(self.forward_batch) + + if self._is_dummy_run: + ( + self.input_ids, + self.positions, + self.input_embeds, + self.forward_batch, + ) = _materialize_atom_dummy_forward( + self.input_ids, + self.positions, + self.input_embeds, + self.forward_batch, + ) + + self._exit_stack = ExitStack() + self._exit_stack.enter_context(bind_current_forward_batch(self.forward_batch)) + if self.set_forward_context: + _set_atom_forward_context( + self.atom_config, + self.forward_batch, + self.positions, + ) + self._exit_stack.callback(_reset_atom_forward_context) + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self._exit_stack.close() + + def trim_output(self, hidden_states): + """Map ATOM-visible outputs back to SGLang-visible token count.""" + + if self._is_dummy_run: + return _trim_hidden_states_for_output(hidden_states, 0) + return hidden_states diff --git a/atom/plugin/sglang/runtime/model_arch.py b/atom/plugin/sglang/runtime/model_arch.py new file mode 100644 index 0000000000..1ddd3772c9 --- /dev/null +++ b/atom/plugin/sglang/runtime/model_arch.py @@ -0,0 +1,22 @@ +"""Runtime behavior flags for SGLang plugin model wrappers.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ModelArchSpec: + wrapper_binds_gdn_context: bool = False + apply_deepseek_patch: bool = False + + +MODEL_ARCH_SPECS = { + "DeepseekV3ForCausalLM": ModelArchSpec(apply_deepseek_patch=True), + "Qwen3MoeForCausalLM": ModelArchSpec(), + "Qwen3NextForCausalLM": ModelArchSpec(wrapper_binds_gdn_context=True), +} + + +def get_model_arch_spec(model_arch: str) -> ModelArchSpec: + return MODEL_ARCH_SPECS.get(model_arch, ModelArchSpec()) From 853ee3ff2b7f37c19fcfad8dacd0ed549b442659 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Wed, 20 May 2026 08:35:46 +0000 Subject: [PATCH 06/17] [ATOM-SGL][Runtime] Introduce model adapter specs Co-authored-by: Cursor --- atom/plugin/prepare.py | 12 ++--- .../sglang/models/base_model_wrapper.py | 13 ++--- atom/plugin/sglang/runtime/__init__.py | 9 +++- atom/plugin/sglang/runtime/model_arch.py | 52 +++++++++++++++---- 4 files changed, 59 insertions(+), 27 deletions(-) diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py index 3c7b722d29..bbcb369bcd 100644 --- a/atom/plugin/prepare.py +++ b/atom/plugin/prepare.py @@ -74,15 +74,11 @@ def prepare_model(config: Any, engine: str): model_cls = _ATOM_SUPPORTED_MODELS[model_arch] logger.info(f"ATOM model class for {model_arch} is {model_cls}") - if model_arch in { - "Qwen3_5ForConditionalGeneration", - "Qwen3_5MoeForConditionalGeneration", - }: - from atom.plugin.sglang.models.qwen3_5 import ( - apply_prepare_model_adaptations, - ) + from atom.plugin.sglang.runtime import get_model_arch_spec - apply_prepare_model_adaptations(atom_config, model_arch) + model_adapter = get_model_arch_spec(model_arch) + if model_adapter.prepare_config is not None: + model_adapter.prepare_config(atom_config, model_arch) register_ops_to_sglang(atom_config=atom_config) set_attn_cls() diff --git a/atom/plugin/sglang/models/base_model_wrapper.py b/atom/plugin/sglang/models/base_model_wrapper.py index 85bc6608e1..02ac6edc77 100644 --- a/atom/plugin/sglang/models/base_model_wrapper.py +++ b/atom/plugin/sglang/models/base_model_wrapper.py @@ -92,15 +92,10 @@ def __init__( config, skip_all_gather=plugin_skip_all_gather ) - # Apply ds model-specific sglang patches (attn dispatch, weight hooks, etc.) - # TODO: will remove this after sglang supports atom attention backend - if self.model_arch_spec.apply_deepseek_patch: - from atom.plugin.sglang.models.deepseek_mla import ( - setup_deepseek_for_sglang, - ) - + # Apply model-specific install-time adapters (attn dispatch, weight hooks, etc.). + if self.model_arch_spec.install_adapters is not None: with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): - setup_deepseek_for_sglang(self.model) + self.model_arch_spec.install_adapters(self.model) def get_embed_and_head(self): if hasattr(self.model, "get_embed_and_head"): @@ -180,7 +175,7 @@ def forward( inputs_embeds=runtime.input_embeds, ) uses_context_only_forward = ( - self.model_arch_spec.apply_deepseek_patch + self.model_arch_spec.install_adapters is not None or self.model_arch_spec.wrapper_binds_gdn_context ) with SGLangForwardBatchMetadata.bind(metadata): diff --git a/atom/plugin/sglang/runtime/__init__.py b/atom/plugin/sglang/runtime/__init__.py index f63890418f..fa8468c177 100644 --- a/atom/plugin/sglang/runtime/__init__.py +++ b/atom/plugin/sglang/runtime/__init__.py @@ -7,11 +7,18 @@ plugin_runtime_scope, ) from atom.plugin.sglang.runtime.forward_context import SGLangPluginRuntime -from atom.plugin.sglang.runtime.model_arch import MODEL_ARCH_SPECS, get_model_arch_spec +from atom.plugin.sglang.runtime.model_arch import ( + MODEL_ADAPTER_SPECS, + MODEL_ARCH_SPECS, + SGLangModelAdapterSpec, + get_model_arch_spec, +) __all__ = [ + "MODEL_ADAPTER_SPECS", "MODEL_ARCH_SPECS", "SGLangForwardBatchMetadata", + "SGLangModelAdapterSpec", "SGLangPluginRuntime", "bind_current_forward_batch", "get_current_forward_batch", diff --git a/atom/plugin/sglang/runtime/model_arch.py b/atom/plugin/sglang/runtime/model_arch.py index 1ddd3772c9..4765baa483 100644 --- a/atom/plugin/sglang/runtime/model_arch.py +++ b/atom/plugin/sglang/runtime/model_arch.py @@ -1,22 +1,56 @@ -"""Runtime behavior flags for SGLang plugin model wrappers.""" +"""SGLang plugin model adapter registry.""" from __future__ import annotations from dataclasses import dataclass +from typing import Any, Callable, Optional @dataclass(frozen=True) -class ModelArchSpec: +class SGLangModelAdapterSpec: + """Adapter hooks for one SGLang plugin model architecture. + + The first version keeps the existing runtime flags while adding function + hooks for config preparation and install-time model adaptation. This avoids + growing a long list of booleans in the generic wrapper as new models arrive. + """ + wrapper_binds_gdn_context: bool = False - apply_deepseek_patch: bool = False + prepare_config: Optional[Callable[[Any, str], None]] = None + install_adapters: Optional[Callable[[Any], None]] = None + + +def _prepare_qwen35_config(atom_config: Any, model_arch: str) -> None: + from atom.plugin.sglang.models.qwen3_5 import apply_prepare_model_adaptations + + apply_prepare_model_adaptations(atom_config, model_arch) -MODEL_ARCH_SPECS = { - "DeepseekV3ForCausalLM": ModelArchSpec(apply_deepseek_patch=True), - "Qwen3MoeForCausalLM": ModelArchSpec(), - "Qwen3NextForCausalLM": ModelArchSpec(wrapper_binds_gdn_context=True), +def _install_deepseek_mla_adapters(model: Any) -> None: + from atom.plugin.sglang.models.deepseek_mla import setup_deepseek_for_sglang + + setup_deepseek_for_sglang(model) + + +MODEL_ADAPTER_SPECS = { + "DeepseekV3ForCausalLM": SGLangModelAdapterSpec( + install_adapters=_install_deepseek_mla_adapters, + ), + "Qwen3MoeForCausalLM": SGLangModelAdapterSpec(), + "Qwen3NextForCausalLM": SGLangModelAdapterSpec( + wrapper_binds_gdn_context=True, + ), + "Qwen3_5ForConditionalGeneration": SGLangModelAdapterSpec( + prepare_config=_prepare_qwen35_config, + ), + "Qwen3_5MoeForConditionalGeneration": SGLangModelAdapterSpec( + prepare_config=_prepare_qwen35_config, + ), } +# Backwards-compatible alias for callers that only need generated EntryClass names. +MODEL_ARCH_SPECS = MODEL_ADAPTER_SPECS + -def get_model_arch_spec(model_arch: str) -> ModelArchSpec: - return MODEL_ARCH_SPECS.get(model_arch, ModelArchSpec()) +def get_model_arch_spec(model_arch: str) -> SGLangModelAdapterSpec: + return MODEL_ADAPTER_SPECS.get(model_arch, SGLangModelAdapterSpec()) From 6ed65b9121bcb881a7cbbc9d0fa322ce18839b3c Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Thu, 21 May 2026 02:32:48 +0000 Subject: [PATCH 07/17] [ATOM-SGL][Runtime] Keep custom wrappers out of generated entries Co-authored-by: Cursor --- atom/plugin/sglang/runtime/model_arch.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/atom/plugin/sglang/runtime/model_arch.py b/atom/plugin/sglang/runtime/model_arch.py index 4765baa483..f103cb9427 100644 --- a/atom/plugin/sglang/runtime/model_arch.py +++ b/atom/plugin/sglang/runtime/model_arch.py @@ -48,8 +48,17 @@ def _install_deepseek_mla_adapters(model: Any) -> None: ), } -# Backwards-compatible alias for callers that only need generated EntryClass names. -MODEL_ARCH_SPECS = MODEL_ADAPTER_SPECS +# Architectures whose SGLang EntryClass is generated by base_model_wrapper. +# Custom outer-wrapper modules, such as Qwen3.5 multimodal wrappers, keep their +# own EntryClass and should not appear here or SGLang will see duplicate classes. +MODEL_ARCH_SPECS = { + key: MODEL_ADAPTER_SPECS[key] + for key in ( + "DeepseekV3ForCausalLM", + "Qwen3MoeForCausalLM", + "Qwen3NextForCausalLM", + ) +} def get_model_arch_spec(model_arch: str) -> SGLangModelAdapterSpec: From 6322f7779e0bddcf477016a1da7e1cd3c223290a Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Wed, 20 May 2026 08:33:35 +0000 Subject: [PATCH 08/17] [ATOM-SGL][Attn refrac] Split full attention backend helpers Co-authored-by: Cursor --- .../full_attention/full_attention_backend.py | 321 +----------------- .../full_attention/kv_cache.py | 140 ++++++++ .../full_attention/metadata.py | 37 ++ .../full_attention/pa_metadata.py | 126 +++++++ 4 files changed, 321 insertions(+), 303 deletions(-) create mode 100644 atom/plugin/sglang/attention_backend/full_attention/kv_cache.py create mode 100644 atom/plugin/sglang/attention_backend/full_attention/metadata.py create mode 100644 atom/plugin/sglang/attention_backend/full_attention/pa_metadata.py diff --git a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py index 40fe70b9e7..b459433f1e 100644 --- a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py +++ b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py @@ -11,12 +11,9 @@ # be handled by ATOM's native backend, making sglang-specific overrides # unnecessary. -from dataclasses import dataclass from typing import TYPE_CHECKING, Optional import torch -import triton -import triton.language as tl import sglang.srt.layers.attention.aiter_backend as _sglang_aiter from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend @@ -28,6 +25,16 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import get_bool_env_var +from atom.plugin.sglang.attention_backend.full_attention.kv_cache import ( + set_kv_buffer_with_layout_shuffle as _set_kv_buffer_with_layout_shuffle, +) +from atom.plugin.sglang.attention_backend.full_attention.metadata import ForwardMetadata +from atom.plugin.sglang.attention_backend.full_attention.pa_metadata import ( + allocate_pa_metadata_buffers as _allocate_pa_metadata_buffers, + build_pa_metadata_for_decode as _build_pa_metadata_for_decode, + build_pa_metadata_for_prefill as _build_pa_metadata_for_prefill, +) + if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner @@ -70,141 +77,6 @@ except ImportError: pass - -@triton.jit -def reshape_and_cache_shuffle_kernel( - key_ptr, # [num_tokens, num_kv_heads, head_size] - value_ptr, # [num_tokens, num_kv_heads, head_size] - key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x] - value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x] - slot_mapping_ptr, # [num_tokens] - k_scale_ptr, - v_scale_ptr, - x, - k_stride0, - v_stride0, - block_size, - head_size, - num_kv_heads, - BLOCK_SIZE: tl.constexpr, - QUANT: tl.constexpr, -): - tid = tl.program_id(0) - head_id = tl.program_id(1) - offset = tl.arange(0, BLOCK_SIZE) - src_offset_k = tid * k_stride0 + head_id * head_size - src_offset_v = tid * v_stride0 + head_id * head_size - slot_id = tl.load(slot_mapping_ptr + tid) - if slot_id < 0: - return - block_id = slot_id // block_size - block_offset = slot_id % block_size - dst_offset = ( - block_id * num_kv_heads * head_size * block_size - + head_id * head_size * block_size - ) - dst_k_shuffle_offset = ( - dst_offset + offset // x * block_size * x + block_offset * x + offset % x - ) - dst_v_shuffle_offset = ( - dst_offset + block_offset // x * head_size * x + offset * x + block_offset % x - ) - k_val = tl.load(key_ptr + src_offset_k + offset) - v_val = tl.load(value_ptr + src_offset_v + offset) - if QUANT: - k_scale = tl.load(k_scale_ptr) - v_scale = tl.load(v_scale_ptr) - k_dtype = key_cache_ptr.type.element_ty - v_dtype = value_cache_ptr.type.element_ty - k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype) - v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype) - tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) - tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) - - -def reshape_and_cache_shuffle_triton( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scales: torch.Tensor, - v_scales: torch.Tensor, -): - num_tokens = slot_mapping.shape[0] - _, num_kv_heads, head_size = key.shape - num_blocks, block_size, _, _ = key_cache.shape - x = 16 // key_cache.element_size() - k_cache_template = torch.empty( - [num_blocks, num_kv_heads, head_size // x, block_size, x], - dtype=key_cache.dtype, - device="meta", - ) - v_cache_template = torch.empty( - [num_blocks, num_kv_heads, block_size // x, head_size, x], - dtype=value_cache.dtype, - device="meta", - ) - new_key_cache = key_cache.view_as(k_cache_template) - new_value_cache = value_cache.view_as(v_cache_template) - QUANT = False - if kv_cache_dtype.startswith("fp8"): - QUANT = True - grid = ( - num_tokens, - num_kv_heads, - ) - reshape_and_cache_shuffle_kernel[grid]( - key, - value, - new_key_cache, - new_value_cache, - slot_mapping, - k_scales, - v_scales, - x, - key.stride(0), - value.stride(0), - block_size, - head_size, - num_kv_heads, - BLOCK_SIZE=head_size, - QUANT=QUANT, - ) - - -@dataclass -class ForwardMetadata: - """Per-batch metadata consumed by ATOM's attention kernels (pa_fwd_asm, mla_decode_fwd, etc.).""" - - # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode - kv_indptr: Optional[torch.Tensor] - kv_indices: Optional[torch.Tensor] - qo_indptr: Optional[torch.Tensor] - kv_last_page_len: Optional[torch.Tensor] - max_q_len: Optional[int] - max_kv_len: Optional[int] - page_table: Optional[torch.Tensor] - kv_lens: Optional[torch.Tensor] - # mla - work_metadata: Optional[torch.Tensor] = None - work_info_set: Optional[torch.Tensor] = None - work_indptr: Optional[torch.Tensor] = None - reduce_indptr: Optional[torch.Tensor] = None - reduce_final_map: Optional[torch.Tensor] = None - reduce_partial_map: Optional[torch.Tensor] = None - fp8_prefill_kv_indices: Optional[torch.Tensor] = None - num_kv_splits: Optional[int] = None - run_graph: Optional[bool] = True - # PA metadata for pa_persistent_fwd (only used in decode mode, non-MLA) - pa_metadata_qo_indptr: Optional[torch.Tensor] = None - pa_metadata_pages_kv_indptr: Optional[torch.Tensor] = None - pa_metadata_kv_indices: Optional[torch.Tensor] = None - pa_metadata_context_lens: Optional[torch.Tensor] = None - pa_metadata_max_qlen: Optional[int] = None - - class ATOMAttnBackendForSgl(AiterAttnBackend): """ATOM's custom attention backend for sglang plugin mode. @@ -410,7 +282,7 @@ def _init_decode_mha(self, bs, kv_indptr, kv_indices, forward_batch): page_table, seq_lens, ) - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + _build_pa_metadata_for_decode(self, bs, tp_q_head_num=self.num_head) else: page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : @@ -826,7 +698,7 @@ def _fixup_page_table(self, forward_batch: ForwardBatch): // self.page_size ) if self.decode_using_pa_ps: - self._build_pa_metadata_for_prefill(forward_batch.batch_size) + _build_pa_metadata_for_prefill(self, forward_batch.batch_size) if ( not self.decode_using_pa_ps and self.page_size > 1 @@ -837,151 +709,6 @@ def _fixup_page_table(self, forward_batch: ForwardBatch): // self.page_size ) - def _ensure_buffer(self, name, size, dtype, zero=True): - """Allocate or reuse a pa_metadata buffer, growing if needed.""" - if self.pa_metadata_buffers is None: - self.pa_metadata_buffers = {} - size_val = size[0] if isinstance(size, (tuple, list)) else size - buf = self.pa_metadata_buffers.get(name) - needs_alloc = ( - buf is None - or buf.shape[0] < size_val - or (isinstance(size, (tuple, list)) and len(buf.shape) < len(size)) - ) - if needs_alloc: - factory = torch.zeros if zero else torch.empty - self.pa_metadata_buffers[name] = factory( - size, dtype=dtype, device=self.device - ) - elif zero: - self.pa_metadata_buffers[name].zero_() - - def _allocate_pa_metadata_buffers(self, buffer_specs): - """Allocate or reuse pa_metadata buffers. - - Args: - buffer_specs: sequence of ((size, dtype), ...) tuples from get_pa_metadata_info_v1, - in order: work_metadata_ptrs, work_indptr, work_info, - reduce_indptr, reduce_final_map, reduce_partial_map. - """ - names = [ - "work_metadata_ptrs", - "work_indptr", - "work_info", - "reduce_indptr", - "reduce_final_map", - "reduce_partial_map", - ] - zero_flags = [False, True, True, True, True, True] - for name, (size, dtype), zero in zip(names, buffer_specs, zero_flags): - self._ensure_buffer(name, size, dtype, zero=zero) - - def _build_pa_metadata_for_decode( - self, - batch_size: int, - tp_q_head_num: Optional[int] = None, - ): - """Build pa_metadata buffers for pa_persistent_fwd in decode mode. - - This method prepares all metadata buffers needed for pa_persistent_fwd kernel. - The metadata can be reused across multiple layers in the same forward pass. - - Args: - batch_size: Batch size for the current forward pass - tp_q_head_num: Number of Q heads per TP rank. If None, uses self.num_head. - """ - max_qlen = 1 - - # Use provided tp_q_head_num or default to self.num_head - if tp_q_head_num is None: - tp_q_head_num = self.num_head - - buffer_specs = get_pa_metadata_info_v1(batch_size, self.num_kv_head) - self._allocate_pa_metadata_buffers(buffer_specs) - qo_indptr = self.pa_decode_qo_indptr[: batch_size + 1] - - # Get context_lens (kv_lens is always set before calling _build_pa_metadata_for_decode) - # Note: kv_lens comes from self.seq_lens which is already int32 - context_lens = self.forward_metadata.kv_lens - - kernel_block_size = self.page_size - num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size - # Use dedicated pa_kv_indptr buffer (similar to self.kv_indptr, but for pa_persistent_fwd) - pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] - pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) - - # Convert page_table to kv_indices (block indices) using Triton kernel to avoid sync - # page_table shape: [batch_size, max_num_blocks_per_seq] - # Note: page_table comes from self.page_table which is already int32 and always set before this call - page_table = self.forward_metadata.page_table - - # Use Triton kernel to gather kv_indices from page_table (avoids high-level indexing sync) - create_flashinfer_kv_indices_triton[(batch_size,)]( - page_table, - self.pa_batch_indices[:batch_size], # [0, 1, 2, ..., batch_size-1] - num_blocks_per_seq, - pages_kv_indptr, - None, # kv_start_idx - self.pa_kv_indices, - page_table.stride(0), - ) - # Use the full buffer - pa_persistent_fwd reads only valid elements based on pages_kv_indptr - kv_indices = self.pa_kv_indices - - get_pa_metadata_v1( - seqlens_qo_indptr=qo_indptr, - pages_kv_indptr=pages_kv_indptr, - context_lens=context_lens.int(), - num_heads_per_head_k=tp_q_head_num // self.num_kv_head, - num_heads_k=self.num_kv_head, - is_causal=True, - work_metadata_ptrs=self.pa_metadata_buffers["work_metadata_ptrs"], - work_indptr=self.pa_metadata_buffers["work_indptr"], - work_info=self.pa_metadata_buffers["work_info"], - reduce_indptr=self.pa_metadata_buffers["reduce_indptr"], - reduce_final_map=self.pa_metadata_buffers["reduce_final_map"], - reduce_partial_map=self.pa_metadata_buffers["reduce_partial_map"], - kv_granularity=max(kernel_block_size, 16), - block_size=kernel_block_size, - max_seqlen_qo=max_qlen, - uni_seqlen_qo=max_qlen, - fast_mode=True, - topk=-1, - max_split_per_batch=-1, - ) - # Store computed values in ForwardMetadata for reuse in forward_decode - self.forward_metadata.pa_metadata_qo_indptr = qo_indptr - self.forward_metadata.pa_metadata_pages_kv_indptr = pages_kv_indptr - self.forward_metadata.pa_metadata_kv_indices = kv_indices - self.forward_metadata.pa_metadata_context_lens = context_lens - self.forward_metadata.pa_metadata_max_qlen = max_qlen - - def _build_pa_metadata_for_prefill(self, batch_size: int): - """Build metadata for mha_batch_prefill_func in prefill mode. - - This method prepares page-level metadata needed for mha_batch_prefill_func. - The metadata is computed once per forward pass and reused across all layers. - """ - block_size = self.page_size - context_lens = self.forward_metadata.kv_lens - num_blocks_per_seq = (context_lens + block_size - 1) // block_size - - # Page-level kv_indptr (reuse pa_kv_indptr buffer) - pages_kv_indptr = self.pa_kv_indptr[: batch_size + 1] - pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) - - # Build kv_indices from page_table using triton kernel - page_table = self.forward_metadata.page_table - create_flashinfer_kv_indices_triton[(batch_size,)]( - page_table, - self.pa_batch_indices[:batch_size], - num_blocks_per_seq, - pages_kv_indptr, - None, # kv_start_idx - self.pa_kv_indices, - page_table.stride(0), - ) - def init_cuda_graph_state( self, max_bs: int, @@ -1036,7 +763,7 @@ def init_cuda_graph_state( if self.decode_using_pa_ps and not self.use_mla: buffer_specs = get_pa_metadata_info_v1(max_bs, self.num_kv_head) - self._allocate_pa_metadata_buffers(buffer_specs) + _allocate_pa_metadata_buffers(self, buffer_specs) def _init_mla_cuda_graph_metadata(self, bs, req_pool_indices, seq_lens): """Shared MLA decode metadata setup for CUDA graph capture/replay.""" @@ -1165,7 +892,7 @@ def init_forward_metadata_capture_cuda_graph( seq_lens_persistent, ) if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + _build_pa_metadata_for_decode(self, bs, tp_q_head_num=self.num_head) elif forward_mode.is_target_verify(): qo_indptr = self.qo_indptr[: bs + 1] qo_indptr[: bs + 1] = torch.arange( @@ -1501,7 +1228,7 @@ def init_forward_metadata_replay_cuda_graph( seq_lens_persistent[:bs], ) if self.decode_using_pa_ps: - self._build_pa_metadata_for_decode(bs, tp_q_head_num=self.num_head) + _build_pa_metadata_for_decode(self, bs, tp_q_head_num=self.num_head) elif forward_mode.is_target_verify(): bs = len(req_pool_indices) qo_indptr = self.qo_indptr[: bs + 1] @@ -1879,27 +1606,15 @@ def set_kv_buffer_with_layout_shuffle( v_scale, block_size, ): - num_slots, num_kv_heads, head_dim = k_buffer.shape - num_blocks = num_slots // block_size - num_slots_with_block = num_blocks * block_size - k_buffer = k_buffer[:num_slots_with_block].view( - num_blocks, block_size, num_kv_heads, head_dim - ) - v_buffer = v_buffer[:num_slots_with_block].view( - num_blocks, block_size, num_kv_heads, head_dim - ) - kv_cache_dtype = "auto" - if k_buffer.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): - kv_cache_dtype = "fp8" - reshape_and_cache_shuffle_triton( + _set_kv_buffer_with_layout_shuffle( + cache_loc, k, v, k_buffer, v_buffer, - cache_loc, - kv_cache_dtype, k_scale, v_scale, + block_size, ) def forward_extend(self, q, k, v, layer, forward_batch, save_kv_cache=True): diff --git a/atom/plugin/sglang/attention_backend/full_attention/kv_cache.py b/atom/plugin/sglang/attention_backend/full_attention/kv_cache.py new file mode 100644 index 0000000000..00c2f0db0e --- /dev/null +++ b/atom/plugin/sglang/attention_backend/full_attention/kv_cache.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import torch +import triton +import triton.language as tl + + +@triton.jit +def reshape_and_cache_shuffle_kernel( + key_ptr, # [num_tokens, num_kv_heads, head_size] + value_ptr, # [num_tokens, num_kv_heads, head_size] + key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x] + value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x] + slot_mapping_ptr, # [num_tokens] + k_scale_ptr, + v_scale_ptr, + x, + k_stride0, + v_stride0, + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE: tl.constexpr, + QUANT: tl.constexpr, +): + tid = tl.program_id(0) + head_id = tl.program_id(1) + offset = tl.arange(0, BLOCK_SIZE) + src_offset_k = tid * k_stride0 + head_id * head_size + src_offset_v = tid * v_stride0 + head_id * head_size + slot_id = tl.load(slot_mapping_ptr + tid) + if slot_id < 0: + return + block_id = slot_id // block_size + block_offset = slot_id % block_size + dst_offset = ( + block_id * num_kv_heads * head_size * block_size + + head_id * head_size * block_size + ) + dst_k_shuffle_offset = ( + dst_offset + offset // x * block_size * x + block_offset * x + offset % x + ) + dst_v_shuffle_offset = ( + dst_offset + block_offset // x * head_size * x + offset * x + block_offset % x + ) + k_val = tl.load(key_ptr + src_offset_k + offset) + v_val = tl.load(value_ptr + src_offset_v + offset) + if QUANT: + k_scale = tl.load(k_scale_ptr) + v_scale = tl.load(v_scale_ptr) + k_dtype = key_cache_ptr.type.element_ty + v_dtype = value_cache_ptr.type.element_ty + k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype) + v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype) + tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val) + tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val) + + +def reshape_and_cache_shuffle_triton( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scales: torch.Tensor, + v_scales: torch.Tensor, +): + num_tokens = slot_mapping.shape[0] + _, num_kv_heads, head_size = key.shape + num_blocks, block_size, _, _ = key_cache.shape + x = 16 // key_cache.element_size() + k_cache_template = torch.empty( + [num_blocks, num_kv_heads, head_size // x, block_size, x], + dtype=key_cache.dtype, + device="meta", + ) + v_cache_template = torch.empty( + [num_blocks, num_kv_heads, block_size // x, head_size, x], + dtype=value_cache.dtype, + device="meta", + ) + new_key_cache = key_cache.view_as(k_cache_template) + new_value_cache = value_cache.view_as(v_cache_template) + quant = kv_cache_dtype.startswith("fp8") + grid = ( + num_tokens, + num_kv_heads, + ) + reshape_and_cache_shuffle_kernel[grid]( + key, + value, + new_key_cache, + new_value_cache, + slot_mapping, + k_scales, + v_scales, + x, + key.stride(0), + value.stride(0), + block_size, + head_size, + num_kv_heads, + BLOCK_SIZE=head_size, + QUANT=quant, + ) + + +def set_kv_buffer_with_layout_shuffle( + cache_loc, + k, + v, + k_buffer, + v_buffer, + k_scale, + v_scale, + block_size, +): + num_slots, num_kv_heads, head_dim = k_buffer.shape + num_blocks = num_slots // block_size + num_slots_with_block = num_blocks * block_size + k_buffer = k_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + v_buffer = v_buffer[:num_slots_with_block].view( + num_blocks, block_size, num_kv_heads, head_dim + ) + kv_cache_dtype = "auto" + if k_buffer.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + kv_cache_dtype = "fp8" + reshape_and_cache_shuffle_triton( + k, + v, + k_buffer, + v_buffer, + cache_loc, + kv_cache_dtype, + k_scale, + v_scale, + ) diff --git a/atom/plugin/sglang/attention_backend/full_attention/metadata.py b/atom/plugin/sglang/attention_backend/full_attention/metadata.py new file mode 100644 index 0000000000..b66feaa756 --- /dev/null +++ b/atom/plugin/sglang/attention_backend/full_attention/metadata.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class ForwardMetadata: + """Per-batch metadata consumed by SGLang full-attention backend kernels.""" + + # kv_indptr and kv_indices are only used in MLA mode, optional for non-MLA mode + kv_indptr: Optional[torch.Tensor] + kv_indices: Optional[torch.Tensor] + qo_indptr: Optional[torch.Tensor] + kv_last_page_len: Optional[torch.Tensor] + max_q_len: Optional[int] + max_kv_len: Optional[int] + page_table: Optional[torch.Tensor] + kv_lens: Optional[torch.Tensor] + # MLA metadata + work_metadata: Optional[torch.Tensor] = None + work_info_set: Optional[torch.Tensor] = None + work_indptr: Optional[torch.Tensor] = None + reduce_indptr: Optional[torch.Tensor] = None + reduce_final_map: Optional[torch.Tensor] = None + reduce_partial_map: Optional[torch.Tensor] = None + fp8_prefill_kv_indices: Optional[torch.Tensor] = None + num_kv_splits: Optional[int] = None + run_graph: Optional[bool] = True + # PA metadata for pa_persistent_fwd (only used in decode mode, non-MLA) + pa_metadata_qo_indptr: Optional[torch.Tensor] = None + pa_metadata_pages_kv_indptr: Optional[torch.Tensor] = None + pa_metadata_kv_indices: Optional[torch.Tensor] = None + pa_metadata_context_lens: Optional[torch.Tensor] = None + pa_metadata_max_qlen: Optional[int] = None diff --git a/atom/plugin/sglang/attention_backend/full_attention/pa_metadata.py b/atom/plugin/sglang/attention_backend/full_attention/pa_metadata.py new file mode 100644 index 0000000000..5ab66056e0 --- /dev/null +++ b/atom/plugin/sglang/attention_backend/full_attention/pa_metadata.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from typing import Optional + +import torch +from aiter import get_pa_metadata_info_v1, get_pa_metadata_v1 +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton + + +def _ensure_buffer(backend, name, size, dtype, zero=True): + """Allocate or reuse a pa_metadata buffer, growing if needed.""" + if backend.pa_metadata_buffers is None: + backend.pa_metadata_buffers = {} + size_val = size[0] if isinstance(size, (tuple, list)) else size + buf = backend.pa_metadata_buffers.get(name) + needs_alloc = ( + buf is None + or buf.shape[0] < size_val + or (isinstance(size, (tuple, list)) and len(buf.shape) < len(size)) + ) + if needs_alloc: + factory = torch.zeros if zero else torch.empty + backend.pa_metadata_buffers[name] = factory( + size, dtype=dtype, device=backend.device + ) + elif zero: + backend.pa_metadata_buffers[name].zero_() + + +def allocate_pa_metadata_buffers(backend, buffer_specs): + """Allocate or reuse pa_metadata buffers for the backend.""" + names = [ + "work_metadata_ptrs", + "work_indptr", + "work_info", + "reduce_indptr", + "reduce_final_map", + "reduce_partial_map", + ] + zero_flags = [False, True, True, True, True, True] + for name, (size, dtype), zero in zip(names, buffer_specs, zero_flags): + _ensure_buffer(backend, name, size, dtype, zero=zero) + + +def build_pa_metadata_for_decode( + backend, + batch_size: int, + tp_q_head_num: Optional[int] = None, +): + """Build pa_metadata buffers for pa_persistent_fwd in decode mode.""" + max_qlen = 1 + + if tp_q_head_num is None: + tp_q_head_num = backend.num_head + + buffer_specs = get_pa_metadata_info_v1(batch_size, backend.num_kv_head) + allocate_pa_metadata_buffers(backend, buffer_specs) + qo_indptr = backend.pa_decode_qo_indptr[: batch_size + 1] + + context_lens = backend.forward_metadata.kv_lens + + kernel_block_size = backend.page_size + num_blocks_per_seq = (context_lens + kernel_block_size - 1) // kernel_block_size + pages_kv_indptr = backend.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + page_table = backend.forward_metadata.page_table + + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + backend.pa_batch_indices[:batch_size], + num_blocks_per_seq, + pages_kv_indptr, + None, + backend.pa_kv_indices, + page_table.stride(0), + ) + kv_indices = backend.pa_kv_indices + + get_pa_metadata_v1( + seqlens_qo_indptr=qo_indptr, + pages_kv_indptr=pages_kv_indptr, + context_lens=context_lens.int(), + num_heads_per_head_k=tp_q_head_num // backend.num_kv_head, + num_heads_k=backend.num_kv_head, + is_causal=True, + work_metadata_ptrs=backend.pa_metadata_buffers["work_metadata_ptrs"], + work_indptr=backend.pa_metadata_buffers["work_indptr"], + work_info=backend.pa_metadata_buffers["work_info"], + reduce_indptr=backend.pa_metadata_buffers["reduce_indptr"], + reduce_final_map=backend.pa_metadata_buffers["reduce_final_map"], + reduce_partial_map=backend.pa_metadata_buffers["reduce_partial_map"], + kv_granularity=max(kernel_block_size, 16), + block_size=kernel_block_size, + max_seqlen_qo=max_qlen, + uni_seqlen_qo=max_qlen, + fast_mode=True, + topk=-1, + max_split_per_batch=-1, + ) + backend.forward_metadata.pa_metadata_qo_indptr = qo_indptr + backend.forward_metadata.pa_metadata_pages_kv_indptr = pages_kv_indptr + backend.forward_metadata.pa_metadata_kv_indices = kv_indices + backend.forward_metadata.pa_metadata_context_lens = context_lens + backend.forward_metadata.pa_metadata_max_qlen = max_qlen + + +def build_pa_metadata_for_prefill(backend, batch_size: int): + """Build page-level metadata for non-MLA prefill mode.""" + block_size = backend.page_size + context_lens = backend.forward_metadata.kv_lens + num_blocks_per_seq = (context_lens + block_size - 1) // block_size + + pages_kv_indptr = backend.pa_kv_indptr[: batch_size + 1] + pages_kv_indptr[1 : batch_size + 1] = torch.cumsum(num_blocks_per_seq, dim=0) + + page_table = backend.forward_metadata.page_table + create_flashinfer_kv_indices_triton[(batch_size,)]( + page_table, + backend.pa_batch_indices[:batch_size], + num_blocks_per_seq, + pages_kv_indptr, + None, + backend.pa_kv_indices, + page_table.stride(0), + ) From 1ba3960acbedd84bec4b2d9bbb7acc402582cf94 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Thu, 21 May 2026 09:31:08 +0000 Subject: [PATCH 09/17] [ATOM-SGL][Attn refrac] Format refactored attention files Co-authored-by: Cursor --- .../full_attention/full_attention_backend.py | 1 + .../sglang/models/deepseek_mla_attention.py | 35 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py index b459433f1e..6dc70a9996 100644 --- a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py +++ b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py @@ -77,6 +77,7 @@ except ImportError: pass + class ATOMAttnBackendForSgl(AiterAttnBackend): """ATOM's custom attention backend for sglang plugin mode. diff --git a/atom/plugin/sglang/models/deepseek_mla_attention.py b/atom/plugin/sglang/models/deepseek_mla_attention.py index 65a240efd2..79279505f5 100644 --- a/atom/plugin/sglang/models/deepseek_mla_attention.py +++ b/atom/plugin/sglang/models/deepseek_mla_attention.py @@ -133,9 +133,7 @@ def _project_q( if q_scale is not None else attn.q_proj(q_input) ) - return _unwrap_linear_output(q).view( - -1, attn.num_local_heads, attn.qk_head_dim - ) + return _unwrap_linear_output(q).view(-1, attn.num_local_heads, attn.qk_head_dim) def _forward_absorbed( self, @@ -160,14 +158,15 @@ def _forward_absorbed( q = self._project_q(q_input, q_scale) k_nope = kv_c_normed.unsqueeze(1) k_pe = k_pe.unsqueeze(1) - q_nope, q_pe = q.split( - [attn.qk_nope_head_dim, attn.qk_rope_head_dim], dim=-1 - ) + q_nope, q_pe = q.split([attn.qk_nope_head_dim, attn.qk_rope_head_dim], dim=-1) q_nope_out = mla_absorbed_bmm( attn, q_nope, attn.w_kc, attn.w_scale, attn.w_scale_k, attn.kv_lora_rank ) - if attn.rotary_emb is not None and not attn.use_fused_qk_rope_concat_and_cache_mla: + if ( + attn.rotary_emb is not None + and not attn.use_fused_qk_rope_concat_and_cache_mla + ): q_pe, k_pe = attn.rotary_emb(positions, q_pe, k_pe) if nsa_use_prefill_cp(forward_batch): @@ -181,9 +180,7 @@ def _forward_absorbed( mla_attn = _get_sglang_radix_attn(self.base_attn) kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(mla_attn.layer_id) q_out_dtype = ( - dtypes.fp8 - if attn.kv_cache_dtype == "fp8_e4m3" - else q_nope_out.dtype + dtypes.fp8 if attn.kv_cache_dtype == "fp8_e4m3" else q_nope_out.dtype ) q = torch.empty( ( @@ -295,14 +292,16 @@ def forward( attn_tp_context = get_attn_tp_context() with attn_tp_context.maybe_input_scattered(forward_batch): - q_input, kv_c_normed, k_pe, positions, q_scale = self._gather_runtime_inputs( - q_input, - kv_c_normed, - k_pe, - positions, - q_scale, - forward_batch=forward_batch, - input_scattered=attn_tp_context.input_scattered, + q_input, kv_c_normed, k_pe, positions, q_scale = ( + self._gather_runtime_inputs( + q_input, + kv_c_normed, + k_pe, + positions, + q_scale, + forward_batch=forward_batch, + input_scattered=attn_tp_context.input_scattered, + ) ) use_non_absorbed = forward_batch.forward_mode.is_extend_without_speculative() From 0680a4b3e58bec2d6da17b1b4ba4df93b7709d80 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Thu, 21 May 2026 09:41:40 +0000 Subject: [PATCH 10/17] [ATOM-SGL][Attn refrac] Fix ruff findings in refactored attention code Co-authored-by: Cursor --- .../attention_backend/full_attention/full_attention_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py index 6dc70a9996..26ff848340 100644 --- a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py +++ b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py @@ -45,7 +45,6 @@ flash_attn_varlen_func, dtypes, get_pa_metadata_info_v1, - get_pa_metadata_v1, mha_batch_prefill_func, pa_fwd_asm, pa_persistent_fwd, From 21d6c2518e525d1c89014c244c069fd500bfccff Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Thu, 21 May 2026 10:20:59 +0000 Subject: [PATCH 11/17] [ATOM-SGL][Attn refrac] Avoid DeepSeek MLA wrapper module cycle Co-authored-by: Cursor --- atom/plugin/sglang/models/deepseek_mla_attention.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/atom/plugin/sglang/models/deepseek_mla_attention.py b/atom/plugin/sglang/models/deepseek_mla_attention.py index 79279505f5..f9e6edd849 100644 --- a/atom/plugin/sglang/models/deepseek_mla_attention.py +++ b/atom/plugin/sglang/models/deepseek_mla_attention.py @@ -28,7 +28,10 @@ def __init__( base_attn: nn.Module, ) -> None: super().__init__() - self.owner_attn = owner_attn + # Keep a non-module back reference. Registering owner_attn as a child + # module would create owner_attn -> mla_attn(wrapper) -> owner_attn and + # make nn.Module.train/eval recurse forever. + object.__setattr__(self, "owner_attn", owner_attn) self.base_attn = base_attn @property From 74829e99e343c262dc21e787680038d5cc15767d Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Mon, 1 Jun 2026 06:33:18 +0000 Subject: [PATCH 12/17] fix rebase issue --- atom/plugin/sglang/models/deepseek_mla_forward.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/atom/plugin/sglang/models/deepseek_mla_forward.py b/atom/plugin/sglang/models/deepseek_mla_forward.py index 58362c54dd..219426c3ae 100644 --- a/atom/plugin/sglang/models/deepseek_mla_forward.py +++ b/atom/plugin/sglang/models/deepseek_mla_forward.py @@ -196,8 +196,7 @@ def init_sgl_attrs( v_head_dim=attn.v_head_dim, prefix=maybe_prefix(attn.prefix, "attn_non_absorbed"), ) - if hasattr(attn.attn_non_absorbed, "attn"): - attn.attn_non_absorbed.attn.kv_b_proj = None + _bind_non_absorbed_kv_b_proj(attn) def mla_absorbed_bmm( @@ -352,6 +351,15 @@ def _get_sglang_radix_attn(attn_module): return attn_module.attn if hasattr(attn_module, "attn") else attn_module +def _bind_non_absorbed_kv_b_proj(attn: DeepseekV2MLAAttention) -> None: + """Expose DeepSeek's latent-KV projection on the non-absorbed SGLang layer.""" + + if not hasattr(attn, "attn_non_absorbed"): + return + attn_non_absorbed = _get_sglang_radix_attn(attn.attn_non_absorbed) + attn_non_absorbed.kv_b_proj = attn.kv_b_proj + + def _concat_mha_k_for_non_absorbed( attn: DeepseekV2MLAAttention, k_nope: torch.Tensor, @@ -673,6 +681,7 @@ def process_mla_kv_b_proj_after_loading(attn: DeepseekV2MLAAttention) -> None: Orchestrates reading, quantization handling, and splitting of kv_b_proj into absorbed w_kc / w_vc weights. """ + _bind_non_absorbed_kv_b_proj(attn) if not getattr(attn.kv_b_proj, "_sgl_mxfp4_process_done", False): attn.kv_b_proj.process_weights_after_loading() From f8474a1b2843f5aaf5d92cf46dda593d325b6675 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Mon, 1 Jun 2026 06:54:31 +0000 Subject: [PATCH 13/17] precheckin --- .../sglang/models/deepseek_mla_attention.py | 9 ++++---- .../sglang/models/deepseek_mla_forward.py | 23 +++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/atom/plugin/sglang/models/deepseek_mla_attention.py b/atom/plugin/sglang/models/deepseek_mla_attention.py index f9e6edd849..696f064a6d 100644 --- a/atom/plugin/sglang/models/deepseek_mla_attention.py +++ b/atom/plugin/sglang/models/deepseek_mla_attention.py @@ -307,14 +307,15 @@ def forward( ) ) - use_non_absorbed = forward_batch.forward_mode.is_extend_without_speculative() + use_non_absorbed = ( + forward_batch.forward_mode.is_extend_without_speculative() + ) if not use_non_absorbed and forward_batch.forward_mode.is_draft_extend(): extend_prefix_lens_cpu = getattr( forward_batch, "extend_prefix_lens_cpu", None ) - use_non_absorbed = ( - extend_prefix_lens_cpu is not None - and not any(extend_prefix_lens_cpu) + use_non_absorbed = extend_prefix_lens_cpu is not None and not any( + extend_prefix_lens_cpu ) if use_non_absorbed: diff --git a/atom/plugin/sglang/models/deepseek_mla_forward.py b/atom/plugin/sglang/models/deepseek_mla_forward.py index 219426c3ae..fcb8696d02 100644 --- a/atom/plugin/sglang/models/deepseek_mla_forward.py +++ b/atom/plugin/sglang/models/deepseek_mla_forward.py @@ -404,6 +404,29 @@ def _set_mla_kv_buffer_for_non_absorbed( ) +def _is_mxfp4_kv_b_proj(attn: DeepseekV2MLAAttention) -> bool: + kv_b_proj = attn.kv_b_proj + params_dtype = getattr(kv_b_proj, "params_dtype", None) + if params_dtype == dtypes.fp4x2 or params_dtype == getattr( + torch, "float4_e2m1fn_x2", None + ): + return True + + quant_type = getattr(kv_b_proj, "quant_type", None) + if getattr(quant_type, "name", "") == "per_1x32" or str(quant_type).endswith( + "per_1x32" + ): + return True + + quant_method = getattr(kv_b_proj, "quant_method", None) + quant_config = getattr(quant_method, "quant_config", None) + return bool( + quant_config is not None + and quant_config.get_name() == "quark" + and kv_b_proj.weight.dtype == torch.uint8 + ) + + def _can_run_non_absorbed_mla_now( attn: DeepseekV2MLAAttention, forward_batch, From bf1cd4f501c7beb71ef5ae5a0e242f1fd948c467 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Tue, 2 Jun 2026 06:23:57 +0000 Subject: [PATCH 14/17] prepare for sglang only --- atom/__init__.py | 5 +- atom/plugin/__init__.py | 8 +- atom/plugin/prepare.py | 77 ------------------- atom/plugin/sglang/__init__.py | 3 + .../sglang/models/base_model_wrapper.py | 10 +-- atom/plugin/sglang/models/qwen3_5.py | 4 +- atom/plugin/sglang/prepare.py | 75 ++++++++++++++++++ tests/plugin/test_sglang_prepare_hooks.py | 57 ++++++++------ tests/plugin/test_sglang_prepare_model.py | 68 ++++++++++------ 9 files changed, 162 insertions(+), 145 deletions(-) create mode 100644 atom/plugin/sglang/prepare.py diff --git a/atom/__init__.py b/atom/__init__.py index 7c1c75eb3a..9049532b61 100644 --- a/atom/__init__.py +++ b/atom/__init__.py @@ -4,11 +4,10 @@ from atom.model_engine.llm_engine import LLMEngine from atom.sampling_params import SamplingParams -# interface for upper framework to construct the model from ATOM -from atom.plugin import prepare_model +from atom.plugin.sglang import prepare_model_for_sglang __all__ = [ "LLMEngine", "SamplingParams", - "prepare_model", + "prepare_model_for_sglang", ] diff --git a/atom/plugin/__init__.py b/atom/plugin/__init__.py index 27c855e511..315b40cf75 100644 --- a/atom/plugin/__init__.py +++ b/atom/plugin/__init__.py @@ -1,12 +1,6 @@ -from .prepare import ( - prepare_model, - is_sglang, - is_vllm, - is_plugin_mode, -) +from .prepare import is_plugin_mode, is_sglang, is_vllm __all__ = [ - "prepare_model", "is_sglang", "is_vllm", "is_plugin_mode", diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py index bbcb369bcd..ede7c9de64 100644 --- a/atom/plugin/prepare.py +++ b/atom/plugin/prepare.py @@ -1,8 +1,3 @@ -from typing import Any -import logging - -logger = logging.getLogger("atom") - # all of the supported frameworks, including server mode and plugin mode _SUPPORTED_FRAMEWORKS = ["vllm", "sglang", "sgl", "atom"] @@ -33,75 +28,3 @@ def _set_framework_backbone(framework: str) -> None: raise ValueError(f"Unsupported framework {framework} for ATOM to plug in") global _CURRENT_FRAMEWORK _CURRENT_FRAMEWORK = framework - - -def prepare_model(config: Any, engine: str): - """ - Prepare the model to upper framework SGLang - """ - logger.info(f"Prepare model for plugin mode, the upper engine is {engine}") - - _set_framework_backbone(engine) - - if is_sglang(): - model_arch = config.architectures[0] - else: - raise ValueError( - f"prepare_model does not support engine {engine!r} " - f"with config type {type(config)}" - ) - - # import here to avoid partial initialization - from .register import ( - _ATOM_SUPPORTED_MODELS, - # register_ops_to_vllm, - register_ops_to_sglang, - init_aiter_dist, - set_attn_cls, - ) - - if model_arch not in _ATOM_SUPPORTED_MODELS: - supported_archs = list(_ATOM_SUPPORTED_MODELS.keys()) - raise ValueError( - f"ATOM does not support the required model architecture: {model_arch}. " - f"For now supported model architectures: {supported_archs}" - ) - - from atom.plugin.config import generate_atom_config_for_plugin_mode - - atom_config = generate_atom_config_for_plugin_mode(config) - - model_cls = _ATOM_SUPPORTED_MODELS[model_arch] - logger.info(f"ATOM model class for {model_arch} is {model_cls}") - - from atom.plugin.sglang.runtime import get_model_arch_spec - - model_adapter = get_model_arch_spec(model_arch) - if model_adapter.prepare_config is not None: - model_adapter.prepare_config(atom_config, model_arch) - - register_ops_to_sglang(atom_config=atom_config) - set_attn_cls() - - # init aiter dist for using aiter custom collective ops - init_aiter_dist(config=atom_config) - - # Patch SGLang graph_capture to also enter aiter's ca_comm.capture(), - # avoiding hipMemcpyAsync in aiter collectives when model uses aiter's - # custom all_reduce (same fix as atom/plugin/vllm/graph_capture_patch.py) - from atom.plugin.sglang.graph_capture_patch import apply_graph_capture_patch - - apply_graph_capture_patch() - - try: - model = model_cls(atom_config=atom_config) - except TypeError as exc: - # Some SGLang plugin models keep SGLang's native wrapper constructor - # and only swap their internal language_model with an ATOM model. - # Those classes accept `config=...` instead of `atom_config=...`. - if "atom_config" not in str(exc): - raise - model = model_cls(config=config) - if not hasattr(model, "atom_config"): - model.atom_config = atom_config - return model diff --git a/atom/plugin/sglang/__init__.py b/atom/plugin/sglang/__init__.py index e69de29bb2..d03f99a045 100644 --- a/atom/plugin/sglang/__init__.py +++ b/atom/plugin/sglang/__init__.py @@ -0,0 +1,3 @@ +from atom.plugin.sglang.prepare import prepare_model_for_sglang + +__all__ = ["prepare_model_for_sglang"] diff --git a/atom/plugin/sglang/models/base_model_wrapper.py b/atom/plugin/sglang/models/base_model_wrapper.py index 02ac6edc77..c3bd0854be 100644 --- a/atom/plugin/sglang/models/base_model_wrapper.py +++ b/atom/plugin/sglang/models/base_model_wrapper.py @@ -64,17 +64,11 @@ def __init__( self.model_arch = getattr(config, "architectures", [""])[0] self.model_arch_spec = get_model_arch_spec(self.model_arch) - import atom - - # TODO: prepare_model() currently handles model construction, config - # generation, attention backend registration, and distributed init. - # Refactor so this wrapper only dispatches the attention backend - # (register_ops_to_sglang + set_attn_cls), and let sglang handle - # model construction directly with plugin_runtime_scope(framework="sglang"): from atom.config import get_current_atom_config + from atom.plugin.sglang.prepare import prepare_model - self.model = atom.prepare_model(config=config, engine="sglang") + self.model = prepare_model(config=config) self.atom_config = getattr(self.model, "atom_config", None) if self.atom_config is None: self.atom_config = get_current_atom_config() diff --git a/atom/plugin/sglang/models/qwen3_5.py b/atom/plugin/sglang/models/qwen3_5.py index c0b6b0ca0b..e85af86a22 100644 --- a/atom/plugin/sglang/models/qwen3_5.py +++ b/atom/plugin/sglang/models/qwen3_5.py @@ -202,13 +202,13 @@ def __init__( prefix: str = "", ) -> None: del prefix - import atom + from atom.plugin.sglang.prepare import prepare_model nn.Module.__init__(self) root_config = type(self)._pending_vlm_root_config if root_config is None: root_config = config - atom_lm = atom.prepare_model(config=root_config, engine="sglang") + atom_lm = prepare_model(config=root_config) if atom_lm is None: arch = getattr(root_config, "architectures", ["unknown"])[0] raise ValueError(f"ATOM failed to build language model for {arch}") diff --git a/atom/plugin/sglang/prepare.py b/atom/plugin/sglang/prepare.py new file mode 100644 index 0000000000..a51f5f2f44 --- /dev/null +++ b/atom/plugin/sglang/prepare.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import logging +from typing import Any + +from atom.plugin.prepare import _set_framework_backbone + +logger = logging.getLogger("atom") + + +def prepare_model(config: Any): + """Prepare an ATOM model for SGLang plugin mode.""" + logger.info("Prepare model for plugin mode, the upper engine is sglang") + _set_framework_backbone("sglang") + + model_arch = config.architectures[0] + + # Import here to avoid partial initialization while SGLang discovers models. + from atom.plugin.register import ( + _ATOM_SUPPORTED_MODELS, + init_aiter_dist, + register_ops_to_sglang, + set_attn_cls, + ) + + if model_arch not in _ATOM_SUPPORTED_MODELS: + supported_archs = list(_ATOM_SUPPORTED_MODELS.keys()) + raise ValueError( + f"ATOM does not support the required model architecture: {model_arch}. " + f"For now supported model architectures: {supported_archs}" + ) + + from atom.plugin.config import generate_atom_config_for_plugin_mode + + atom_config = generate_atom_config_for_plugin_mode(config) + + model_cls = _ATOM_SUPPORTED_MODELS[model_arch] + logger.info("ATOM model class for %s is %s", model_arch, model_cls) + + from atom.plugin.sglang.runtime import get_model_arch_spec + + model_adapter = get_model_arch_spec(model_arch) + if model_adapter.prepare_config is not None: + model_adapter.prepare_config(atom_config, model_arch) + + register_ops_to_sglang(atom_config=atom_config) + set_attn_cls() + + # Init aiter dist for using aiter custom collective ops. + init_aiter_dist(config=atom_config) + + # Patch SGLang graph_capture to also enter aiter's ca_comm.capture(), + # avoiding hipMemcpyAsync in aiter collectives when model uses aiter's + # custom all_reduce (same fix as atom/plugin/vllm/graph_capture_patch.py). + from atom.plugin.sglang.graph_capture_patch import apply_graph_capture_patch + + apply_graph_capture_patch() + + try: + model = model_cls(atom_config=atom_config) + except TypeError as exc: + # Some SGLang plugin models keep SGLang's native wrapper constructor + # and only swap their internal language_model with an ATOM model. + # Those classes accept `config=...` instead of `atom_config=...`. + if "atom_config" not in str(exc): + raise + model = model_cls(config=config) + if not hasattr(model, "atom_config"): + model.atom_config = atom_config + return model + + +def prepare_model_for_sglang(config: Any): + """Backward-compatible alias for SGLang plugin model preparation.""" + return prepare_model(config) diff --git a/tests/plugin/test_sglang_prepare_hooks.py b/tests/plugin/test_sglang_prepare_hooks.py index cff778076b..04376cfa36 100644 --- a/tests/plugin/test_sglang_prepare_hooks.py +++ b/tests/plugin/test_sglang_prepare_hooks.py @@ -11,7 +11,8 @@ import pytest -from atom.plugin import prepare as plugin_prepare +from atom.plugin import prepare as plugin_runtime +from atom.plugin.sglang import prepare as sglang_prepare class _Obj: @@ -44,27 +45,32 @@ def _module(name: str, **attrs) -> ModuleType: return module +def _make_fake_runtime_module(model_arch: str, prepare_config): + module = ModuleType("atom.plugin.sglang.runtime") + module.get_model_arch_spec = MagicMock( + return_value=_Obj(prepare_config=prepare_config) + ) + return module + + @pytest.fixture(autouse=True) def _reset_framework_state(): - plugin_prepare._set_framework_backbone("atom") + plugin_runtime._set_framework_backbone("atom") yield - plugin_prepare._set_framework_backbone("atom") + plugin_runtime._set_framework_backbone("atom") @pytest.mark.parametrize( - "model_arch,expect_register_ops", + "model_arch", ( - ("Qwen3_5ForConditionalGeneration", False), - ("Qwen3_5MoeForConditionalGeneration", False), - ("Qwen3NextForCausalLM", False), - ("DeepseekV3ForCausalLM", True), - ("Qwen3MoeForCausalLM", True), + "Qwen3_5ForConditionalGeneration", + "Qwen3_5MoeForConditionalGeneration", + "Qwen3NextForCausalLM", + "DeepseekV3ForCausalLM", + "Qwen3MoeForCausalLM", ), ) -def test_prepare_model_register_ops_gate( - model_arch: str, - expect_register_ops: bool, -): +def test_prepare_model_register_ops_gate(model_arch: str): fake_atom_config = _Obj(plugin_config=_Obj(is_plugin_mode=True)) fake_register, _fake_model, fake_model_cls = _make_fake_register_module(model_arch) fake_config_mod = MagicMock() @@ -75,29 +81,34 @@ def test_prepare_model_register_ops_gate( "atom.plugin.sglang.models.qwen3_5", apply_prepare_model_adaptations=MagicMock(), ) + prepare_config = ( + fake_qwen35_mod.apply_prepare_model_adaptations + if model_arch + in { + "Qwen3_5ForConditionalGeneration", + "Qwen3_5MoeForConditionalGeneration", + } + else None + ) + fake_runtime_mod = _make_fake_runtime_module(model_arch, prepare_config) with patch.dict( sys.modules, { "atom.plugin.register": fake_register, "atom.plugin.config": fake_config_mod, + "atom.plugin.sglang.runtime": fake_runtime_mod, "atom.plugin.sglang.models.qwen3_5": fake_qwen35_mod, "atom.plugin.sglang.graph_capture_patch": MagicMock( apply_graph_capture_patch=MagicMock() ), }, ): - plugin_prepare.prepare_model( - config=_Obj(architectures=[model_arch]), - engine="sglang", - ) + sglang_prepare.prepare_model(config=_Obj(architectures=[model_arch])) - if expect_register_ops: - fake_register.register_ops_to_sglang.assert_called_once_with( - atom_config=fake_atom_config - ) - else: - fake_register.register_ops_to_sglang.assert_not_called() + fake_register.register_ops_to_sglang.assert_called_once_with( + atom_config=fake_atom_config + ) if model_arch in { "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration", diff --git a/tests/plugin/test_sglang_prepare_model.py b/tests/plugin/test_sglang_prepare_model.py index a9ae0a1851..b5f12b672a 100644 --- a/tests/plugin/test_sglang_prepare_model.py +++ b/tests/plugin/test_sglang_prepare_model.py @@ -13,7 +13,8 @@ import pytest from unittest.mock import MagicMock, patch -from atom.plugin import prepare as plugin_prepare +from atom.plugin import prepare as plugin_runtime +from atom.plugin.sglang import prepare as sglang_prepare class _Obj: @@ -26,9 +27,9 @@ def __init__(self, **kwargs): @pytest.fixture(autouse=True) def _reset_framework_state(): - plugin_prepare._set_framework_backbone("atom") + plugin_runtime._set_framework_backbone("atom") yield - plugin_prepare._set_framework_backbone("atom") + plugin_runtime._set_framework_backbone("atom") def _make_fake_register_module(model_dict=None): @@ -41,35 +42,37 @@ def _make_fake_register_module(model_dict=None): return mod +def _make_fake_runtime_module(): + mod = MagicMock() + mod.get_model_arch_spec = MagicMock(return_value=_Obj(prepare_config=None)) + return mod + + # --------------------------------------------------------------------------- # Engine / architecture validation # --------------------------------------------------------------------------- -def test_prepare_model_rejects_unsupported_engine(): - """Unsupported engine should raise ValueError from _set_framework_backbone.""" - config = _Obj(architectures=["SomeModel"]) - with pytest.raises(ValueError, match="Unsupported framework"): - plugin_prepare.prepare_model(config=config, engine="tensorflow") - - -def test_prepare_model_rejects_non_sglang_engine_gracefully(): - """vllm engine currently not supported in prepare_model (only sglang path).""" - config = _Obj(architectures=["Qwen3ForCausalLM"]) - with pytest.raises(ValueError, match="does not support engine"): - plugin_prepare.prepare_model(config=config, engine="vllm") - - def test_prepare_model_rejects_unsupported_architecture(): - """Known engine but unknown arch should raise ValueError.""" + """Unknown architecture should raise ValueError from the SGLang prepare path.""" fake_register = _make_fake_register_module( model_dict={"DeepseekV3ForCausalLM": MagicMock()} ) + fake_runtime = _make_fake_runtime_module() - with patch.dict(sys.modules, {"atom.plugin.register": fake_register}): + with patch.dict( + sys.modules, + { + "atom.plugin.register": fake_register, + "atom.plugin.sglang.runtime": fake_runtime, + "atom.plugin.sglang.graph_capture_patch": MagicMock( + apply_graph_capture_patch=MagicMock() + ), + }, + ): config = _Obj(architectures=["TotallyFakeModelArch"]) with pytest.raises(ValueError, match="does not support"): - plugin_prepare.prepare_model(config=config, engine="sglang") + sglang_prepare.prepare_model(config=config) # --------------------------------------------------------------------------- @@ -86,6 +89,7 @@ def test_prepare_model_sglang_happy_path(): fake_register = _make_fake_register_module( model_dict={"DeepseekV3ForCausalLM": fake_model_cls} ) + fake_runtime = _make_fake_runtime_module() mock_gen_config = MagicMock(return_value=fake_atom_config) fake_config_mod = MagicMock() @@ -96,10 +100,14 @@ def test_prepare_model_sglang_happy_path(): { "atom.plugin.register": fake_register, "atom.plugin.config": fake_config_mod, + "atom.plugin.sglang.runtime": fake_runtime, + "atom.plugin.sglang.graph_capture_patch": MagicMock( + apply_graph_capture_patch=MagicMock() + ), }, ): config = _Obj(architectures=["DeepseekV3ForCausalLM"]) - result = plugin_prepare.prepare_model(config=config, engine="sglang") + result = sglang_prepare.prepare_model(config=config) # Config generation called mock_gen_config.assert_called_once_with(config) @@ -126,6 +134,7 @@ def test_prepare_model_selects_sglang_dict_for_deepseek_v2(): fake_register = _make_fake_register_module( model_dict={"DeepseekV2ForCausalLM": fake_model_cls} ) + fake_runtime = _make_fake_runtime_module() fake_config_mod = MagicMock() fake_config_mod.generate_atom_config_for_plugin_mode = MagicMock( return_value=fake_atom_config @@ -136,10 +145,14 @@ def test_prepare_model_selects_sglang_dict_for_deepseek_v2(): { "atom.plugin.register": fake_register, "atom.plugin.config": fake_config_mod, + "atom.plugin.sglang.runtime": fake_runtime, + "atom.plugin.sglang.graph_capture_patch": MagicMock( + apply_graph_capture_patch=MagicMock() + ), }, ): config = _Obj(architectures=["DeepseekV2ForCausalLM"]) - result = plugin_prepare.prepare_model(config=config, engine="sglang") + result = sglang_prepare.prepare_model(config=config) assert result is fake_model @@ -152,6 +165,7 @@ def test_prepare_model_sets_framework_to_sglang(): fake_register = _make_fake_register_module( model_dict={"DeepseekV3ForCausalLM": fake_model_cls} ) + fake_runtime = _make_fake_runtime_module() fake_config_mod = MagicMock() fake_config_mod.generate_atom_config_for_plugin_mode = MagicMock( return_value=fake_atom_config @@ -162,10 +176,14 @@ def test_prepare_model_sets_framework_to_sglang(): { "atom.plugin.register": fake_register, "atom.plugin.config": fake_config_mod, + "atom.plugin.sglang.runtime": fake_runtime, + "atom.plugin.sglang.graph_capture_patch": MagicMock( + apply_graph_capture_patch=MagicMock() + ), }, ): config = _Obj(architectures=["DeepseekV3ForCausalLM"]) - plugin_prepare.prepare_model(config=config, engine="sglang") + sglang_prepare.prepare_model(config=config) - assert plugin_prepare.is_sglang() is True - assert plugin_prepare.is_plugin_mode() is True + assert plugin_runtime.is_sglang() is True + assert plugin_runtime.is_plugin_mode() is True From 961821ec2d39a3770165b2a2135d03e436c6793b Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Tue, 2 Jun 2026 06:56:57 +0000 Subject: [PATCH 15/17] import error meet in qwen3.5 --- atom/plugin/sglang/attention.py | 4 +++- .../full_attention/full_attention_backend.py | 16 +++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/atom/plugin/sglang/attention.py b/atom/plugin/sglang/attention.py index f1474c45b1..44a810c152 100644 --- a/atom/plugin/sglang/attention.py +++ b/atom/plugin/sglang/attention.py @@ -1,4 +1,6 @@ -from atom.plugin.sglang.attention_backend.radix_attention import RadixAttention +from atom.plugin.sglang.attention_backend.full_attention.radix_attention import ( + RadixAttention, +) class AttentionForSGLang(RadixAttention): diff --git a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py index 26ff848340..5e8ac5a498 100644 --- a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py +++ b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py @@ -19,7 +19,6 @@ from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend from sglang.srt.layers.attention.utils import ( create_flashinfer_kv_indices_triton, - launch_reshape_and_cache_flash, pad_sequence_with_mask, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -1578,16 +1577,15 @@ def _set_kv_buffer_native_dense(self, layer, cache_loc, k, v, forward_batch): k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id ) - launch_reshape_and_cache_flash( + self.set_kv_buffer_with_layout_shuffle( + cache_loc, k.view(-1, layer.tp_k_head_num, layer.qk_head_dim), v.view(-1, layer.tp_v_head_num, layer.v_head_dim), - k_cache.view( - -1, self.page_size, layer.tp_k_head_num, layer.qk_head_dim - ), - v_cache.view(-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim), - cache_loc, - k_scale=k_descale, - v_scale=v_descale, + k_cache, + v_cache, + k_descale, + v_descale, + self.page_size, ) return From 94829c77e71775e6c12524dec0bdaa7e3ad9c457 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Tue, 2 Jun 2026 10:27:53 +0000 Subject: [PATCH 16/17] qwen3.5 acc fix --- .../full_attention/full_attention_backend.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py index 5e8ac5a498..26ff848340 100644 --- a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py +++ b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py @@ -19,6 +19,7 @@ from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend from sglang.srt.layers.attention.utils import ( create_flashinfer_kv_indices_triton, + launch_reshape_and_cache_flash, pad_sequence_with_mask, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -1577,15 +1578,16 @@ def _set_kv_buffer_native_dense(self, layer, cache_loc, k, v, forward_batch): k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id ) - self.set_kv_buffer_with_layout_shuffle( - cache_loc, + launch_reshape_and_cache_flash( k.view(-1, layer.tp_k_head_num, layer.qk_head_dim), v.view(-1, layer.tp_v_head_num, layer.v_head_dim), - k_cache, - v_cache, - k_descale, - v_descale, - self.page_size, + k_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.qk_head_dim + ), + v_cache.view(-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim), + cache_loc, + k_scale=k_descale, + v_scale=v_descale, ) return From ffe41aeb262ade52138b573231131ac7e0a4986d Mon Sep 17 00:00:00 2001 From: qichu-yun Date: Tue, 2 Jun 2026 08:19:19 -0500 Subject: [PATCH 17/17] [Fix] Limit static FP4 linear kv_b_proj post-processing --- .../sglang/models/deepseek_mla_forward.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/atom/plugin/sglang/models/deepseek_mla_forward.py b/atom/plugin/sglang/models/deepseek_mla_forward.py index fcb8696d02..d18449cad7 100644 --- a/atom/plugin/sglang/models/deepseek_mla_forward.py +++ b/atom/plugin/sglang/models/deepseek_mla_forward.py @@ -446,6 +446,19 @@ def _can_run_non_absorbed_mla_now( return True +def _is_static_quark_mxfp4_kv_b_proj(kv_b_proj) -> bool: + layer_quant_config = getattr(kv_b_proj, "layer_quant_config", None) + return ( + getattr(kv_b_proj, "weight", None) is not None + and kv_b_proj.weight.dim() == 2 + and layer_quant_config is not None + and layer_quant_config.quant_method == "quark" + and layer_quant_config.quant_dtype == dtypes.fp4x2 + and getattr(layer_quant_config.quant_type, "name", None) == "per_1x32" + and getattr(kv_b_proj, "source_quant_dtype", None) is None + ) + + def _read_kv_b_proj_weight(attn: DeepseekV2MLAAttention) -> torch.Tensor: """Read kv_b_proj weight, handling AWQ and fnuz dtypes.""" if hasattr(attn.kv_b_proj, "qweight"): @@ -460,14 +473,7 @@ def _read_kv_b_proj_weight(attn: DeepseekV2MLAAttention) -> torch.Tensor: attn.kv_b_proj.qzeros, ).T else: - layer_quant_config = getattr(attn.kv_b_proj, "layer_quant_config", None) - is_quark_static_mxfp4 = ( - layer_quant_config is not None - and layer_quant_config.quant_method == "quark" - and layer_quant_config.quant_dtype == dtypes.fp4x2 - and getattr(layer_quant_config.quant_type, "name", None) == "per_1x32" - ) - if is_quark_static_mxfp4: + if _is_static_quark_mxfp4_kv_b_proj(attn.kv_b_proj): w = getattr( attn.kv_b_proj, "_mxfp4_unshuffled_weight", @@ -705,7 +711,9 @@ def process_mla_kv_b_proj_after_loading(attn: DeepseekV2MLAAttention) -> None: kv_b_proj into absorbed w_kc / w_vc weights. """ _bind_non_absorbed_kv_b_proj(attn) - if not getattr(attn.kv_b_proj, "_sgl_mxfp4_process_done", False): + if _is_static_quark_mxfp4_kv_b_proj(attn.kv_b_proj) and not getattr( + attn.kv_b_proj, "_sgl_mxfp4_process_done", False + ): attn.kv_b_proj.process_weights_after_loading() w = _read_kv_b_proj_weight(attn) @@ -737,16 +745,7 @@ def process_weights_after_loading_with_mxfp4_preserve(): if getattr(kv_b_proj, "_sgl_mxfp4_process_done", False): return - layer_quant_config = getattr(kv_b_proj, "layer_quant_config", None) - is_quark_static_mxfp4 = ( - kv_b_proj.weight.dim() == 2 - and layer_quant_config is not None - and layer_quant_config.quant_method == "quark" - and layer_quant_config.quant_dtype == dtypes.fp4x2 - and getattr(layer_quant_config.quant_type, "name", None) == "per_1x32" - and getattr(kv_b_proj, "source_quant_dtype", None) is None - ) - if is_quark_static_mxfp4: + if _is_static_quark_mxfp4_kv_b_proj(kv_b_proj): kv_b_proj._mxfp4_unshuffled_weight = kv_b_proj.weight.detach().clone() kv_b_proj._mxfp4_unshuffled_weight_scale = ( kv_b_proj.weight_scale.detach().clone()