Support sharded target logits for EAGLE3 online training#558
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements support for sharding target logits and sequence parallelism (USP) during online data generation for Eagle3 training. It introduces the --shard-target-logits argument and updates the SGLang backend to handle logit and hidden state sharding via all-to-all communication and sequence slicing. Review feedback identifies a critical error in the position_ids calculation for sequence parallel mode and highlights several performance bottlenecks in the sharding utility functions caused by inefficient tensor allocations and copying.
| seq_len = input_id.shape[1] | ||
| sp_ulysses_size = max(1, sp_size // sp_ring_size) | ||
| usp_chunk_size = max(seq_len - ttt_length, 0) | ||
| ring_chunk = usp_chunk_size * sp_ulysses_size | ||
| ring_start = sp_ring_rank * ring_chunk | ||
| kept_position_ids.append( | ||
| torch.arange( | ||
| ring_start, | ||
| ring_start + ring_chunk, | ||
| dtype=torch.long, | ||
| device=input_id.device, | ||
| ).unsqueeze(0) | ||
| ) |
There was a problem hiding this comment.
The logic for generating position_ids in sequence parallel mode is incorrect. It uses sp_ring_rank and a ring_chunk size that does not match the local sequence length seq_len. Specifically, ring_chunk is calculated as usp_chunk_size * sp_ulysses_size, which represents the length of the entire Ring chunk, but input_id is only a Ulysses shard of that chunk. This will cause a shape mismatch and incorrect position encodings during training. The position_ids should be based on the absolute offset start calculated from _sp_chunk_bounds using the rank's sp_rank and the local sequence length.
original_seq_len = input_ids.shape[1]
start, _ = _sp_chunk_bounds(original_seq_len, sp_rank, sp_size, ttt_length)
seq_len = input_id.shape[1]
kept_position_ids.append(
torch.arange(
start,
start + seq_len,
dtype=torch.long,
device=input_id.device,
).unsqueeze(0)
)There was a problem hiding this comment.
This is intentional and follows the existing offline USP preprocessing logic. In USP mode, position_ids are not local input slice positions. They need to match the Ulysses all-to-all expanded sequence length consumed by UspAdapter, i.e. (local_len - ttt_length) * sp_ulysses_size. The online path mirrors process_data_usp so online and offline USP use consistent position ids.
226a1da to
64a52de
Compare
64a52de to
44571b6
Compare
|
Hi @jiapingW could you please help take a look? This PR solves a similar issue as #524, but extends it to the I tested it on DeepSeek-V3 32K online training ( Thanks! |
|
Great. I’ll review it soon. |
Motivation
This PR adds a memory-efficient target-logits path for EAGLE3 online training with the SGLang target backend.
Previously, online EAGLE3 training materialized full target logits on every TP rank after SGLang's tensor-parallel logits all-gather. For long-context training, this creates large redundant logits tensors and can lead to OOM. This becomes especially problematic when enabling sequence parallel training, where the draft model only needs the local sequence shard instead of full-sequence logits on every rank.
Modifications
--shard-target-logitstoscripts/train_eagle3.py.tp_size >= sp_sizewhentp_size % sp_size == 0.--shard-target-logitsis enabled.Related Issues
N/A
Accuracy Test
Tested EAGLE3 online training on DeepSeek-V3 with the SGLang target backend.
Configuration:
The 32K training run proceeds without OOM. Training loss decreases normally and training accuracy increases normally.
Training loss / accuracy curves:
Downstream acceptance length on the related evaluation data: 2.53. This is roughly consistent with the previous non-SP training result we used.
SGLang Inference/evaluation Logs:
Benchmark & Profiling
This PR is primarily intended to reduce target-logits memory usage during EAGLE3 online training.
Observed behavior:
--shard-target-logits, each rank only keeps the target logits needed by its local batch shard or local USP sequence shard.--sglang-mem-fraction-static 0.8.No throughput benchmark is included yet.
Checklist