feat: Train-Inference Disaggregation for Remote Target Model Serving #569
feat: Train-Inference Disaggregation for Remote Target Model Serving #569moehanabi wants to merge 18 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces train-inference disaggregation (remote training) to SpecForge, allowing the target model to run on a standalone inference server while the draft model trains on a separate GPU. This architecture is supported by a dedicated NCCL transport layer for GPU-to-GPU tensor transfer, a custom binary wire format fallback, and a prefetch pipeline to overlap target inference with draft training. Additionally, it adds support for Qwen3.5 VLM training, custom model-parallel initialization, and updated chat templates. Critical feedback was provided regarding an invalid transformers dependency version constraint, a spacing calculation bug in target layer ID selection, a position ID permutation bug when the batch size is exactly 3, a thread-safety issue with collective distributed calls in the multi-threaded HTTP server, and a potential shutdown hang when broadcasting the exit signal to worker ranks.
| "torchaudio==2.9.1", | ||
| "torchvision==0.24.1", | ||
| "transformers==4.57.1", | ||
| "transformers>=5.2.0", |
There was a problem hiding this comment.
The dependency transformers>=5.2.0 is specified, but Hugging Face transformers is currently on version 4.x. Version 5.x does not exist on PyPI, which will cause pip install to fail. Please adjust the version constraint to a valid 4.x release.
| "transformers>=5.2.0", | |
| "transformers>=4.45.0", |
| if num_draft_layers >= len(eligible): | ||
| return eligible[:num_draft_layers] | ||
| step = len(eligible) / num_draft_layers | ||
| return [eligible[int(round(i * step))] for i in range(num_draft_layers)] |
There was a problem hiding this comment.
In _build_target_layer_ids, when layer_types is not None, the spacing step is calculated as step = len(eligible) / num_draft_layers. This does not span the full range of eligible layers and leaves out the last eligible layer (e.g., for num_draft_layers = 2 and len(eligible) = 5, it selects indices 0 and 2 instead of 0 and 4). Spanning the full range is critical for the draft model to capture features from the end of the target model.
| if num_draft_layers >= len(eligible): | |
| return eligible[:num_draft_layers] | |
| step = len(eligible) / num_draft_layers | |
| return [eligible[int(round(i * step))] for i in range(num_draft_layers)] | |
| if num_draft_layers >= len(eligible): | |
| return eligible[:num_draft_layers] | |
| if num_draft_layers == 1: | |
| return [eligible[len(eligible) // 2]] | |
| step = (len(eligible) - 1) / (num_draft_layers - 1) | |
| return [eligible[int(round(i * step))] for i in range(num_draft_layers)] |
| if context_position_ids.ndim == 3 and context_position_ids.shape[0] != 3: | ||
| if context_position_ids.shape[1] == 3: | ||
| context_position_ids = context_position_ids.permute( | ||
| 1, 0, 2 | ||
| ).contiguous() |
There was a problem hiding this comment.
In forward, the permutation of context_position_ids from [B, 3, S] to [3, B, S] is skipped if context_position_ids.shape[0] == 3. However, if the batch size B is exactly 3, the shape is [3, 3, S], and the permutation is incorrectly skipped. This causes the batch dimension to be mixed up with the RoPE component dimension, leading to incorrect position IDs and training instability/loss explosion. Since the DataLoader collates the batch as the first dimension, the input layout is guaranteed to be [B, 3, S], so we should check if shape[0] == bsz and shape[1] == 3 to safely permute.
| if context_position_ids.ndim == 3 and context_position_ids.shape[0] != 3: | |
| if context_position_ids.shape[1] == 3: | |
| context_position_ids = context_position_ids.permute( | |
| 1, 0, 2 | |
| ).contiguous() | |
| if context_position_ids.ndim == 3: | |
| if context_position_ids.shape[0] == bsz and context_position_ids.shape[1] == 3: | |
| context_position_ids = context_position_ids.permute( | |
| 1, 0, 2 | |
| ).contiguous() |
| def _route_request_synced(app, path, body): | ||
| """Route a request, broadcasting across tp ranks so all ranks participate in the | ||
| forward pass (required by NCCL allreduce inside SGLang model_runner). | ||
|
|
||
| For heavy forward-pass endpoints a rank-0-only serialisation path is used | ||
| so that worker ranks do not waste CPU serialising results that are discarded. | ||
| """ | ||
| if not dist.is_initialized() or dist.get_world_size() == 1: | ||
| try: | ||
| return _route_request(app, path, body) | ||
| except Exception: | ||
| logger.exception("Error handling %s", path) | ||
| return json.dumps({"error": "Internal server error"}).encode(), 500 | ||
|
|
||
| # /init_nccl must run ONLY on rank 0 — it sets up a separate 2-rank NCCL | ||
| # group (server rank 0 + training client rank 1). TP worker ranks must | ||
| # NOT participate. | ||
| if path == "/init_nccl": | ||
| return _route_request(app, path, body) | ||
|
|
||
| rank = dist.get_rank() | ||
| synced_path, body_bytes = _broadcast_request(path, body) | ||
|
|
||
| # Heavy endpoints: all ranks execute the forward pass (NCCL requires it), | ||
| # but only rank 0 needs post-processing (target_p, serialisation). | ||
| # Non-rank-0 workers skip post-processing to avoid blocking the next step. | ||
| if synced_path == "/generate_eagle3_data": | ||
| out = None | ||
| try: | ||
| out = app._run_generate_eagle3_data( | ||
| body_bytes, rank_only_forward=(rank != 0) | ||
| ) | ||
| except Exception: | ||
| logger.exception("Error handling %s (rank %d)", synced_path, rank) | ||
| if rank == 0: | ||
| if out is not None: | ||
| return app._serialize_response(out) | ||
| return ( | ||
| json.dumps( | ||
| {"error": "Internal server error during forward pass"} | ||
| ).encode(), | ||
| 500, | ||
| ) | ||
| return None | ||
|
|
||
| if synced_path == "/generate_dflash_data": | ||
| out = None | ||
| try: | ||
| out = app._run_generate_dflash_data( | ||
| body_bytes, rank_only_forward=(rank != 0) | ||
| ) | ||
| except Exception: | ||
| logger.exception("Error handling %s (rank %d)", synced_path, rank) | ||
| if rank == 0: | ||
| if out is not None: | ||
| return app._serialize_response(out) | ||
| return ( | ||
| json.dumps( | ||
| {"error": "Internal server error during forward pass"} | ||
| ).encode(), | ||
| 500, | ||
| ) | ||
| return None | ||
|
|
||
| # Lightweight endpoints: broadcast + execute on all ranks | ||
| try: | ||
| result = _route_request(app, synced_path, body_bytes) | ||
| except Exception: | ||
| logger.exception("Error handling %s (rank %d)", synced_path, rank) | ||
| result = None | ||
|
|
||
| if rank == 0: | ||
| if result is not None: | ||
| return result | ||
| return json.dumps({"error": "Internal server error"}).encode(), 500 | ||
| return None | ||
|
|
There was a problem hiding this comment.
In remote_target_server.py, _route_request_synced performs collective PyTorch distributed calls (dist.broadcast_object_list and dist.broadcast) to synchronize requests with worker ranks. However, the HTTP server is multi-threaded (ThreadingHTTPServer), and collective calls are not thread-safe. While heavy forward passes are protected by _forward_semaphore, lightweight endpoints (like set_vocab_mapping or get_model_info) bypass the semaphore and can execute _route_request_synced concurrently on separate threads. This will corrupt the collective communication state and cause deadlocks or desynchronization crashes. We should introduce a global lock to serialize all synchronized requests.
_synced_request_lock = threading.Lock()
def _route_request_synced(app, path, body):
"""Route a request, broadcasting across tp ranks so all ranks participate in the
forward pass (required by NCCL allreduce inside SGLang model_runner).
For heavy forward-pass endpoints a rank-0-only serialisation path is used
so that worker ranks do not waste CPU serialising results that are discarded.
"""
if not dist.is_initialized() or dist.get_world_size() == 1:
try:
return _route_request(app, path, body)
except Exception:
logger.exception("Error handling %s", path)
return json.dumps({"error": "Internal server error"}).encode(), 500
# /init_nccl must run ONLY on rank 0 — it sets up a separate 2-rank NCCL
# group (server rank 0 + training client rank 1). TP worker ranks must
# NOT participate.
if path == "/init_nccl":
return _route_request(app, path, body)
with _synced_request_lock:
rank = dist.get_rank()
synced_path, body_bytes = _broadcast_request(path, body)
# Heavy endpoints: all ranks execute the forward pass (NCCL requires it),
# but only rank 0 needs post-processing (target_p, serialisation).
# Non-rank-0 workers skip post-processing to avoid blocking the next step.
if synced_path == "/generate_eagle3_data":
out = None
try:
out = app._run_generate_eagle3_data(
body_bytes, rank_only_forward=(rank != 0)
)
except Exception:
logger.exception("Error handling %s (rank %d)", synced_path, rank)
if rank == 0:
if out is not None:
return app._serialize_response(out)
return (
json.dumps(
{"error": "Internal server error during forward pass"}
).encode(),
500,
)
return None
if synced_path == "/generate_dflash_data":
out = None
try:
out = app._run_generate_dflash_data(
body_bytes, rank_only_forward=(rank != 0)
)
except Exception:
logger.exception("Error handling %s (rank %d)", synced_path, rank)
if rank == 0:
if out is not None:
return app._serialize_response(out)
return (
json.dumps(
{"error": "Internal server error during forward pass"}
).encode(),
500,
)
return None
# Lightweight endpoints: broadcast + execute on all ranks
try:
result = _route_request(app, synced_path, body_bytes)
except Exception:
logger.exception("Error handling %s (rank %d)", synced_path, rank)
result = None
if rank == 0:
if result is not None:
return result
return json.dumps({"error": "Internal server error"}).encode(), 500
return None| if dist.get_world_size() > 1: | ||
| from specforge.modeling.target.remote_target_server import ( | ||
| _SENTINEL_EXIT, | ||
| ) | ||
|
|
||
| dist.broadcast_object_list([_SENTINEL_EXIT], src=0) |
There was a problem hiding this comment.
During shutdown, rank 0 attempts to broadcast _SENTINEL_EXIT to worker ranks using dist.broadcast_object_list. However, if worker ranks have already exited or are in the process of exiting (e.g., due to receiving SIGTERM from torchrun at the same time), this collective call will hang or raise a distributed communication error. Wrapping this call in a try...except block ensures a clean and robust shutdown without hanging.
| if dist.get_world_size() > 1: | |
| from specforge.modeling.target.remote_target_server import ( | |
| _SENTINEL_EXIT, | |
| ) | |
| dist.broadcast_object_list([_SENTINEL_EXIT], src=0) | |
| if dist.get_world_size() > 1: | |
| from specforge.modeling.target.remote_target_server import ( | |
| _SENTINEL_EXIT, | |
| ) | |
| try: | |
| dist.broadcast_object_list([_SENTINEL_EXIT], src=0) | |
| except Exception: | |
| logger.warning("Failed to broadcast exit signal to worker ranks (they may have already exited).") |
Signed-off-by: EanWang211123 <wangyiheng@sangfor.com.cn>
Signed-off-by: EanWang211123 <wangyiheng@sangfor.com.cn>
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR introduces train-inference disaggregation for SpecForge, deploying the target model as a standalone HTTP/NCCL server while draft model training runs separately. It adds remote backend support for both Eagle3 and DFlash training paths, plus several SGLang patches and Qwen3.5 VLM support.
Changes:
- New remote target server/client with NCCL data plane and custom wire-format fallback for tensor transport
- Training scripts (
train_eagle3.py,train_dflash.py) gainremotebackend + async target prefetch pipelining - SGLang backend patches updated for newer SGLang API; Qwen3.5/VLM handling for DFlash; FSDP/distributed cleanup improvements
Reviewed changes
Copilot reviewed 31 out of 31 changed files in this pull request and generated 19 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_modeling/test_target/test_nccl_transport.py | New unit + GPU + multiprocess integration tests for NCCL transport |
| specforge/modeling/target/_nccl_transport.py | New NCCL transport (TCPStore rendezvous, send/recv, metadata encode/decode) |
| specforge/modeling/target/_tensor_wire.py | New compact binary tensor wire format |
| specforge/modeling/target/remote_target_server.py | HTTP server wrapping SGLang target model with TP broadcast + NCCL send |
| specforge/modeling/target/remote_target_client.py | Remote client with NCCL recv, async prefetch, TP broadcast |
| specforge/modeling/target/eagle3_target_model.py / dflash_target_model.py | Add remote backend hooks, rank-only-forward path, Qwen3.5 VLM support |
| specforge/modeling/target/sglang_backend/{patch,utils,model_runner}.py | SGLang version-compat updates (removed multi-item scoring, new group params) |
| specforge/modeling/target/init.py | Export remote target classes |
| specforge/modeling/target/custom_backend/{llama,llama4,phi3,gpt_oss}.py | Remove check_model_inputs decorator import/usage |
| specforge/core/eagle3.py / dflash.py | Support precomputed position_mask; multimodal position_ids in DFlash |
| specforge/distributed.py | GPU id env override; safer destroy_distributed |
| specforge/utils.py | Remote config helpers; hf_config_dict plumbing; preformatted text row passthrough |
| specforge/data/{template,preprocessing}.py | Register qwen3.5 templates; pass empty tools for preformatted |
| specforge/args.py | Add RemoteBackendArgs; rename SGLang piecewise flag |
| scripts/{train_eagle3,train_dflash,launch_target_server}.py | Add remote backend, prefetch pipeline, server launcher |
| pyproject.toml | Bump transformers>=5.2.0 |
| docs/, examples/, configs/ | New remote training docs, qwen3.5 example, dflash config |
Comments suppressed due to low confidence (1)
tests/test_modeling/test_target/test_nccl_transport.py:1
- The
as eclause is unused (onlytraceback.format_exc()is used). Either drop theas eto silence linters (except Exception:) or includerepr(e)in the error payload. The same applies to the other_gpu_*helper functions in this file.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| except (requests.ConnectionError, requests.Timeout) as exc: | ||
| if attempt < self.max_retries: | ||
| time.sleep(2**attempt) | ||
| continue | ||
| raise RuntimeError( | ||
| f"Remote request failed after {self.max_retries + 1} attempts: {exc}" | ||
| ) from exc |
| def _request(self, endpoint: str, payload: bytes) -> bytes: | ||
| """POST payload to endpoint with exponential-backoff retry.""" | ||
| url = f"{self.url}/{endpoint.lstrip('/')}" | ||
| last_exc = None | ||
|
|
||
| for attempt in range(self.max_retries + 1): | ||
| try: | ||
| resp = self._session.post( | ||
| url, | ||
| data=payload, | ||
| timeout=self.timeout, | ||
| headers={"Content-Type": "application/octet-stream"}, | ||
| ) | ||
| resp.raise_for_status() | ||
| return resp.content | ||
| except (requests.ConnectionError, requests.Timeout) as exc: | ||
| last_exc = exc | ||
| if attempt < self.max_retries: |
| enable_torch_compile: bool = True, | ||
| nccl_port: int = None, | ||
| host: str = "0.0.0.0", | ||
| attention_backend: str | None = None, |
| prefetch_queue = [] | ||
| torch_profiler = None | ||
|
|
||
| def next_prefetch_boundary() -> int | float: |
| if dist.get_rank() == 0: | ||
| progress_bar = tqdm( | ||
| train_dataloader, desc=f"Training Epoch {epoch}", leave=True | ||
| def next_prefetch_boundary() -> int | float: |
| QWEN3_5_MODEL_TYPES = {"qwen3_5", "qwen3_5_moe"} | ||
| VLM_MODEL_TYPES = QWEN3_5_MODEL_TYPES |
| tracker.close() | ||
| destroy_distributed() | ||
| # Save final checkpoint if training ended without saving | ||
| if args.save_interval <= 0 or global_step % args.save_interval != 0: |
| from torch.distributed.distributed_c10d import ( | ||
| _unregister_process_group, | ||
| _world, | ||
| ) |
| "hf_config_dict": hf_config_dict, | ||
| "server_model_path": server_model_path, | ||
| } | ||
| ).encode("utf-8") |
| ( | ||
| "sliding_attention" | ||
| if sliding_window is not None and layer_idx >= max_window_layers | ||
| else "full_attention" | ||
| ) | ||
| for layer_idx in range(num_hidden_layers) | ||
| ] |
7499e80 to
f12b8eb
Compare
f12b8eb to
b0f94b4
Compare
960b2ac to
f2dbaa5
Compare
|
Move to #573 to rebase newest commit |
Motivation
In the current co-located training setup, the target model (e.g., Qwen3-30B-A3B-FP8) and draft model share the same GPU, causing memory contention, compute serialization, and scaling limitations. This PR introduces train-inference disaggregation — deploying the target model as an independent inference server while the draft model trains on a separate GPU. This enables:
Modifications
remote_target_server.py,launch_target_server.py): Standalone HTTP + NCCL server that runs target model forward and sends hidden states / target_p to the training client via GPU-direct NCCL transfer.remote_target_client.py): Training-side backend that sends requests over HTTP and receives large tensors via NCCL recv, with support for TP broadcast._nccl_transport.py): Dedicated 2-process NCCL group for GPU-to-GPU tensor transfer, supporting both intra-node NVLink and inter-node RDMA._tensor_wire.py): Compact binary serialization when NCCL is unavailable.train_eagle3.py,train_dflash.py): Extended with--target-model-backend remote,--remote-url(s), and--target-prefetch-deptharguments.args.py,distributed.py): AddedRemoteBackendArgs,SGLangBackendArgs, andSPECFORGE_GPU_IDsupport.docs/advanced_features/remote_training.md.test_nccl_transport.py).Related Issues
Accuracy Test
Loss consistency verified across all configurations (100 steps, Qwen3-30B-A3B-FP8):
Benchmark & Profiling
Cross-machine benchmark (2x 8xHopper 80G, RoCE v2 RDMA 388 Gb/s, Qwen3-30B-A3B-FP8, max_length=12288):
Key findings:
Checklist