Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions atom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
from atom.model_engine.llm_engine import LLMEngine
from atom.sampling_params import SamplingParams

# interface for upper framework to construct the model from ATOM
from atom.plugin import prepare_model
from atom.plugin.sglang import prepare_model_for_sglang

__all__ = [
"LLMEngine",
"SamplingParams",
"prepare_model",
"prepare_model_for_sglang",
]
Comment on lines 4 to 13
8 changes: 1 addition & 7 deletions atom/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
from .prepare import (
prepare_model,
is_sglang,
is_vllm,
is_plugin_mode,
)
from .prepare import is_plugin_mode, is_sglang, is_vllm

__all__ = [
"prepare_model",
"is_sglang",
"is_vllm",
"is_plugin_mode",
Expand Down
81 changes: 0 additions & 81 deletions atom/plugin/prepare.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
from typing import Any
import logging

logger = logging.getLogger("atom")

# all of the supported frameworks, including server mode and plugin mode
_SUPPORTED_FRAMEWORKS = ["vllm", "sglang", "sgl", "atom"]

Expand Down Expand Up @@ -33,79 +28,3 @@ def _set_framework_backbone(framework: str) -> None:
raise ValueError(f"Unsupported framework {framework} for ATOM to plug in")
global _CURRENT_FRAMEWORK
_CURRENT_FRAMEWORK = framework


def prepare_model(config: Any, engine: str):
"""
Prepare the model to upper framework SGLang
"""
logger.info(f"Prepare model for plugin mode, the upper engine is {engine}")

_set_framework_backbone(engine)

if is_sglang():
model_arch = config.architectures[0]
else:
raise ValueError(
f"prepare_model does not support engine {engine!r} "
f"with config type {type(config)}"
)

# import here to avoid partial initialization
from .register import (
_ATOM_SUPPORTED_MODELS,
# register_ops_to_vllm,
register_ops_to_sglang,
init_aiter_dist,
set_attn_cls,
)

if model_arch not in _ATOM_SUPPORTED_MODELS:
supported_archs = list(_ATOM_SUPPORTED_MODELS.keys())
raise ValueError(
f"ATOM does not support the required model architecture: {model_arch}. "
f"For now supported model architectures: {supported_archs}"
)

from atom.plugin.config import generate_atom_config_for_plugin_mode

atom_config = generate_atom_config_for_plugin_mode(config)

model_cls = _ATOM_SUPPORTED_MODELS[model_arch]
logger.info(f"ATOM model class for {model_arch} is {model_cls}")

if model_arch in {
"Qwen3_5ForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration",
}:
from atom.plugin.sglang.models.qwen3_5 import (
apply_prepare_model_adaptations,
)

apply_prepare_model_adaptations(atom_config, model_arch)

register_ops_to_sglang(atom_config=atom_config)
set_attn_cls()

# init aiter dist for using aiter custom collective ops
init_aiter_dist(config=atom_config)

# Patch SGLang graph_capture to also enter aiter's ca_comm.capture(),
# avoiding hipMemcpyAsync in aiter collectives when model uses aiter's
# custom all_reduce (same fix as atom/plugin/vllm/graph_capture_patch.py)
from atom.plugin.sglang.graph_capture_patch import apply_graph_capture_patch

apply_graph_capture_patch()

try:
model = model_cls(atom_config=atom_config)
except TypeError as exc:
# Some SGLang plugin models keep SGLang's native wrapper constructor
# and only swap their internal language_model with an ATOM model.
# Those classes accept `config=...` instead of `atom_config=...`.
if "atom_config" not in str(exc):
raise
model = model_cls(config=config)
if not hasattr(model, "atom_config"):
model.atom_config = atom_config
return model
2 changes: 1 addition & 1 deletion atom/plugin/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _register_custom_attention_to_sglang() -> None:
from sglang.srt.layers.attention.attention_registry import (
register_attention_backend,
)
from atom.plugin.sglang.attention_backend.sgl_attn_backend import (
from atom.plugin.sglang.attention_backend.full_attention.full_attention_backend import (
ATOMAttnBackendForSgl,
)

Expand Down
3 changes: 3 additions & 0 deletions atom/plugin/sglang/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from atom.plugin.sglang.prepare import prepare_model_for_sglang

__all__ = ["prepare_model_for_sglang"]
4 changes: 3 additions & 1 deletion atom/plugin/sglang/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from atom.plugin.sglang.attention_backend.radix_attention import RadixAttention
from atom.plugin.sglang.attention_backend.full_attention.radix_attention import (
RadixAttention,
)


class AttentionForSGLang(RadixAttention):
Expand Down
2 changes: 1 addition & 1 deletion atom/plugin/sglang/attention_backend/attention_gdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _build_gdn_metadata(
def build(
cls, forward_batch_or_metadata: Any
) -> Optional["SGLangGDNForwardContext"]:
from atom.plugin.sglang.models.base_model_wrapper import (
from atom.plugin.sglang.runtime import (
SGLangForwardBatchMetadata,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .radix_attention import RadixAttention
from .full_attention_backend import ATOMAttnBackendForSgl, ForwardMetadata
Comment thread
ZhiweiYan-96 marked this conversation as resolved.

__all__ = [
"RadixAttention",
"ATOMAttnBackendForSgl",
"ForwardMetadata",
]
Loading
Loading