Skip to content

feat: Train-Inference Disaggregation for Remote Target Model Serving #569

Closed
moehanabi wants to merge 18 commits into
sgl-project:mainfrom
moehanabi:remote_train_rebase
Closed

feat: Train-Inference Disaggregation for Remote Target Model Serving #569
moehanabi wants to merge 18 commits into
sgl-project:mainfrom
moehanabi:remote_train_rebase

Conversation

@moehanabi
Copy link
Copy Markdown
Contributor

@moehanabi moehanabi commented May 28, 2026

Motivation

In the current co-located training setup, the target model (e.g., Qwen3-30B-A3B-FP8) and draft model share the same GPU, causing memory contention, compute serialization, and scaling limitations. This PR introduces train-inference disaggregation — deploying the target model as an independent inference server while the draft model trains on a separate GPU. This enables:

  • Elimination of GPU memory contention between target and draft models
  • Pipeline parallelism via prefetch to overlap target inference with draft training
  • Independent scaling of inference and training resources
  • Up to 2.37x training speedup with dual-server prefetch

Modifications

  • Remote target server (remote_target_server.py, launch_target_server.py): Standalone HTTP + NCCL server that runs target model forward and sends hidden states / target_p to the training client via GPU-direct NCCL transfer.
  • Remote target client (remote_target_client.py): Training-side backend that sends requests over HTTP and receives large tensors via NCCL recv, with support for TP broadcast.
  • NCCL transport layer (_nccl_transport.py): Dedicated 2-process NCCL group for GPU-to-GPU tensor transfer, supporting both intra-node NVLink and inter-node RDMA.
  • Wire format fallback (_tensor_wire.py): Compact binary serialization when NCCL is unavailable.
  • Prefetch pipeline: Async prefetch queue with configurable depth and multi-server round-robin scheduling to fully hide target forward latency.
  • Training scripts (train_eagle3.py, train_dflash.py): Extended with --target-model-backend remote, --remote-url(s), and --target-prefetch-depth arguments.
  • Args & distributed (args.py, distributed.py): Added RemoteBackendArgs, SGLangBackendArgs, and SPECFORGE_GPU_ID support.
  • Bug fixes: Padding memory optimization, SGLang v0.5.10 compatibility, cross-machine NCCL support, correct exit handling.
  • Qwen3.5 support: Added DFlash config and training script for Qwen3.5. (This pr is based on [Feature] Support DFlash Speculative Decoding Training for Qwen3.5 Models #495 . Thanks for your contributor! @EanWang211123 )
  • Documentation: Added comprehensive English documentation under docs/advanced_features/remote_training.md.
  • Tests: Added NCCL transport unit tests (test_nccl_transport.py).

Related Issues

Accuracy Test

Loss consistency verified across all configurations (100 steps, Qwen3-30B-A3B-FP8):

Model Baseline Loss Remote Loss Match
DFlash TP=1 6.4594 6.4594 ✅ Exact
DFlash TP=2 6.4624 6.4624 ✅ Exact
EAGLE3 TP=1 0.1999 0.1999 ✅ Exact
EAGLE3 TP=2 0.2006 0.2006 ✅ Exact
  • Prefetch depth (0/1/2) has no impact on accuracy.
  • Cross-machine RDMA introduces no precision loss.
  • Single-server vs dual-server accuracy is identical.

Benchmark & Profiling

Cross-machine benchmark (2x 8xHopper 80G, RoCE v2 RDMA 388 Gb/s, Qwen3-30B-A3B-FP8, max_length=12288):

Scenario Config Iter Time Speedup vs Baseline
DFlash, single server TP=2, depth=1 0.132s 1.63x
DFlash, dual server TP=2, depth=2 0.094s 2.28x
EAGLE3, single server TP=2, depth=1 0.325s 1.88x
EAGLE3, dual server TP=2, depth=1 0.291s 2.09x

Key findings:

  • Disaggregation alone (depth=0) provides 3-6% speedup from eliminating GPU contention
  • Prefetch depth=1 adds 36-53% speedup by overlapping target forward with draft training
  • Dual-server depth=2 achieves peak 2.37x for DFlash by fully pipelining target forward

Checklist

Copilot AI review requested due to automatic review settings May 28, 2026 08:42
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces train-inference disaggregation (remote training) to SpecForge, allowing the target model to run on a standalone inference server while the draft model trains on a separate GPU. This architecture is supported by a dedicated NCCL transport layer for GPU-to-GPU tensor transfer, a custom binary wire format fallback, and a prefetch pipeline to overlap target inference with draft training. Additionally, it adds support for Qwen3.5 VLM training, custom model-parallel initialization, and updated chat templates. Critical feedback was provided regarding an invalid transformers dependency version constraint, a spacing calculation bug in target layer ID selection, a position ID permutation bug when the batch size is exactly 3, a thread-safety issue with collective distributed calls in the multi-threaded HTTP server, and a potential shutdown hang when broadcasting the exit signal to worker ranks.

Comment thread pyproject.toml
"torchaudio==2.9.1",
"torchvision==0.24.1",
"transformers==4.57.1",
"transformers>=5.2.0",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dependency transformers>=5.2.0 is specified, but Hugging Face transformers is currently on version 4.x. Version 5.x does not exist on PyPI, which will cause pip install to fail. Please adjust the version constraint to a valid 4.x release.

Suggested change
"transformers>=5.2.0",
"transformers>=4.45.0",

Comment thread scripts/train_dflash.py
Comment on lines +93 to +96
if num_draft_layers >= len(eligible):
return eligible[:num_draft_layers]
step = len(eligible) / num_draft_layers
return [eligible[int(round(i * step))] for i in range(num_draft_layers)]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _build_target_layer_ids, when layer_types is not None, the spacing step is calculated as step = len(eligible) / num_draft_layers. This does not span the full range of eligible layers and leaves out the last eligible layer (e.g., for num_draft_layers = 2 and len(eligible) = 5, it selects indices 0 and 2 instead of 0 and 4). Spanning the full range is critical for the draft model to capture features from the end of the target model.

Suggested change
if num_draft_layers >= len(eligible):
return eligible[:num_draft_layers]
step = len(eligible) / num_draft_layers
return [eligible[int(round(i * step))] for i in range(num_draft_layers)]
if num_draft_layers >= len(eligible):
return eligible[:num_draft_layers]
if num_draft_layers == 1:
return [eligible[len(eligible) // 2]]
step = (len(eligible) - 1) / (num_draft_layers - 1)
return [eligible[int(round(i * step))] for i in range(num_draft_layers)]

Comment thread specforge/core/dflash.py
Comment on lines +285 to +289
if context_position_ids.ndim == 3 and context_position_ids.shape[0] != 3:
if context_position_ids.shape[1] == 3:
context_position_ids = context_position_ids.permute(
1, 0, 2
).contiguous()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In forward, the permutation of context_position_ids from [B, 3, S] to [3, B, S] is skipped if context_position_ids.shape[0] == 3. However, if the batch size B is exactly 3, the shape is [3, 3, S], and the permutation is incorrectly skipped. This causes the batch dimension to be mixed up with the RoPE component dimension, leading to incorrect position IDs and training instability/loss explosion. Since the DataLoader collates the batch as the first dimension, the input layout is guaranteed to be [B, 3, S], so we should check if shape[0] == bsz and shape[1] == 3 to safely permute.

Suggested change
if context_position_ids.ndim == 3 and context_position_ids.shape[0] != 3:
if context_position_ids.shape[1] == 3:
context_position_ids = context_position_ids.permute(
1, 0, 2
).contiguous()
if context_position_ids.ndim == 3:
if context_position_ids.shape[0] == bsz and context_position_ids.shape[1] == 3:
context_position_ids = context_position_ids.permute(
1, 0, 2
).contiguous()

Comment on lines +687 to +763
def _route_request_synced(app, path, body):
"""Route a request, broadcasting across tp ranks so all ranks participate in the
forward pass (required by NCCL allreduce inside SGLang model_runner).

For heavy forward-pass endpoints a rank-0-only serialisation path is used
so that worker ranks do not waste CPU serialising results that are discarded.
"""
if not dist.is_initialized() or dist.get_world_size() == 1:
try:
return _route_request(app, path, body)
except Exception:
logger.exception("Error handling %s", path)
return json.dumps({"error": "Internal server error"}).encode(), 500

# /init_nccl must run ONLY on rank 0 — it sets up a separate 2-rank NCCL
# group (server rank 0 + training client rank 1). TP worker ranks must
# NOT participate.
if path == "/init_nccl":
return _route_request(app, path, body)

rank = dist.get_rank()
synced_path, body_bytes = _broadcast_request(path, body)

# Heavy endpoints: all ranks execute the forward pass (NCCL requires it),
# but only rank 0 needs post-processing (target_p, serialisation).
# Non-rank-0 workers skip post-processing to avoid blocking the next step.
if synced_path == "/generate_eagle3_data":
out = None
try:
out = app._run_generate_eagle3_data(
body_bytes, rank_only_forward=(rank != 0)
)
except Exception:
logger.exception("Error handling %s (rank %d)", synced_path, rank)
if rank == 0:
if out is not None:
return app._serialize_response(out)
return (
json.dumps(
{"error": "Internal server error during forward pass"}
).encode(),
500,
)
return None

if synced_path == "/generate_dflash_data":
out = None
try:
out = app._run_generate_dflash_data(
body_bytes, rank_only_forward=(rank != 0)
)
except Exception:
logger.exception("Error handling %s (rank %d)", synced_path, rank)
if rank == 0:
if out is not None:
return app._serialize_response(out)
return (
json.dumps(
{"error": "Internal server error during forward pass"}
).encode(),
500,
)
return None

# Lightweight endpoints: broadcast + execute on all ranks
try:
result = _route_request(app, synced_path, body_bytes)
except Exception:
logger.exception("Error handling %s (rank %d)", synced_path, rank)
result = None

if rank == 0:
if result is not None:
return result
return json.dumps({"error": "Internal server error"}).encode(), 500
return None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In remote_target_server.py, _route_request_synced performs collective PyTorch distributed calls (dist.broadcast_object_list and dist.broadcast) to synchronize requests with worker ranks. However, the HTTP server is multi-threaded (ThreadingHTTPServer), and collective calls are not thread-safe. While heavy forward passes are protected by _forward_semaphore, lightweight endpoints (like set_vocab_mapping or get_model_info) bypass the semaphore and can execute _route_request_synced concurrently on separate threads. This will corrupt the collective communication state and cause deadlocks or desynchronization crashes. We should introduce a global lock to serialize all synchronized requests.

_synced_request_lock = threading.Lock()


def _route_request_synced(app, path, body):
    """Route a request, broadcasting across tp ranks so all ranks participate in the
    forward pass (required by NCCL allreduce inside SGLang model_runner).

    For heavy forward-pass endpoints a rank-0-only serialisation path is used
    so that worker ranks do not waste CPU serialising results that are discarded.
    """
    if not dist.is_initialized() or dist.get_world_size() == 1:
        try:
            return _route_request(app, path, body)
        except Exception:
            logger.exception("Error handling %s", path)
            return json.dumps({"error": "Internal server error"}).encode(), 500

    # /init_nccl must run ONLY on rank 0 — it sets up a separate 2-rank NCCL
    # group (server rank 0 + training client rank 1).  TP worker ranks must
    # NOT participate.
    if path == "/init_nccl":
        return _route_request(app, path, body)

    with _synced_request_lock:
        rank = dist.get_rank()
        synced_path, body_bytes = _broadcast_request(path, body)

        # Heavy endpoints: all ranks execute the forward pass (NCCL requires it),
        # but only rank 0 needs post-processing (target_p, serialisation).
        # Non-rank-0 workers skip post-processing to avoid blocking the next step.
        if synced_path == "/generate_eagle3_data":
            out = None
            try:
                out = app._run_generate_eagle3_data(
                    body_bytes, rank_only_forward=(rank != 0)
                )
            except Exception:
                logger.exception("Error handling %s (rank %d)", synced_path, rank)
            if rank == 0:
                if out is not None:
                    return app._serialize_response(out)
                return (
                    json.dumps(
                        {"error": "Internal server error during forward pass"}
                    ).encode(),
                    500,
                )
            return None

        if synced_path == "/generate_dflash_data":
            out = None
            try:
                out = app._run_generate_dflash_data(
                    body_bytes, rank_only_forward=(rank != 0)
                )
            except Exception:
                logger.exception("Error handling %s (rank %d)", synced_path, rank)
            if rank == 0:
                if out is not None:
                    return app._serialize_response(out)
                return (
                    json.dumps(
                        {"error": "Internal server error during forward pass"}
                    ).encode(),
                    500,
                )
            return None

        # Lightweight endpoints: broadcast + execute on all ranks
        try:
            result = _route_request(app, synced_path, body_bytes)
        except Exception:
            logger.exception("Error handling %s (rank %d)", synced_path, rank)
            result = None

        if rank == 0:
            if result is not None:
                return result
            return json.dumps({"error": "Internal server error"}).encode(), 500
        return None

Comment on lines +191 to +196
if dist.get_world_size() > 1:
from specforge.modeling.target.remote_target_server import (
_SENTINEL_EXIT,
)

dist.broadcast_object_list([_SENTINEL_EXIT], src=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

During shutdown, rank 0 attempts to broadcast _SENTINEL_EXIT to worker ranks using dist.broadcast_object_list. However, if worker ranks have already exited or are in the process of exiting (e.g., due to receiving SIGTERM from torchrun at the same time), this collective call will hang or raise a distributed communication error. Wrapping this call in a try...except block ensures a clean and robust shutdown without hanging.

Suggested change
if dist.get_world_size() > 1:
from specforge.modeling.target.remote_target_server import (
_SENTINEL_EXIT,
)
dist.broadcast_object_list([_SENTINEL_EXIT], src=0)
if dist.get_world_size() > 1:
from specforge.modeling.target.remote_target_server import (
_SENTINEL_EXIT,
)
try:
dist.broadcast_object_list([_SENTINEL_EXIT], src=0)
except Exception:
logger.warning("Failed to broadcast exit signal to worker ranks (they may have already exited).")

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR introduces train-inference disaggregation for SpecForge, deploying the target model as a standalone HTTP/NCCL server while draft model training runs separately. It adds remote backend support for both Eagle3 and DFlash training paths, plus several SGLang patches and Qwen3.5 VLM support.

Changes:

  • New remote target server/client with NCCL data plane and custom wire-format fallback for tensor transport
  • Training scripts (train_eagle3.py, train_dflash.py) gain remote backend + async target prefetch pipelining
  • SGLang backend patches updated for newer SGLang API; Qwen3.5/VLM handling for DFlash; FSDP/distributed cleanup improvements

Reviewed changes

Copilot reviewed 31 out of 31 changed files in this pull request and generated 19 comments.

Show a summary per file
File Description
tests/test_modeling/test_target/test_nccl_transport.py New unit + GPU + multiprocess integration tests for NCCL transport
specforge/modeling/target/_nccl_transport.py New NCCL transport (TCPStore rendezvous, send/recv, metadata encode/decode)
specforge/modeling/target/_tensor_wire.py New compact binary tensor wire format
specforge/modeling/target/remote_target_server.py HTTP server wrapping SGLang target model with TP broadcast + NCCL send
specforge/modeling/target/remote_target_client.py Remote client with NCCL recv, async prefetch, TP broadcast
specforge/modeling/target/eagle3_target_model.py / dflash_target_model.py Add remote backend hooks, rank-only-forward path, Qwen3.5 VLM support
specforge/modeling/target/sglang_backend/{patch,utils,model_runner}.py SGLang version-compat updates (removed multi-item scoring, new group params)
specforge/modeling/target/init.py Export remote target classes
specforge/modeling/target/custom_backend/{llama,llama4,phi3,gpt_oss}.py Remove check_model_inputs decorator import/usage
specforge/core/eagle3.py / dflash.py Support precomputed position_mask; multimodal position_ids in DFlash
specforge/distributed.py GPU id env override; safer destroy_distributed
specforge/utils.py Remote config helpers; hf_config_dict plumbing; preformatted text row passthrough
specforge/data/{template,preprocessing}.py Register qwen3.5 templates; pass empty tools for preformatted
specforge/args.py Add RemoteBackendArgs; rename SGLang piecewise flag
scripts/{train_eagle3,train_dflash,launch_target_server}.py Add remote backend, prefetch pipeline, server launcher
pyproject.toml Bump transformers>=5.2.0
docs/, examples/, configs/ New remote training docs, qwen3.5 example, dflash config
Comments suppressed due to low confidence (1)

tests/test_modeling/test_target/test_nccl_transport.py:1

  • The as e clause is unused (only traceback.format_exc() is used). Either drop the as e to silence linters (except Exception:) or include repr(e) in the error payload. The same applies to the other _gpu_* helper functions in this file.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +436 to +442
except (requests.ConnectionError, requests.Timeout) as exc:
if attempt < self.max_retries:
time.sleep(2**attempt)
continue
raise RuntimeError(
f"Remote request failed after {self.max_retries + 1} attempts: {exc}"
) from exc
Comment on lines +359 to +376
def _request(self, endpoint: str, payload: bytes) -> bytes:
"""POST payload to endpoint with exponential-backoff retry."""
url = f"{self.url}/{endpoint.lstrip('/')}"
last_exc = None

for attempt in range(self.max_retries + 1):
try:
resp = self._session.post(
url,
data=payload,
timeout=self.timeout,
headers={"Content-Type": "application/octet-stream"},
)
resp.raise_for_status()
return resp.content
except (requests.ConnectionError, requests.Timeout) as exc:
last_exc = exc
if attempt < self.max_retries:
enable_torch_compile: bool = True,
nccl_port: int = None,
host: str = "0.0.0.0",
attention_backend: str | None = None,
Comment thread scripts/train_eagle3.py
prefetch_queue = []
torch_profiler = None

def next_prefetch_boundary() -> int | float:
Comment thread scripts/train_dflash.py
if dist.get_rank() == 0:
progress_bar = tqdm(
train_dataloader, desc=f"Training Epoch {epoch}", leave=True
def next_prefetch_boundary() -> int | float:
Comment on lines +24 to +25
QWEN3_5_MODEL_TYPES = {"qwen3_5", "qwen3_5_moe"}
VLM_MODEL_TYPES = QWEN3_5_MODEL_TYPES
Comment thread scripts/train_eagle3.py
tracker.close()
destroy_distributed()
# Save final checkpoint if training ended without saving
if args.save_interval <= 0 or global_step % args.save_interval != 0:
Comment on lines +245 to +248
from torch.distributed.distributed_c10d import (
_unregister_process_group,
_world,
)
Comment thread specforge/utils.py
"hf_config_dict": hf_config_dict,
"server_model_path": server_model_path,
}
).encode("utf-8")
Comment thread scripts/train_dflash.py
Comment on lines +160 to +166
(
"sliding_attention"
if sliding_window is not None and layer_idx >= max_window_layers
else "full_attention"
)
for layer_idx in range(num_hidden_layers)
]
@moehanabi moehanabi force-pushed the remote_train_rebase branch 2 times, most recently from 7499e80 to f12b8eb Compare May 28, 2026 08:52
@moehanabi moehanabi force-pushed the remote_train_rebase branch from f12b8eb to b0f94b4 Compare May 28, 2026 08:53
@moehanabi moehanabi force-pushed the remote_train_rebase branch from 960b2ac to f2dbaa5 Compare June 1, 2026 12:28
@moehanabi
Copy link
Copy Markdown
Contributor Author

Move to #573 to rebase newest commit

@moehanabi moehanabi closed this Jun 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants