diff --git a/README.md b/README.md index 25a12c3..77a6649 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ application code. | LMDeploy | Yes | Yes | External node registration | Uses LMDeploy PD connection pool and RDMA migration when available. | | vLLM | Yes | Yes | Static, heartbeat | Supports two-stage KV transfer and static NIXL DP-aware rank routing. | | SGLang | Yes | Yes | Static | Uses bootstrap dual dispatch with aligned prefill bootstrap ports. | +| DLEngine | Yes | Yes | dlslime-ctrl (`nanoctrl`) | Hybrid `dlengine serve` nodes; auto-discovery when `--ctrl_address` is set. | DLRouter is configured with one backend type per router process through `--backend`. Run multiple router processes if you need separate backend types at @@ -100,6 +101,45 @@ curl -X POST http://localhost:8000/v1/chat/completions \ }' ``` +### DLEngine with dlslime-ctrl discovery + +Start the control plane and a DLEngine OpenAI server (see DLEngine +`dlengine serve`), then run DLRouter with auto-discovery: + +```bash +dlslime-ctrl server --redis-url redis://127.0.0.1:6379 + +dlengine serve /path/to/model \ + --host 0.0.0.0 --port 8100 \ + --served-model-name Qwen3-4B \ + --ctrl-address 127.0.0.1:4479 + +pip install -e ".[dlengine]" # pulls dlslime for NanoCtrlClient + +python -m dlrouter \ + --backend dlengine \ + --serving_strategy hybrid \ + --ctrl_address 127.0.0.1:4479 +``` + +DLRouter polls dlslime-ctrl for entities with kind `dlengine` and registers +their HTTP endpoints. Use the same `model` name as `--served-model-name` in +requests. Manual registration still works via `POST /nodes/add` when +`--ctrl_address` is omitted. + +Send a request (the served model name, model path, and its basename are all +accepted as the `model` value): + +```bash +curl -X POST http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen3-4B", + "messages": [{"role": "user", "content": "Hello!"}], + "stream": false + }' +``` + DLRouter also installs a `dlrouter` console script, so `dlrouter ...` is equivalent to `python -m dlrouter ...` after installation. @@ -226,7 +266,7 @@ be installed in the runtime environment. |---|---|---| | `--server_name` | `0.0.0.0` | Bind address. | | `--server_port` | `8000` | Listen port. | -| `--backend` | `lmdeploy` | Backend type: `lmdeploy`, `vllm`, or `sglang`. | +| `--backend` | `lmdeploy` | Backend type: `lmdeploy`, `vllm`, `sglang`, or `dlengine`. | | `--routing_strategy` | `min_expected_latency` | Request routing strategy. | | `--serving_strategy` | `hybrid` | Serving mode: `hybrid` or `distserve`. | | `--api_keys` | `None` | Comma-separated Bearer tokens for API authentication. | diff --git a/dlrouter/api/app.py b/dlrouter/api/app.py index bfe9cdc..2bce75e 100644 --- a/dlrouter/api/app.py +++ b/dlrouter/api/app.py @@ -17,7 +17,7 @@ ) from dlrouter.backends.factory import create_backend from dlrouter.config import RouterConfig -from dlrouter.constants import ServingStrategy +from dlrouter.constants import ServiceDiscoveryMode, ServingStrategy from dlrouter.core.health_check import HealthChecker from dlrouter.core.node_manager import NodeManager from dlrouter.core.proxy_engine import ProxyEngine @@ -108,22 +108,24 @@ async def lifespan(application: FastAPI): cache_status=config.cache_status, ) - # Service discovery (backend-specific, e.g., ZMQ for vLLM PD mode) + # Service discovery (backend-specific) service_discovery: Optional[Any] = None - if config.serving_strategy == ServingStrategy.DISTSERVE: - discovery_mode = backend.preferred_discovery_mode(config.backend_config) - if discovery_mode is not None: - service_discovery = backend.create_service_discovery( - discovery_mode, - config.backend_config, - node_manager, - ) - # Allow heartbeat-based discovery to drop its registered-address - # cache when a node is removed (e.g. by HealthChecker after a - # crash), so a restarted instance can be re-registered. - unregister = getattr(service_discovery, 'unregister_by_url', None) - if callable(unregister): - node_manager.add_remove_listener(unregister) + discovery_mode = backend.preferred_discovery_mode(config.backend_config) + use_discovery = discovery_mode is not None and ( + config.serving_strategy == ServingStrategy.DISTSERVE + or discovery_mode == ServiceDiscoveryMode.NANOCTRL + ) + if use_discovery: + service_discovery = backend.create_service_discovery( + discovery_mode, + config.backend_config, + node_manager, + ) + # Allow discovery to drop its registered-address cache when a node is + # removed (e.g. by HealthChecker), so a restarted instance can re-register. + unregister = getattr(service_discovery, 'unregister_by_url', None) + if callable(unregister): + node_manager.add_remove_listener(unregister) # Proxy engine proxy_engine = ProxyEngine(node_manager) diff --git a/dlrouter/backends/__init__.py b/dlrouter/backends/__init__.py index 2cae71b..a21bcf8 100644 --- a/dlrouter/backends/__init__.py +++ b/dlrouter/backends/__init__.py @@ -2,6 +2,11 @@ from dlrouter.backends.base import BaseBackend from dlrouter.backends.definition import BackendDefinition +from dlrouter.backends.dlengine import ( + DLENGINE_BACKEND_DEFINITION, + DLEngineBackend, + DLEngineConfig, +) from dlrouter.backends.factory import create_backend, get_backend_definition from dlrouter.backends.lmdeploy import ( LMDEPLOY_BACKEND_DEFINITION, @@ -21,11 +26,14 @@ __all__ = [ + 'DLENGINE_BACKEND_DEFINITION', 'LMDEPLOY_BACKEND_DEFINITION', 'SGLANG_BACKEND_DEFINITION', 'VLLM_BACKEND_DEFINITION', 'BackendDefinition', 'BaseBackend', + 'DLEngineBackend', + 'DLEngineConfig', 'LMDeployBackend', 'LMDeployPDConfig', 'SGLangBackend', diff --git a/dlrouter/backends/dlengine/__init__.py b/dlrouter/backends/dlengine/__init__.py new file mode 100644 index 0000000..b335d53 --- /dev/null +++ b/dlrouter/backends/dlengine/__init__.py @@ -0,0 +1,12 @@ +"""DLEngine backend package.""" + +from dlrouter.backends.dlengine.backend import DLEngineBackend +from dlrouter.backends.dlengine.config import DLEngineConfig +from dlrouter.backends.dlengine.definition import DLENGINE_BACKEND_DEFINITION + + +__all__ = [ + 'DLENGINE_BACKEND_DEFINITION', + 'DLEngineBackend', + 'DLEngineConfig', +] diff --git a/dlrouter/backends/dlengine/backend.py b/dlrouter/backends/dlengine/backend.py new file mode 100644 index 0000000..6d80fa8 --- /dev/null +++ b/dlrouter/backends/dlengine/backend.py @@ -0,0 +1,417 @@ +"""DLEngine backend adapter. + +Forwards OpenAI-compatible HTTP to DLEngine ``serve`` nodes. When +``--ctrl_address`` is set, discovers nodes via dlslime-ctrl (entity kind +``dlengine``). +""" + +import json +from typing import TYPE_CHECKING, Any, Optional + +import aiohttp +import requests +from fastapi import BackgroundTasks +from fastapi.responses import JSONResponse, StreamingResponse + +from dlrouter.backends.base import BaseBackend, CLIArg, PDRequestContext +from dlrouter.backends.http import BackendHTTPTransportMixin, StreamFraming +from dlrouter.backends.dlengine.config import DLEngineConfig +from dlrouter.constants import ( + AIOHTTP_TIMEOUT, + ERROR_MESSAGES, + HEALTH_CHECK_TIMEOUT, + EngineRole, + ErrorCode, + ServiceDiscoveryMode, +) +from dlrouter.core.dp_url import normalize_dp_aware_url +from dlrouter.core.node_lifecycle import post_call, pre_call +from dlrouter.logger import get_logger + + +if TYPE_CHECKING: + from dlrouter.core.node_manager import NodeManager + from dlrouter.core.service_discovery.base import BaseServiceDiscovery + + +logger = get_logger('dlrouter.backends.dlengine') + +DEFAULT_POOL_CONNECTIONS = 100 +DEFAULT_POOL_MAXSIZE = 100 + +# DLRouter adds routing metadata; DLEngine serve only needs generation fields. +# ``kv_transfer_params`` carries the PD handoff (do_remote_decode / migration). +_CHAT_FORWARD_KEYS = frozenset( + { + 'model', + 'messages', + 'prompt', + 'stream', + 'temperature', + 'max_tokens', + 'max_completion_tokens', + 'ignore_eos', + 'stop', + 'kv_transfer_params', + } +) + + +def _sanitize_chat_payload(request_data: dict[str, Any]) -> dict[str, Any]: + """Keep a minimal OpenAI payload for DLEngine serve.""" + payload = {k: request_data[k] for k in _CHAT_FORWARD_KEYS if k in request_data} + if 'model' in payload: + payload['model'] = str(payload['model']) + return payload + + +class DLEngineBackend(BackendHTTPTransportMixin, BaseBackend): + """Backend adapter for DLEngine OpenAI HTTP servers.""" + + stream_framing = StreamFraming.SSE_LINES + + def __init__( + self, + config: Optional[DLEngineConfig] = None, + pool_connections: int = DEFAULT_POOL_CONNECTIONS, + pool_maxsize: int = DEFAULT_POOL_MAXSIZE, + ) -> None: + self.config = config or DLEngineConfig() + self._timeout = aiohttp.ClientTimeout(total=AIOHTTP_TIMEOUT) + self._health_timeout = aiohttp.ClientTimeout(total=HEALTH_CHECK_TIMEOUT) + self._connector_kwargs = { + 'limit': pool_connections, + 'limit_per_host': pool_maxsize, + 'ttl_dns_cache': 300, + 'enable_cleanup_closed': True, + } + self._session: Optional[aiohttp.ClientSession] = None + self._session_lock = None + + @classmethod + def create(cls, parsed_config: Any = None) -> 'DLEngineBackend': + """Create a DLEngine backend from parsed configuration.""" + config = ( + parsed_config + if isinstance(parsed_config, DLEngineConfig) + else DLEngineConfig() + ) + return cls(config=config) + + def fetch_models(self, node_url: str) -> list[str]: + """Fetch available models from a DLEngine node.""" + try: + resp = requests.get( + f'{node_url}/v1/models', + headers={'accept': 'application/json'}, + timeout=HEALTH_CHECK_TIMEOUT, + ) + resp.raise_for_status() + data = resp.json() + return [m['id'] for m in data.get('data', [])] + except Exception as e: + logger.error(f'Failed to fetch models from {node_url}: {e}') + return [] + + def deregister_node(self, node_url: str) -> None: + """No-op for DLEngine hybrid HTTP nodes.""" + + def _prepare_payload(self, endpoint: str, request_data: dict[str, Any]) -> dict[str, Any]: + if endpoint in ('/v1/chat/completions', '/v1/completions'): + return _sanitize_chat_payload(request_data) + return request_data + + async def forward_request( + self, + node_url: str, + endpoint: str, + request_data: dict[str, Any], + stream: bool = False, + ) -> Any: + return await super().forward_request( + node_url, + endpoint, + self._prepare_payload(endpoint, request_data), + stream=stream, + ) + + async def stream_forward( + self, + node_url: str, + endpoint: str, + request_data: dict[str, Any], + ): + payload = self._prepare_payload(endpoint, request_data) + async for chunk in super().stream_forward(node_url, endpoint, payload): + yield chunk + + def supports_pd_disagg(self) -> bool: + """DLEngine supports two-stage PD disaggregation over HTTP.""" + return True + + @staticmethod + def _error_json(code: ErrorCode) -> dict[str, Any]: + return {'error_code': code.value, 'text': ERROR_MESSAGES[code]} + + def _model_not_found_response(self, model_name: str) -> JSONResponse: + logger.warning(f'Model not found: {model_name}') + return JSONResponse( + self._error_json(ErrorCode.MODEL_NOT_FOUND), + status_code=404, + ) + + def _backend_error_response(self) -> JSONResponse: + return JSONResponse( + self._error_json(ErrorCode.BACKEND_ERROR), + status_code=502, + ) + + async def handle_pd_request( + self, + request_data: dict[str, Any], + model_name: str, + endpoint: str, + stream: bool, + context: PDRequestContext, + ) -> Any: + """Two-stage PD: prefill (1 token + KV) -> decode (RDMA-pull + stream). + + Stage 1 asks a prefill node to run a single-token prefill and return an + opaque ``kv_transfer_params.migration`` payload (a serialized prefilled + sequence pointing at the prefill engine's KV blocks). Stage 2 hands that + payload to a decode node, which RDMA-pulls the KV cache and generates the + full completion. The prefill KV blocks are released afterwards via + ``POST /pd/free``. + """ + node_manager = context.node_manager + request_key = context.request_key + + p_url = node_manager.get_node_url(model_name, EngineRole.PREFILL, request_key) + if not p_url: + return self._model_not_found_response(model_name) + d_url = node_manager.get_node_url(model_name, EngineRole.DECODE, request_key) + if not d_url: + return self._model_not_found_response(model_name) + + logger.info(f'PD prefill={p_url} decode={d_url}') + + # ---- Stage 1: prefill ---- + start_p = pre_call(node_manager, p_url) + try: + prefill_info = await self._prefill_request(p_url, endpoint, request_data) + finally: + post_call(node_manager, p_url, start_p) + + if prefill_info is None: + return self._backend_error_response() + + kv = prefill_info.get('kv_transfer_params') or {} + migration = kv.get('migration') + seq_id = kv.get('seq_id') + if not migration: + # No KV to migrate: the prefill node fully finished the request + # locally (e.g. the first sampled token was EOS, so the scheduler + # marked the sequence FINISHED instead of TO_BE_MIGRATED). Return + # its completion directly instead of handing off to a decode node. + if prefill_info.get('choices'): + logger.info('Prefill produced a full completion; skipping decode') + if stream: + return StreamingResponse( + self._completion_as_sse(prefill_info), + media_type='text/event-stream', + ) + return JSONResponse(prefill_info) + logger.error('Prefill returned no migration payload') + return self._backend_error_response() + + # ---- Stage 2: decode ---- + decode_data = _sanitize_chat_payload(request_data) + decode_data['kv_transfer_params'] = {'migration': migration} + decode_data['stream'] = stream + + free_ids = [seq_id] if seq_id is not None else [] + start_d = pre_call(node_manager, d_url) + + if stream: + async def _stream(): + async for chunk in self.stream_forward(d_url, endpoint, decode_data): + yield chunk + + bg = BackgroundTasks() + bg.add_task(post_call, node_manager, d_url, start_d) + if free_ids: + bg.add_task(self._free_prefill, p_url, free_ids) + return StreamingResponse( + _stream(), + background=bg, + media_type='text/event-stream', + ) + + try: + text = await self.forward_request(d_url, endpoint, decode_data) + except Exception as e: + logger.error(f'Decode error on {d_url}: {e}') + return self._backend_error_response() + finally: + post_call(node_manager, d_url, start_d) + if free_ids: + await self._free_prefill(p_url, free_ids) + return JSONResponse(json.loads(text)) + + async def _prefill_request( + self, + node_url: str, + endpoint: str, + request_data: dict[str, Any], + ) -> Optional[dict[str, Any]]: + """Run prefill and return the migration payload. + + We do not clamp ``max_tokens`` to 1: a DLEngine ``mode="prefill"`` + engine already emits exactly one token before handing the sequence off + for migration, and the user's ``max_tokens`` must survive into the + migrated sequence so the decode engine resumes with the right budget. + """ + data = _sanitize_chat_payload(request_data) + data['stream'] = False + data['kv_transfer_params'] = {'do_remote_decode': True} + try: + text = await self.forward_request(node_url, endpoint, data) + return json.loads(text) + except Exception as e: + logger.error(f'Prefill request failed on {node_url}: {e}') + return None + + @staticmethod + async def _completion_as_sse(completion: dict[str, Any]): + """Emit a finished (non-migrated) completion as a one-shot SSE stream. + + Used when the prefill node fully answered the request so there is no + decode handoff, but the client asked for a streaming response. + """ + obj = completion.get('object') or '' + choice = (completion.get('choices') or [{}])[0] + finish_reason = choice.get('finish_reason', 'stop') + if obj == 'chat.completion': + content = (choice.get('message') or {}).get('content', '') + chunk = { + 'id': completion.get('id'), + 'object': 'chat.completion.chunk', + 'created': completion.get('created'), + 'model': completion.get('model'), + 'choices': [ + { + 'index': 0, + 'delta': {'role': 'assistant', 'content': content}, + 'finish_reason': finish_reason, + } + ], + } + else: + chunk = { + 'id': completion.get('id'), + 'object': 'text_completion', + 'created': completion.get('created'), + 'model': completion.get('model'), + 'choices': [ + { + 'index': 0, + 'text': choice.get('text', ''), + 'finish_reason': finish_reason, + } + ], + } + yield f'data: {json.dumps(chunk)}\n\n'.encode() + yield b'data: [DONE]\n\n' + + async def _free_prefill(self, node_url: str, seq_ids: list[int]) -> None: + """Release prefill-side MIGRATE KV blocks via POST /pd/free.""" + try: + session = await self._get_session() + url = normalize_dp_aware_url(node_url) + '/pd/free' + async with session.post(url, json={'seq_ids': seq_ids}) as resp: + await resp.read() + except Exception as e: # noqa: BLE001 + logger.warning(f'PD free failed on {node_url} for {seq_ids}: {e}') + + def preferred_discovery_mode( + self, + backend_config: dict[str, Any], + ) -> Optional[ServiceDiscoveryMode]: + """Use dlslime-ctrl polling when ``ctrl_address`` is configured.""" + cfg = self.parse_config(**backend_config) + if cfg.ctrl_address: + return ServiceDiscoveryMode.NANOCTRL + return None + + @classmethod + def get_cli_args(cls) -> list[CLIArg]: + """Return DLEngine-specific CLI arguments.""" + return [ + CLIArg( + name='ctrl_address', + type=str, + default=None, + help='dlslime-ctrl address (host:port) for DLEngine node discovery', + ), + CLIArg( + name='ctrl_scope', + type=str, + default=None, + help='dlslime-ctrl scope for multi-tenant isolation', + ), + CLIArg( + name='ctrl_kind', + type=str, + default='dlengine', + help='Entity kind to list from dlslime-ctrl (default: dlengine)', + ), + CLIArg( + name='discovery_poll_interval', + type=float, + default=5.0, + help='Seconds between dlslime-ctrl discovery polls', + ), + ] + + @classmethod + def parse_config(cls, **kwargs: Any) -> DLEngineConfig: + """Parse DLEngine config from CLI args.""" + ctrl_address = kwargs.get('ctrl_address') + if ctrl_address is not None: + ctrl_address = str(ctrl_address).strip() or None + ctrl_scope = kwargs.get('ctrl_scope') + if ctrl_scope is not None: + ctrl_scope = str(ctrl_scope).strip() or None + ctrl_kind = kwargs.get('ctrl_kind') or 'dlengine' + interval = float(kwargs.get('discovery_poll_interval', 5.0)) + return DLEngineConfig( + ctrl_address=ctrl_address, + ctrl_scope=ctrl_scope, + ctrl_kind=str(ctrl_kind), + discovery_poll_interval=interval, + ) + + def create_service_discovery( + self, + discovery_mode: ServiceDiscoveryMode, + backend_config: dict[str, Any], + node_manager: 'NodeManager', + ) -> Optional['BaseServiceDiscovery']: + """Create dlslime-ctrl polling discovery.""" + if discovery_mode != ServiceDiscoveryMode.NANOCTRL: + return None + cfg = self.parse_config(**backend_config) + if not cfg.ctrl_address: + logger.warning('NanoCtrl discovery requested but ctrl_address is empty') + return None + from dlrouter.core.service_discovery.nanoctrl_discovery import ( + NanoCtrlServiceDiscovery, + ) + + return NanoCtrlServiceDiscovery( + ctrl_address=cfg.ctrl_address, + node_manager=node_manager, + ctrl_scope=cfg.ctrl_scope, + ctrl_kind=cfg.ctrl_kind, + poll_interval=cfg.discovery_poll_interval, + ) diff --git a/dlrouter/backends/dlengine/config.py b/dlrouter/backends/dlengine/config.py new file mode 100644 index 0000000..3cc68a7 --- /dev/null +++ b/dlrouter/backends/dlengine/config.py @@ -0,0 +1,13 @@ +"""DLEngine backend configuration.""" + +from dataclasses import dataclass + + +@dataclass +class DLEngineConfig: + """Configuration for DLEngine hybrid nodes via dlslime-ctrl.""" + + ctrl_address: str | None = None + ctrl_scope: str | None = None + ctrl_kind: str = 'dlengine' + discovery_poll_interval: float = 5.0 diff --git a/dlrouter/backends/dlengine/definition.py b/dlrouter/backends/dlengine/definition.py new file mode 100644 index 0000000..bd80da6 --- /dev/null +++ b/dlrouter/backends/dlengine/definition.py @@ -0,0 +1,12 @@ +"""DLEngine backend definition.""" + +from dlrouter.backends.definition import BackendDefinition +from dlrouter.backends.dlengine.backend import DLEngineBackend +from dlrouter.constants import BackendType + + +DLENGINE_BACKEND_DEFINITION = BackendDefinition( + backend_type=BackendType.DLENGINE, + name='dlengine', + backend_cls=DLEngineBackend, +) diff --git a/dlrouter/backends/factory.py b/dlrouter/backends/factory.py index f1bde10..073790f 100644 --- a/dlrouter/backends/factory.py +++ b/dlrouter/backends/factory.py @@ -4,6 +4,7 @@ from dlrouter.backends.base import BaseBackend from dlrouter.backends.definition import BackendDefinition +from dlrouter.backends.dlengine import DLENGINE_BACKEND_DEFINITION from dlrouter.backends.lmdeploy import LMDEPLOY_BACKEND_DEFINITION from dlrouter.backends.sglang import SGLANG_BACKEND_DEFINITION from dlrouter.backends.vllm import VLLM_BACKEND_DEFINITION @@ -13,6 +14,7 @@ # Registry of backend definitions _BACKEND_REGISTRY: dict[BackendType, BackendDefinition] = { BackendType.LMDEPLOY: LMDEPLOY_BACKEND_DEFINITION, + BackendType.DLENGINE: DLENGINE_BACKEND_DEFINITION, BackendType.SGLANG: SGLANG_BACKEND_DEFINITION, BackendType.VLLM: VLLM_BACKEND_DEFINITION, } diff --git a/dlrouter/constants.py b/dlrouter/constants.py index 9a928a1..da803d7 100644 --- a/dlrouter/constants.py +++ b/dlrouter/constants.py @@ -37,6 +37,7 @@ class BackendType(str, enum.Enum): LMDEPLOY = 'lmdeploy' VLLM = 'vllm' SGLANG = 'sglang' + DLENGINE = 'dlengine' class ServingStrategy(str, enum.Enum): @@ -69,6 +70,7 @@ class ServiceDiscoveryMode(str, enum.Enum): STATIC = 'static' # 手动配置节点列表 (绝大多数场景) HEARTBEAT = 'heartbeat' # 心跳注册模式 (仅 vLLM P2P NCCL) + NANOCTRL = 'nanoctrl' # dlslime-ctrl entity registry (DLEngine serve) class ErrorCode(enum.IntEnum): diff --git a/dlrouter/core/service_discovery/__init__.py b/dlrouter/core/service_discovery/__init__.py index ee00274..6b956bb 100644 --- a/dlrouter/core/service_discovery/__init__.py +++ b/dlrouter/core/service_discovery/__init__.py @@ -22,6 +22,9 @@ from dlrouter.core.service_discovery.heartbeat_discovery import ( HeartbeatServiceDiscovery, ) +from dlrouter.core.service_discovery.nanoctrl_discovery import ( + NanoCtrlServiceDiscovery, +) from dlrouter.core.service_discovery.static_discovery import ( StaticServiceDiscovery, ) @@ -34,6 +37,7 @@ __all__ = [ 'BaseServiceDiscovery', 'HeartbeatServiceDiscovery', + 'NanoCtrlServiceDiscovery', 'NodeInfo', 'StaticServiceDiscovery', 'ZMQHeartbeatDiscovery', diff --git a/dlrouter/core/service_discovery/nanoctrl_discovery.py b/dlrouter/core/service_discovery/nanoctrl_discovery.py new file mode 100644 index 0000000..4696ac3 --- /dev/null +++ b/dlrouter/core/service_discovery/nanoctrl_discovery.py @@ -0,0 +1,195 @@ +"""dlslime-ctrl based service discovery for DLEngine HTTP servers. + +DLEngine ``serve`` registers OpenAI-compatible HTTP endpoints with +dlslime-ctrl (entity kind ``dlengine``). This discovery polls +``list_entities`` and syncs them into NodeManager. +""" + +from __future__ import annotations + +import threading +import time +from typing import TYPE_CHECKING, Any, Optional + +from dlrouter.constants import EngineRole +from dlrouter.core.service_discovery.base import BaseServiceDiscovery +from dlrouter.logger import get_logger +from dlrouter.models.node import NodeStatus + + +if TYPE_CHECKING: + from dlrouter.core.node_manager import NodeManager + + +logger = get_logger('dlrouter.service_discovery.nanoctrl') + +DEFAULT_CTRL_KIND = 'dlengine' + + +def _entity_http_url(entity: dict[str, Any]) -> str | None: + """Build node URL from a dlslime-ctrl entity record.""" + endpoint = entity.get('endpoint') or {} + if not isinstance(endpoint, dict): + return None + host = endpoint.get('host') + port = endpoint.get('port') + if not host or port is None: + return None + protocol = endpoint.get('protocol', 'http') + if protocol == 'https': + return f'https://{host}:{port}' + return f'http://{host}:{port}' + + +_ROLE_BY_NAME = { + 'prefill': EngineRole.PREFILL, + 'decode': EngineRole.DECODE, + 'hybrid': EngineRole.HYBRID, +} + + +def _entity_role(entity: dict[str, Any]) -> EngineRole: + """Map the entity ``metadata.role`` to an EngineRole (default HYBRID).""" + metadata = entity.get('metadata') or {} + if not isinstance(metadata, dict): + return EngineRole.HYBRID + role = str(metadata.get('role', 'hybrid')).strip().lower() + return _ROLE_BY_NAME.get(role, EngineRole.HYBRID) + + +def _entity_models(entity: dict[str, Any]) -> list[str]: + """Model aliases for routing (served name, path, basename).""" + metadata = entity.get('metadata') or {} + if not isinstance(metadata, dict): + return [] + + names: list[str] = [] + seen: set[str] = set() + + def _add(name: str | None) -> None: + if not name: + return + key = name.strip() + if not key or key in seen: + return + seen.add(key) + names.append(key) + + _add(metadata.get('served_model_name')) + model_path = metadata.get('model_path') + if model_path: + path = str(model_path).rstrip('/') + _add(path) + _add(path.split('/')[-1]) + return names + + +class NanoCtrlServiceDiscovery(BaseServiceDiscovery): + """Poll dlslime-ctrl and reconcile DLEngine HTTP nodes.""" + + def __init__( + self, + ctrl_address: str, + node_manager: Optional[NodeManager] = None, + ctrl_scope: Optional[str] = None, + ctrl_kind: str = DEFAULT_CTRL_KIND, + poll_interval: float = 5.0, + ) -> None: + super().__init__(node_manager=node_manager) + self._ctrl_address = ctrl_address + self._ctrl_scope = ctrl_scope + self._ctrl_kind = ctrl_kind + self._poll_interval = poll_interval + self._stop = threading.Event() + self._thread: threading.Thread | None = None + self._known_urls: set[str] = set() + self._client: Any = None + + def _get_client(self) -> Any: + if self._client is None: + try: + from dlslime.ctrl import NanoCtrlClient + except ImportError as e: + raise ImportError( + 'dlslime is required for DLEngine dlslime-ctrl discovery. ' + 'Install with: pip install dlslime', + ) from e + self._client = NanoCtrlClient(self._ctrl_address, self._ctrl_scope) + self._client.check_connection() + return self._client + + def _poll_once(self) -> None: + client = self._get_client() + entities = client.list_entities(kind=self._ctrl_kind) + live_urls: set[str] = set() + + for entity in entities: + node_url = _entity_http_url(entity) + if not node_url: + logger.warning(f'Skipping entity without HTTP endpoint: {entity}') + continue + live_urls.add(node_url) + if node_url in self._known_urls: + continue + models = _entity_models(entity) + if self._node_manager is None: + self._known_urls.add(node_url) + continue + role = _entity_role(entity) + status = NodeStatus(role=role, models=models) + if self._node_manager.add(node_url, status): + logger.info( + f'Discovered DLEngine node {node_url} ' + f'role={role.name} models={models}', + ) + self._known_urls.add(node_url) + + stale = self._known_urls - live_urls + for node_url in stale: + self._known_urls.discard(node_url) + if self._node_manager is not None: + self._node_manager.remove(node_url) + logger.info(f'Removed stale DLEngine node {node_url}') + + def _loop(self) -> None: + while not self._stop.wait(self._poll_interval): + try: + self._poll_once() + except Exception as e: + logger.error(f'NanoCtrl discovery poll failed: {e}', exc_info=True) + + def start(self) -> None: + self._running = True + try: + self._poll_once() + except Exception as e: + logger.error(f'NanoCtrl discovery initial poll failed: {e}') + self._stop.clear() + self._thread = threading.Thread( + target=self._loop, + name='dlrouter-nanoctrl-discovery', + daemon=True, + ) + self._thread.start() + logger.info( + f'NanoCtrl discovery started (ctrl={self._ctrl_address}, ' + f'kind={self._ctrl_kind}, interval={self._poll_interval}s)', + ) + + def stop(self) -> None: + self._running = False + self._stop.set() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=self._poll_interval + 2.0) + self._thread = None + if self._client is not None: + try: + self._client.stop() + except Exception: + pass + self._client = None + logger.info('NanoCtrl discovery stopped') + + def unregister_by_url(self, node_url: str) -> None: + """Allow HealthChecker removals to re-discover the same URL later.""" + self._known_urls.discard(node_url) diff --git a/pyproject.toml b/pyproject.toml index 3b4b32a..b18de56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,9 @@ dependencies = [ ] [project.optional-dependencies] +dlengine = [ + "dlslime>=0.1.16", +] prod = [ "gunicorn>=21.0.0", ] diff --git a/tests/backends/test_backend_contracts.py b/tests/backends/test_backend_contracts.py index 948cc36..6860183 100644 --- a/tests/backends/test_backend_contracts.py +++ b/tests/backends/test_backend_contracts.py @@ -12,6 +12,7 @@ (BackendType.VLLM, 'vllm'), (BackendType.LMDEPLOY, 'lmdeploy'), (BackendType.SGLANG, 'sglang'), + (BackendType.DLENGINE, 'dlengine'), ], ) def test_builtin_backends_expose_phase_one_capabilities( @@ -28,12 +29,13 @@ def test_builtin_backends_expose_phase_one_capabilities( assert definition.supports('check_health') is True assert definition.supports('register_node') is True assert definition.supports('deregister_node') is True - assert definition.supports('handle_pd_request') is True assert hasattr(backend, 'forward_request') assert hasattr(backend, 'stream_forward') assert hasattr(backend, 'fetch_models') assert hasattr(backend, 'check_health') + + assert definition.supports('handle_pd_request') is True assert hasattr(backend, 'handle_pd_request') assert backend.supports_pd_disagg() is True @@ -52,6 +54,12 @@ def test_builtin_backends_expose_phase_one_capabilities( ), (BackendType.SGLANG, {}, ServiceDiscoveryMode.STATIC), (BackendType.LMDEPLOY, {}, None), + ( + BackendType.DLENGINE, + {'ctrl_address': '127.0.0.1:4479'}, + ServiceDiscoveryMode.NANOCTRL, + ), + (BackendType.DLENGINE, {}, None), ], ) def test_builtin_backends_return_expected_discovery_preference( diff --git a/tests/core/test_nanoctrl_discovery.py b/tests/core/test_nanoctrl_discovery.py new file mode 100644 index 0000000..29c7d43 --- /dev/null +++ b/tests/core/test_nanoctrl_discovery.py @@ -0,0 +1,125 @@ +"""Tests for dlslime-ctrl based DLEngine discovery.""" + +from unittest.mock import MagicMock + +import pytest + +from dlrouter.backends.dlengine.backend import _sanitize_chat_payload +from dlrouter.backends.factory import create_backend +from dlrouter.constants import BackendType, EngineRole, ServiceDiscoveryMode +from dlrouter.core.node_manager import NodeManager +from dlrouter.core.service_discovery.nanoctrl_discovery import ( + NanoCtrlServiceDiscovery, + _entity_http_url, + _entity_models, + _entity_role, +) + + +def test_entity_url_and_models_helpers(): + entity = { + 'endpoint': {'host': '10.0.0.1', 'port': 8100, 'protocol': 'http'}, + 'metadata': { + 'served_model_name': 'Qwen3-0.6B', + 'model_path': '/home/jimy/models/Qwen3-0.6B', + }, + } + assert _entity_http_url(entity) == 'http://10.0.0.1:8100' + assert _entity_models(entity) == [ + 'Qwen3-0.6B', + '/home/jimy/models/Qwen3-0.6B', + ] + + +def test_sanitize_chat_payload_strips_router_fields(): + raw = { + 'model': 'Qwen3-0.6B', + 'messages': [{'role': 'user', 'content': 'hi'}], + 'stream': False, + 'top_k': 40, + 'session_id': 'abc', + 'session_params': {'session_id': 'abc'}, + } + assert _sanitize_chat_payload(raw) == { + 'model': 'Qwen3-0.6B', + 'messages': [{'role': 'user', 'content': 'hi'}], + 'stream': False, + } + + +def test_entity_role_maps_metadata_role(): + assert _entity_role({}) == EngineRole.HYBRID + assert _entity_role({'metadata': {'role': 'hybrid'}}) == EngineRole.HYBRID + assert _entity_role({'metadata': {'role': 'prefill'}}) == EngineRole.PREFILL + assert _entity_role({'metadata': {'role': 'Decode'}}) == EngineRole.DECODE + assert _entity_role({'metadata': {'role': 'bogus'}}) == EngineRole.HYBRID + + +def test_nanoctrl_discovery_assigns_pd_roles(): + backend = create_backend(BackendType.DLENGINE) + node_manager = NodeManager(backend=backend) + + discovery = NanoCtrlServiceDiscovery( + ctrl_address='127.0.0.1:4479', + node_manager=node_manager, + ctrl_kind='dlengine', + poll_interval=60.0, + ) + + fake_client = MagicMock() + fake_client.list_entities.return_value = [ + { + 'endpoint': {'host': '127.0.0.1', 'port': 8100, 'protocol': 'http'}, + 'metadata': {'served_model_name': 'Qwen3-4B', 'role': 'prefill'}, + }, + { + 'endpoint': {'host': '127.0.0.1', 'port': 8200, 'protocol': 'http'}, + 'metadata': {'served_model_name': 'Qwen3-4B', 'role': 'decode'}, + }, + ] + discovery._client = fake_client + + discovery._poll_once() + assert node_manager.nodes['http://127.0.0.1:8100'].role == EngineRole.PREFILL + assert node_manager.nodes['http://127.0.0.1:8200'].role == EngineRole.DECODE + + +def test_dlengine_backend_prefers_nanoctrl_when_ctrl_set(): + backend = create_backend( + BackendType.DLENGINE, + {'ctrl_address': '127.0.0.1:4479'}, + ) + assert backend.preferred_discovery_mode({'ctrl_address': '127.0.0.1:4479'}) == ( + ServiceDiscoveryMode.NANOCTRL + ) + assert backend.preferred_discovery_mode({}) is None + assert backend.supports_pd_disagg() is True + + +def test_nanoctrl_discovery_syncs_nodes(): + backend = create_backend(BackendType.DLENGINE) + node_manager = NodeManager(backend=backend) + + discovery = NanoCtrlServiceDiscovery( + ctrl_address='127.0.0.1:4479', + node_manager=node_manager, + ctrl_kind='dlengine', + poll_interval=60.0, + ) + + fake_client = MagicMock() + fake_client.list_entities.return_value = [ + { + 'endpoint': {'host': '127.0.0.1', 'port': 8100, 'protocol': 'http'}, + 'metadata': {'served_model_name': 'Qwen3-4B'}, + }, + ] + discovery._client = fake_client + + discovery._poll_once() + assert 'http://127.0.0.1:8100' in node_manager.nodes + assert node_manager.nodes['http://127.0.0.1:8100'].models == ['Qwen3-4B'] + + fake_client.list_entities.return_value = [] + discovery._poll_once() + assert 'http://127.0.0.1:8100' not in node_manager.nodes