Skip to content

Add P-EAGLE training support#575

Open
thyways wants to merge 27 commits into
sgl-project:mainfrom
thyways:P-Eagle
Open

Add P-EAGLE training support#575
thyways wants to merge 27 commits into
sgl-project:mainfrom
thyways:P-Eagle

Conversation

@thyways
Copy link
Copy Markdown
Contributor

@thyways thyways commented Jun 4, 2026

Motivation

Closes #541.

P-EAGLE shows speed comparable to DFlash while also supporting parallel decoding, so SpecForge should support P-EAGLE training.

This PR adds P-EAGLE training support to SpecForge, following the algorithmic direction from:

The implementation adapts the P-EAGLE idea to SpecForge's existing online EAGLE3 training pipeline.

Demo model:

https://huggingface.co/thyways/qwen3_8b_peagle_demo

Modifications

This PR adds P-EAGLE training support.

Main changes:

  • Add PEagleDraftModel, including the P-EAGLE multi-layer draft architecture.
  • Add OnlinePEagleModel with COD parallel sampling, P-EAGLE attention masking, loss, and per-depth accuracy metrics.
  • Add scripts/train_peagle.py for online P-EAGLE training with FSDP, TP/DP support, checkpoint save/resume, mask-token resolution, and wandb logging.
  • Add Qwen3-8B P-EAGLE config and example training script.
  • Register P-EAGLE in model/core exports and draft config loading.
  • Improve conversation normalization for ShareGPT-style from/value messages.
  • Add minimum_valid_tokens filtering to drop samples without trainable tokens.
  • Fix SGLang target data padding for loss_mask.
  • Add regression tests for P-EAGLE metrics, COD sampling, attention masking, parser normalization, and trainable-token filtering.

Implementation notes:

  • P-EAGLE inherits the EAGLE3-style draft model but performs parallel multi-token prediction.
  • COD sampling creates sampled prediction depths with geometric downsampling.
  • A learnable mask_hidden parameter is used for positions that do not have target hidden states in parallel prediction depths.
  • Draft embeddings are trainable by default.

Related Issues

Closes #541.

Related references:

Accuracy Test

The Qwen3-8B P-EAGLE was trained on 8x A100 GPUs.

Training setup:

  • Target model: Qwen/Qwen3-8B
  • Training dataset: jihwan1205/perfectblend-qwen3-8b-regen
  • Training length: 110k steps
  • Draft layers: 4
  • Prediction depths: 5
  • Max length: 4096
  • Learning rate: 1e-4

Training Loss:
image

The trained demo checkpoint is available at:

https://huggingface.co/thyways/qwen3_8b_peagle_demo

SGLang does not yet support P-EAGLE inference, so model-side inference accuracy was not evaluated through SGLang in this PR. For the current validation, I adapted the exported config for vLLM testing:

 {
   "architectures": [
     "Eagle3LlamaForCausalLM"
   ],
   "ptd_token_id": 151669
 }

Benchmark & Profiling

Evaluation setup:

  • Hardware: 1x B200
  • Dataset: MT-Bench
  • EAGLE3 baseline: RedHatAI/Qwen3-8B-speculator.eagle3, k=3
  • DFlash baseline: z-lab/Qwen3-8B-DFlash-b16, k=15
  • P-EAGLE: thyways/qwen3_8b_peagle_demo, k=5
image

Due to limited training resources, this checkpoint is not fully trained yet. However, the current results already validate the effectiveness of the implementation, and longer training is expected to further improve the acceptance quality and speedup, consistent with the P-EAGLE paper.

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Support for PEAGLE

1 participant