Skip to content

Support sharded target logits for EAGLE3 online training#558

Open
Yukino256 wants to merge 1 commit into
sgl-project:mainfrom
Yukino256:pr-online-usp-shard-target-logits
Open

Support sharded target logits for EAGLE3 online training#558
Yukino256 wants to merge 1 commit into
sgl-project:mainfrom
Yukino256:pr-online-usp-shard-target-logits

Conversation

@Yukino256
Copy link
Copy Markdown

@Yukino256 Yukino256 commented May 25, 2026

Motivation

This PR adds a memory-efficient target-logits path for EAGLE3 online training with the SGLang target backend.

Previously, online EAGLE3 training materialized full target logits on every TP rank after SGLang's tensor-parallel logits all-gather. For long-context training, this creates large redundant logits tensors and can lead to OOM. This becomes especially problematic when enabling sequence parallel training, where the draft model only needs the local sequence shard instead of full-sequence logits on every rank.

Modifications

  • Add --shard-target-logits to scripts/train_eagle3.py.
  • For SGLang target models, optionally disable SGLang's target-logits TP all-gather and redistribute logits explicitly.
  • Add two sharded target-logits redistribution paths:
    • Non-SP online training: redistribute from full-batch/local-vocab logits to local-batch/full-vocab logits.
    • USP online training: redistribute from full-sequence/local-vocab logits to local-sequence/full-vocab logits.
  • Support online USP training with tp_size >= sp_size when tp_size % sp_size == 0.
  • Align target hidden states, input IDs, attention masks, loss masks, and position IDs with the local USP sequence shard.
  • Keep the original behavior unchanged unless --shard-target-logits is enabled.

Related Issues

N/A

Accuracy Test

Tested EAGLE3 online training on DeepSeek-V3 with the SGLang target backend.

Configuration:

--target-model-backend sglang
--tp-size 8
--sp-ring-size 1
--sp-ulysses-size 8
--attention-backend usp
--max-length 32768
--sglang-mem-fraction-static 0.8
--shard-target-logits

The 32K training run proceeds without OOM. Training loss decreases normally and training accuracy increases normally.

Training loss / accuracy curves:

image

Downstream acceptance length on the related evaluation data: 2.53. This is roughly consistent with the previous non-SP training result we used.

SGLang Inference/evaluation Logs:

image

Benchmark & Profiling

This PR is primarily intended to reduce target-logits memory usage during EAGLE3 online training.

Observed behavior:

  • Before this change, target logits were materialized redundantly on every TP rank.
  • With --shard-target-logits, each rank only keeps the target logits needed by its local batch shard or local USP sequence shard.
  • In the tested DeepSeek-V3 32K online USP setup, training runs successfully with --sglang-mem-fraction-static 0.8.

No throughput benchmark is included yet.

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for sharding target logits and sequence parallelism (USP) during online data generation for Eagle3 training. It introduces the --shard-target-logits argument and updates the SGLang backend to handle logit and hidden state sharding via all-to-all communication and sequence slicing. Review feedback identifies a critical error in the position_ids calculation for sequence parallel mode and highlights several performance bottlenecks in the sharding utility functions caused by inefficient tensor allocations and copying.

Comment on lines +884 to +896
seq_len = input_id.shape[1]
sp_ulysses_size = max(1, sp_size // sp_ring_size)
usp_chunk_size = max(seq_len - ttt_length, 0)
ring_chunk = usp_chunk_size * sp_ulysses_size
ring_start = sp_ring_rank * ring_chunk
kept_position_ids.append(
torch.arange(
ring_start,
ring_start + ring_chunk,
dtype=torch.long,
device=input_id.device,
).unsqueeze(0)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for generating position_ids in sequence parallel mode is incorrect. It uses sp_ring_rank and a ring_chunk size that does not match the local sequence length seq_len. Specifically, ring_chunk is calculated as usp_chunk_size * sp_ulysses_size, which represents the length of the entire Ring chunk, but input_id is only a Ulysses shard of that chunk. This will cause a shape mismatch and incorrect position encodings during training. The position_ids should be based on the absolute offset start calculated from _sp_chunk_bounds using the rank's sp_rank and the local sequence length.

                    original_seq_len = input_ids.shape[1]
                    start, _ = _sp_chunk_bounds(original_seq_len, sp_rank, sp_size, ttt_length)
                    seq_len = input_id.shape[1]
                    kept_position_ids.append(
                        torch.arange(
                            start,
                            start + seq_len,
                            dtype=torch.long,
                            device=input_id.device,
                        ).unsqueeze(0)
                    )

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional and follows the existing offline USP preprocessing logic. In USP mode, position_ids are not local input slice positions. They need to match the Ulysses all-to-all expanded sequence length consumed by UspAdapter, i.e. (local_len - ttt_length) * sp_ulysses_size. The online path mirrors process_data_usp so online and offline USP use consistent position ids.

Comment thread specforge/modeling/target/sglang_backend/utils.py Outdated
Comment thread specforge/modeling/target/sglang_backend/utils.py Outdated
@Yukino256 Yukino256 force-pushed the pr-online-usp-shard-target-logits branch from 226a1da to 64a52de Compare May 25, 2026 03:50
@Yukino256 Yukino256 force-pushed the pr-online-usp-shard-target-logits branch from 64a52de to 44571b6 Compare May 25, 2026 03:51
@Yukino256
Copy link
Copy Markdown
Author

Hi @jiapingW could you please help take a look?

This PR solves a similar issue as #524, but extends it to the online + USP + sharded target logits setting. It avoids redundant target logits across TP ranks and makes the SGLang target logits align with the local USP sequence shard.

I tested it on DeepSeek-V3 32K online training (tp=8, sp=8, mem_fraction_static=0.8). Training runs without OOM, loss/accuracy look normal, and the downstream acceptance length is 2.53, roughly consistent with our previous non-SP result.

Thanks!

@jiapingW
Copy link
Copy Markdown
Collaborator

Great. I’ll review it soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants