From 33d7b81a43d4798143848c952c8195618446e84d Mon Sep 17 00:00:00 2001 From: kathyligg Date: Mon, 13 Apr 2026 19:06:13 +0000 Subject: [PATCH 1/4] Fix: Ensure aux hidden states are correctly set for EAGLE3 in Gemma3. Corrects initialization bug in set_aux_hidden_states_layers. --- configs/gemma3-1b-eagle3.json | 4 +-- configs/gemma3-27b-eagle3.json | 32 +++++++++++++++++++ examples/run_gemma3_1b_eagle3_online.sh | 7 ++-- examples/run_gemma3_27b_eagle3_online.sh | 30 +++++++++++++++++ pyproject.toml | 2 +- requirements-rocm.txt | 4 +-- specforge/core/loss.py | 2 +- specforge/modeling/draft/base.py | 2 +- .../modeling/target/custom_backend/gpt_oss.py | 2 -- .../modeling/target/custom_backend/llama.py | 2 -- .../modeling/target/custom_backend/llama4.py | 2 -- .../modeling/target/custom_backend/phi3.py | 2 -- .../modeling/target/eagle3_target_model.py | 26 +++++++-------- 13 files changed, 85 insertions(+), 32 deletions(-) create mode 100644 configs/gemma3-27b-eagle3.json mode change 100644 => 100755 examples/run_gemma3_1b_eagle3_online.sh create mode 100755 examples/run_gemma3_27b_eagle3_online.sh diff --git a/configs/gemma3-1b-eagle3.json b/configs/gemma3-1b-eagle3.json index e5e74eb16..0c6a303cd 100644 --- a/configs/gemma3-1b-eagle3.json +++ b/configs/gemma3-1b-eagle3.json @@ -26,7 +26,7 @@ "transformers_version": "4.50.0", "use_cache": true, "use_sliding_window": false, - "vocab_size": 262145, - "draft_vocab_size": 32000, + "vocab_size": 262144, + "draft_vocab_size": null, "target_model_type": "gemma3_text" } diff --git a/configs/gemma3-27b-eagle3.json b/configs/gemma3-27b-eagle3.json new file mode 100644 index 000000000..d69283ee8 --- /dev/null +++ b/configs/gemma3-27b-eagle3.json @@ -0,0 +1,32 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 2, + "eos_token_id": 1, + "pad_token_id": 0, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5376, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 16, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": 512, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 262208, + "draft_vocab_size": 12288, + "target_model_type": "gemma3_text" +} diff --git a/examples/run_gemma3_1b_eagle3_online.sh b/examples/run_gemma3_1b_eagle3_online.sh old mode 100644 new mode 100755 index a13650695..8b48905af --- a/examples/run_gemma3_1b_eagle3_online.sh +++ b/examples/run_gemma3_1b_eagle3_online.sh @@ -12,7 +12,7 @@ torchrun \ $ROOT_DIR/scripts/train_eagle3.py \ --target-model-path google/gemma-3-1b-it \ --draft-model-config $ROOT_DIR/configs/gemma3-1b-eagle3.json \ - --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \ --output-dir $ROOT_DIR/outputs/gemma3-1b-eagle3-sharegpt \ --num-epochs 10 \ --batch-size 1 \ @@ -23,4 +23,7 @@ torchrun \ --cache-dir $ROOT_DIR/cache \ --attention-backend sdpa \ --target-model-backend hf \ - --log-interval 10 + --log-interval 500 \ + --eval-interval 2500 \ + --save-interval 60000 \ + --report-to tensorboard diff --git a/examples/run_gemma3_27b_eagle3_online.sh b/examples/run_gemma3_27b_eagle3_online.sh new file mode 100755 index 000000000..c30dbe3d4 --- /dev/null +++ b/examples/run_gemma3_27b_eagle3_online.sh @@ -0,0 +1,30 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for gemma3-1b +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path google/gemma-3-27b-it \ + --draft-model-config $ROOT_DIR/configs/gemma3-27b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \ + --output-dir $ROOT_DIR/outputs/gemma3-27b-eagle3-ultrachat \ + --num-epochs 10 \ + --batch-size 2 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template gemma \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend hf \ + --log-interval 500 \ + --eval-interval 2500 \ + --save-interval 60000 \ + --report-to tensorboard \ + --embedding-key=language_model.model.embed_tokens.weight diff --git a/pyproject.toml b/pyproject.toml index 4698a5e81..02f7ec0ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "torch==2.9.1", "torchaudio==2.9.1", "torchvision==0.24.1", - "transformers==4.57.1", + "transformers>=5.0.0", "qwen-vl-utils==0.0.11", "datasets", "setuptools", diff --git a/requirements-rocm.txt b/requirements-rocm.txt index a7c314563..720bfaf03 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -5,7 +5,7 @@ pre-commit torch==2.8.0+rocm6.3 torchaudio==2.8.0+rocm6.3 torchvision==0.23.0+rocm6.3 -transformers==4.57.1 +transformers>=5.0.0 qwen-vl-utils==0.0.11 datasets setuptools @@ -15,6 +15,6 @@ psutil numpy accelerate pydantic -sglang[all]==0.5.4 +sglang[all]==0.5.9 openai-harmony tensorboard diff --git a/specforge/core/loss.py b/specforge/core/loss.py index 30e7fba7d..2aa337692 100644 --- a/specforge/core/loss.py +++ b/specforge/core/loss.py @@ -24,7 +24,7 @@ def _compute_loss(logits, target_p, position_mask): def _calculate_settings(n): # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 - MAX_FUSED_SIZE = 131072 + MAX_FUSED_SIZE = 262208 BLOCK_SIZE = triton.next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: raise RuntimeError( diff --git a/specforge/modeling/draft/base.py b/specforge/modeling/draft/base.py index b5584a759..54bad8601 100644 --- a/specforge/modeling/draft/base.py +++ b/specforge/modeling/draft/base.py @@ -116,7 +116,7 @@ def freeze_embedding(self) -> None: @torch.no_grad() def load_embedding( - self, model_path: str, embedding_key: str = "model.embed_tokens.weight" + self, model_path: str, embedding_key: str = "language_model.embed_tokens.weight" ) -> None: """ Load the embedding of the draft model. diff --git a/specforge/modeling/target/custom_backend/gpt_oss.py b/specforge/modeling/target/custom_backend/gpt_oss.py index b3b4a7972..9910633c8 100644 --- a/specforge/modeling/target/custom_backend/gpt_oss.py +++ b/specforge/modeling/target/custom_backend/gpt_oss.py @@ -36,7 +36,6 @@ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple -from transformers.utils.generic import check_model_inputs from specforge.distributed import get_tp_group, shard_tensor from specforge.layers import ( @@ -585,7 +584,6 @@ def __init__(self, config: GptOssConfig): # Initialize weights and apply final processing self.post_init() - @check_model_inputs @auto_docstring def forward( self, diff --git a/specforge/modeling/target/custom_backend/llama.py b/specforge/modeling/target/custom_backend/llama.py index 04a3f6c9b..02a1c16c4 100644 --- a/specforge/modeling/target/custom_backend/llama.py +++ b/specforge/modeling/target/custom_backend/llama.py @@ -41,7 +41,6 @@ ) from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, logging -from transformers.utils.generic import check_model_inputs from specforge.distributed import get_tp_group from specforge.layers import ( @@ -275,7 +274,6 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() - @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/specforge/modeling/target/custom_backend/llama4.py b/specforge/modeling/target/custom_backend/llama4.py index 22f807dae..bccbb19ea 100644 --- a/specforge/modeling/target/custom_backend/llama4.py +++ b/specforge/modeling/target/custom_backend/llama4.py @@ -52,7 +52,6 @@ logging, ) from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import check_model_inputs # [MODIFIED] Import from transformers library from specforge.distributed import get_tp_group, shard_tensor @@ -431,7 +430,6 @@ def __init__(self, config: Llama4TextConfig): self.post_init() @can_return_tuple - @check_model_inputs @auto_docstring def forward( self, diff --git a/specforge/modeling/target/custom_backend/phi3.py b/specforge/modeling/target/custom_backend/phi3.py index 2515701f9..c3ec1adcc 100644 --- a/specforge/modeling/target/custom_backend/phi3.py +++ b/specforge/modeling/target/custom_backend/phi3.py @@ -43,7 +43,6 @@ from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import check_model_inputs from specforge.distributed import get_tp_group from specforge.layers import ( @@ -284,7 +283,6 @@ def __init__(self, config: Phi3Config): # Initialize weights and apply final processing self.post_init() - @check_model_inputs @auto_docstring def forward( self, diff --git a/specforge/modeling/target/eagle3_target_model.py b/specforge/modeling/target/eagle3_target_model.py index 2acf50ba5..63877f402 100644 --- a/specforge/modeling/target/eagle3_target_model.py +++ b/specforge/modeling/target/eagle3_target_model.py @@ -94,19 +94,13 @@ def set_aux_hidden_states_layers( if aux_hidden_states_layers is None: if hasattr(self.model.config, "num_hidden_layers"): num_layers = self.model.config.num_hidden_layers + + elif hasattr(self.model.config, "text_config") and hasattr(self.model.config.text_config, "num_hidden_layers"): + num_layers = self.model.config.text_config.num_hidden_layers else: raise ValueError( f"Failed to set aux hidden states layers as model config {self.model.config} does not have num_hidden_layers" ) - aux_hidden_states_layers = [ - 1, - num_layers // 2 - 1, - num_layers - 4, - ] - self.aux_hidden_states_layers = aux_hidden_states_layers - assert ( - len(self.aux_hidden_states_layers) == 3 - ), "aux_hidden_states_layers is expected to be 3 layers for EAGLE3" class HFEagle3TargetModel(Eagle3TargetModel): @@ -154,18 +148,20 @@ def _get_transformer_layers(self): Helper to find the module list containing the transformer layers. Adapts to common architectures (Llama, Qwen, Mistral, OPT, etc.) """ - if hasattr(self.model, "model") and hasattr(self.model.model, "layers"): - return self.model.model.layers + if hasattr(self.model, "model"): + if hasattr(self.model.model, "layers"): + return self.model.model.layers + elif hasattr(self.model.model, "language_model"): + return self.model.model.language_model.layers elif hasattr(self.model, "layers"): return self.model.layers elif hasattr(self.model, "transformer") and hasattr( self.model.transformer, "h" ): return self.model.transformer.h - else: - raise ValueError( - "Could not locate transformer layers in the model architecture to register hooks." - ) + raise ValueError( + "Could not locate transformer layers in the model architecture to register hooks." + ) @torch.no_grad() def generate_eagle3_data( From 155854b90248d228e1c48560cd06bd4b3a7ef840 Mon Sep 17 00:00:00 2001 From: kathyligg Date: Tue, 14 Apr 2026 00:58:36 +0000 Subject: [PATCH 2/4] Adjust intervals --- examples/run_gemma3_27b_eagle3_online.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/run_gemma3_27b_eagle3_online.sh b/examples/run_gemma3_27b_eagle3_online.sh index c30dbe3d4..8c294fb46 100755 --- a/examples/run_gemma3_27b_eagle3_online.sh +++ b/examples/run_gemma3_27b_eagle3_online.sh @@ -23,8 +23,8 @@ torchrun \ --cache-dir $ROOT_DIR/cache \ --attention-backend sdpa \ --target-model-backend hf \ - --log-interval 500 \ - --eval-interval 2500 \ - --save-interval 60000 \ + --log-interval 100 \ + --eval-interval 500 \ + --save-interval 10000 \ --report-to tensorboard \ --embedding-key=language_model.model.embed_tokens.weight From 9567c0abd269d41bf16c296831eb04fa03ebbd9d Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Wed, 15 Apr 2026 21:52:30 +0000 Subject: [PATCH 3/4] Add gemma4 data regen. --- configs/gemma4-26b-a4b-eagle3.json | 32 +++++ examples/regen_gemma4_26b_data.sh | 174 +++++++++++++++++++++++ examples/run_gemma3_27b_eagle3_online.sh | 8 +- examples/run_gemma4_26b_eagle3_online.sh | 31 ++++ scripts/regenerate_train_data.py | 22 ++- scripts/train_eagle3.py | 67 +++++++-- specforge/data/template.py | 18 ++- specforge/modeling/draft/llama3_eagle.py | 36 +++-- 8 files changed, 360 insertions(+), 28 deletions(-) create mode 100644 configs/gemma4-26b-a4b-eagle3.json create mode 100755 examples/regen_gemma4_26b_data.sh create mode 100755 examples/run_gemma4_26b_eagle3_online.sh diff --git a/configs/gemma4-26b-a4b-eagle3.json b/configs/gemma4-26b-a4b-eagle3.json new file mode 100644 index 000000000..36fa9cc24 --- /dev/null +++ b/configs/gemma4-26b-a4b-eagle3.json @@ -0,0 +1,32 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 2, + "eos_token_id": 1, + "pad_token_id": 0, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2816, + "initializer_range": 0.02, + "intermediate_size": 2112, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": 512, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 262144, + "draft_vocab_size": 262144, + "target_model_type": "gemma4_text" +} diff --git a/examples/regen_gemma4_26b_data.sh b/examples/regen_gemma4_26b_data.sh new file mode 100755 index 000000000..69a68a0db --- /dev/null +++ b/examples/regen_gemma4_26b_data.sh @@ -0,0 +1,174 @@ +#!/usr/bin/env bash +# Regenerate training data for Gemma4-26B Eagle3. +# +# This script: +# 1. Launches SGLang server(s) for Gemma4-26B on available GPUs. +# 2. Waits for the server(s) to become healthy. +# 3. Runs regenerate_train_data.py with thinking-ratio support. +# 4. Shuts down the server(s) on exit. +# +# Usage: +# bash examples/regen_gemma4_26b_data.sh +# +# Environment variables (override defaults): +# MODEL - HuggingFace model ID (default: google/gemma-4-26b-a4b-it) +# TP_SIZE - Tensor-parallel size (default: 2) +# NUM_SERVERS - Number of server instances (default: 1) +# BASE_PORT - First server port (default: 30000) +# CONCURRENCY - Requests per server (default: 128) +# MAX_TOKENS - Max generation tokens (default: 8192) +# TEMPERATURE - Sampling temperature (default: 0.8) +# THINKING_RATIO - Fraction with thinking (default: 0.7) +# INPUT_FILE - Input JSONL path (required) +# OUTPUT_FILE - Output JSONL path (required) +# NUM_SAMPLES - Max samples to process (default: all) + +set -euo pipefail + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname "$SCRIPT_DIR") + +# ── Configurable defaults ──────────────────────────────────────────────────── +MODEL="${MODEL:-google/gemma-4-26b-a4b-it}" +TP_SIZE="${TP_SIZE:-1}" +NUM_SERVERS="${NUM_SERVERS:-8}" +BASE_PORT="${BASE_PORT:-30000}" +CONCURRENCY="${CONCURRENCY:-128}" +MAX_TOKENS="${MAX_TOKENS:-2048}" +TEMPERATURE="${TEMPERATURE:-1}" +THINKING_RATIO="${THINKING_RATIO:-0.7}" +INPUT_FILE="${INPUT_FILE:-$ROOT_DIR/cache/dataset/ultrachat_train.jsonl}" +OUTPUT_FILE="${OUTPUT_FILE:-$ROOT_DIR/outputs/dataset/ultrachat_regen_gemma4.jsonl}" +NUM_SAMPLES="${NUM_SAMPLES:-}" + +# ── Derived ────────────────────────────────────────────────────────────────── +TOTAL_GPUS=$(( TP_SIZE * NUM_SERVERS )) +AVAIL_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l || echo 0) + +if [ "$AVAIL_GPUS" -lt "$TOTAL_GPUS" ]; then + echo "Error: Need ${TOTAL_GPUS} GPUs (${NUM_SERVERS} servers x TP ${TP_SIZE}) but only ${AVAIL_GPUS} available." + exit 1 +fi + +echo "============================================================" +echo " Gemma4-26B Data Regeneration" +echo "============================================================" +echo " Model: ${MODEL}" +echo " TP size: ${TP_SIZE}" +echo " Servers: ${NUM_SERVERS}" +echo " Ports: ${BASE_PORT}..$(( BASE_PORT + (NUM_SERVERS - 1) * 10 ))" +echo " Concurrency: ${CONCURRENCY} per server" +echo " Max tokens: ${MAX_TOKENS}" +echo " Temperature: ${TEMPERATURE}" +echo " Thinking ratio: ${THINKING_RATIO}" +echo " Input: ${INPUT_FILE}" +echo " Output: ${OUTPUT_FILE}" +echo "============================================================" + +# ── Cleanup on exit ────────────────────────────────────────────────────────── +SERVER_PIDS=() + +cleanup() { + echo "" + echo "Shutting down SGLang server(s)..." + for pid in "${SERVER_PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + fi + done + # Wait briefly then force-kill stragglers + sleep 2 + for pid in "${SERVER_PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + kill -9 "$pid" 2>/dev/null || true + fi + done + echo "All servers stopped." +} +trap cleanup EXIT + +# ── Launch servers ─────────────────────────────────────────────────────────── +SERVER_ADDRESSES=() + +for i in $(seq 0 $(( NUM_SERVERS - 1 ))); do + PORT=$(( BASE_PORT + i * 10 )) + GPU_START=$(( i * TP_SIZE )) + GPU_END=$(( GPU_START + TP_SIZE - 1 )) + CUDA_DEVICES=$(seq -s, "$GPU_START" "$GPU_END") + + echo "Starting server $((i+1))/${NUM_SERVERS} on GPUs ${CUDA_DEVICES}, port ${PORT}..." + + CUDA_VISIBLE_DEVICES="${CUDA_DEVICES}" /home/pyc_google_com/dev/gemma/.venv/bin/python -m sglang.launch_server \ + --model "${MODEL}" \ + --tp "${TP_SIZE}" \ + --port "${PORT}" \ + --host 0.0.0.0 \ + --cuda-graph-max-bs 128 \ + --trust-remote-code --enable-torch-compile \ + > "${ROOT_DIR}/cache/sglang_server_${PORT}.log" 2>&1 & + + SERVER_PIDS+=($!) + SERVER_ADDRESSES+=("localhost:${PORT}") +done + +# ── Wait for servers to be healthy ─────────────────────────────────────────── +echo "" +echo "Waiting for servers to become healthy..." + +wait_for_server() { + local addr=$1 + local max_wait=600 # 10 minutes + local elapsed=0 + while [ $elapsed -lt $max_wait ]; do + if curl -sf "http://${addr}/health" > /dev/null 2>&1; then + return 0 + fi + sleep 5 + elapsed=$(( elapsed + 5 )) + done + return 1 +} + +for addr in "${SERVER_ADDRESSES[@]}"; do + if wait_for_server "$addr"; then + echo " ${addr} is healthy." + else + echo "Error: ${addr} did not become healthy within 10 minutes." + echo "Check logs at: ${ROOT_DIR}/cache/sglang_server_*.log" + exit 1 + fi +done + +echo "All ${NUM_SERVERS} server(s) are ready." +echo "------------------------------------------------------------" + +# ── Build regen command ────────────────────────────────────────────────────── +REGEN_ARGS=( + python3 "${ROOT_DIR}/scripts/regenerate_train_data.py" + --model "${MODEL}" + --is-reasoning-model + --thinking-ratio "${THINKING_RATIO}" + --concurrency "${CONCURRENCY}" + --max-tokens "${MAX_TOKENS}" + --temperature "${TEMPERATURE}" + --server-address "${SERVER_ADDRESSES[@]}" + --input-file-path "${INPUT_FILE}" + --output-file-path "${OUTPUT_FILE}" + --resume +) + +if [ -n "${NUM_SAMPLES}" ]; then + REGEN_ARGS+=(--num-samples "${NUM_SAMPLES}") +fi + +# ── Run regeneration ───────────────────────────────────────────────────────── +echo "Starting data regeneration..." +echo "" + +mkdir -p "$(dirname "${OUTPUT_FILE}")" +"${REGEN_ARGS[@]}" + +echo "" +echo "============================================================" +echo " Done! Output saved to: ${OUTPUT_FILE}" +echo "============================================================" diff --git a/examples/run_gemma3_27b_eagle3_online.sh b/examples/run_gemma3_27b_eagle3_online.sh index 8c294fb46..f72773e8d 100755 --- a/examples/run_gemma3_27b_eagle3_online.sh +++ b/examples/run_gemma3_27b_eagle3_online.sh @@ -15,7 +15,7 @@ torchrun \ --train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \ --output-dir $ROOT_DIR/outputs/gemma3-27b-eagle3-ultrachat \ --num-epochs 10 \ - --batch-size 2 \ + --batch-size 8 \ --tp-size $TP_SIZE \ --learning-rate 1e-4 \ --max-length 2048 \ @@ -23,8 +23,8 @@ torchrun \ --cache-dir $ROOT_DIR/cache \ --attention-backend sdpa \ --target-model-backend hf \ - --log-interval 100 \ - --eval-interval 500 \ - --save-interval 10000 \ + --log-interval 500 \ + --eval-interval 2500 \ + --save-interval 5000 \ --report-to tensorboard \ --embedding-key=language_model.model.embed_tokens.weight diff --git a/examples/run_gemma4_26b_eagle3_online.sh b/examples/run_gemma4_26b_eagle3_online.sh new file mode 100755 index 000000000..d3bdb5285 --- /dev/null +++ b/examples/run_gemma4_26b_eagle3_online.sh @@ -0,0 +1,31 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for gemma3-1b +NUM_GPUS=${1:-8} +TP_SIZE=${2:-2} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path google/gemma-4-26b-a4b-it \ + --draft-model-config $ROOT_DIR/configs/gemma4-26b-a4b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \ + --output-dir $ROOT_DIR/outputs/gemma4-26b-a4b-eagle3-ultrachat \ + --num-epochs 10 \ + --batch-size 4 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template gemma-4 \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend hf \ + --log-interval 500 \ + --eval-interval 2500 \ + --save-interval 10000 \ + --report-to tensorboard \ + --embedding-key=model.language_model.embed_tokens.weight \ + --eval-holdout-ratio 0.05 diff --git a/scripts/regenerate_train_data.py b/scripts/regenerate_train_data.py index d38392b69..cfec92d70 100644 --- a/scripts/regenerate_train_data.py +++ b/scripts/regenerate_train_data.py @@ -1,6 +1,6 @@ """ This script will re-generate the dataset from target model, -which better aligns the draft model with the target model’s output distribution. +which better aligns the draft model with the target model's output distribution. Usage: 1. Set up one or more SGLang servers for the target model. @@ -60,6 +60,15 @@ def parse_arguments(): action="store_true", help="Whether the model is a GPT-OSS model", ) + model_group.add_argument( + "--thinking-ratio", + type=float, + default=None, + help="Fraction of requests sent with thinking enabled (0 to 1). " + "Requires --is-reasoning-model. When set, each request randomly " + "enables or disables thinking based on this ratio. " + "E.g., 0.7 means 70%% of samples use thinking, 30%% do not.", + ) # sampling params sampling_params_group = parser.add_argument_group("sampling parameters") @@ -184,6 +193,9 @@ def build_query_kwargs(args, messages, max_tokens=None): extra_body = {} if args.top_k is not None: extra_body["top_k"] = args.top_k + if args.thinking_ratio is not None: + enable_thinking = random.random() < args.thinking_ratio + extra_body["chat_template_kwargs"] = {"enable_thinking": enable_thinking} if extra_body: query_kwargs["extra_body"] = extra_body if args.is_gpt_oss: @@ -255,11 +267,19 @@ def main(): if args.max_tokens <= 0: raise ValueError("Max tokens must be greater than 0") + if args.thinking_ratio is not None: + if not (0.0 <= args.thinking_ratio <= 1.0): + raise ValueError("--thinking-ratio must be between 0.0 and 1.0") + if not args.is_reasoning_model: + raise ValueError("--thinking-ratio requires --is-reasoning-model") + print(f"Configuration:") print(f" Model path: {args.model}") print(f" Max tokens: {args.max_tokens}") print(f" Concurrency: {args.concurrency}") print(f" Temperature: {args.temperature}") + if args.thinking_ratio is not None: + print(f" Thinking ratio: {args.thinking_ratio:.0%}") print(f" API URL: {args.server_address}") print(f" Input file: {args.input_file_path}") print(f" Output file: {args.output_file_path}") diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 0bd157b39..92166fd2e 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -3,6 +3,7 @@ import math import os import time +from datetime import datetime from argparse import ArgumentParser, Namespace from typing import List, Optional, Tuple, Union @@ -103,6 +104,13 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]: dataset_group.add_argument("--train-hidden-states-path", type=str, default=None) dataset_group.add_argument("--eval-hidden-states-path", type=str, default=None) dataset_group.add_argument("--eval-data-path", type=str, default=None) + dataset_group.add_argument( + "--eval-holdout-ratio", + type=float, + default=None, + help="Fraction of the training dataset to hold out for evaluation (0 to 1). " + "Mutually exclusive with --eval-data-path and --eval-hidden-states-path.", + ) dataset_group.add_argument("--chat-template", type=str, default="llama3") dataset_group.add_argument( "--is-preformatted", @@ -339,6 +347,19 @@ def sanity_check(args: Namespace) -> None: """ args.dp_size = dist.get_world_size() // args.tp_size args.target_batch_size = args.tp_size * args.batch_size + + if args.eval_holdout_ratio is not None: + if not (0 < args.eval_holdout_ratio < 1): + raise ValueError( + f"--eval-holdout-ratio must be between 0 and 1 (exclusive), " + f"got {args.eval_holdout_ratio}" + ) + if args.eval_data_path is not None or args.eval_hidden_states_path is not None: + raise ValueError( + "--eval-holdout-ratio is mutually exclusive with " + "--eval-data-path and --eval-hidden-states-path" + ) + if args.attention_backend == "usp": sp_sanity_check(args) @@ -347,9 +368,9 @@ def sp_sanity_check(args: Namespace) -> None: args.draft_accumulation_steps = ( args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size ) - assert ( - args.batch_size == 1 - ), f"USP only supports batch_size=1, got batch_size={args.batch_size}" + assert args.batch_size == 1, ( + f"USP only supports batch_size=1, got batch_size={args.batch_size}" + ) assert args.sp_ring_size * args.sp_ulysses_size > 1, ( f"USP requires sp_ring_size * sp_ulysses_size > 1. " @@ -491,6 +512,21 @@ def build_dataloaders( use_usp_preprocess=(args.attention_backend == "usp"), ) + # Split a holdout portion from the training set if requested. + eval_eagle3_dataset_from_holdout = None + if args.eval_holdout_ratio is not None and args.eval_holdout_ratio > 0: + split = train_eagle3_dataset.train_test_split( + test_size=args.eval_holdout_ratio, + seed=args.seed, + ) + train_eagle3_dataset = split["train"] + eval_eagle3_dataset_from_holdout = split["test"] + print_on_rank0( + f"Holdout split: {len(train_eagle3_dataset)} train, " + f"{len(eval_eagle3_dataset_from_holdout)} eval " + f"(ratio={args.eval_holdout_ratio})" + ) + train_dataloader = prepare_dp_dataloaders( train_eagle3_dataset, args.target_batch_size, @@ -503,7 +539,13 @@ def build_dataloaders( ), is_vlm=args.is_vlm, ) - if args.eval_data_path is not None or args.eval_hidden_states_path is not None: + + has_eval = ( + args.eval_data_path is not None + or args.eval_hidden_states_path is not None + or eval_eagle3_dataset_from_holdout is not None + ) + if has_eval: if args.eval_data_path is not None: eval_dataset = Dataset.from_generator( generator=safe_conversations_generator, @@ -527,6 +569,8 @@ def build_dataloaders( ttt_length=args.ttt_length, use_usp_preprocess=(args.attention_backend == "usp"), ) + else: + eval_eagle3_dataset = eval_eagle3_dataset_from_holdout eval_dataloader = prepare_dp_dataloaders( eval_eagle3_dataset, args.target_batch_size, @@ -742,6 +786,16 @@ def main(): ) sanity_check(args) + + # Create a datetime subfolder for this run (skip when resuming into an + # existing output directory so that checkpoints stay in the same place). + if not args.resume: + run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + args.output_dir = os.path.join(args.output_dir, run_timestamp) + if dist.get_rank() == 0: + os.makedirs(args.output_dir, exist_ok=True) + dist.barrier() + print_args_with_dots(args) print_with_rank("Initialized distributed environment") @@ -946,10 +1000,7 @@ def main(): # ================================================ # 7.2 Evaluation Step # ================================================ - should_evaluate = ( - args.eval_data_path is not None - or args.eval_hidden_states_path is not None - ) + should_evaluate = eval_dataloader is not None if ( should_evaluate and global_step % (args.eval_interval * args.draft_accumulation_steps) diff --git a/specforge/data/template.py b/specforge/data/template.py index 4dde000fd..c292978d4 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -58,9 +58,9 @@ def register(self, name: str, template: ChatTemplate, override: bool = False): template(ChatTemplate): The chat template. override(bool): Whether to override the existing template, default to False """ - assert ( - not override and name not in self.templates - ), f"Chat template for the model type {name} has already been registered" + assert not override and name not in self.templates, ( + f"Chat template for the model type {name} has already been registered" + ) self.templates[name] = template def get(self, name: str) -> ChatTemplate: @@ -324,3 +324,15 @@ def get_all_template_names(self) -> List[str]: enable_thinking=True, ), ) + +TEMPLATE_REGISTRY.register( + name="gemma-4", + template=ChatTemplate( + assistant_header="<|turn>model\n", + user_header="<|turn>user\n", + system_prompt="", + end_of_turn_token="\n", + parser_type="thinking", + enable_thinking=True, + ), +) diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 268142c0c..5d1e59d05 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -925,9 +925,9 @@ def forward( k0 = cache_k[0] v0 = cache_v[0] - assert ( - flash_attn_func is not None - ), "flash_attn is not installed, please install flash_attn if you want to use the flash attention backend" + assert flash_attn_func is not None, ( + "flash_attn is not installed, please install flash_attn if you want to use the flash attention backend" + ) attn_output, lse, _ = flash_attn_func( query_states, k0, @@ -981,9 +981,9 @@ class LlamaUSPFlashAttention(LlamaAttention): def __init__(self, config): super().__init__(config) - assert ( - dist.is_initialized() - ), f"LlamaUSPAttention requires torch.distributed; call init_distributed first." + assert dist.is_initialized(), ( + f"LlamaUSPAttention requires torch.distributed; call init_distributed first." + ) if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): raise NotImplementedError( f"LlamaMutiRotaryEmbedding is currently not supported for LlamaUSPFlashAttention." @@ -1008,7 +1008,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() local_q_len = q_len @@ -1099,9 +1098,9 @@ def forward( else: acc_lse = lse_ring - assert ( - acc_lse.shape[1] == current_q_len - ), f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}" + assert acc_lse.shape[1] == current_q_len, ( + f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}" + ) acc_out = out_ring @@ -1311,7 +1310,6 @@ def forward( class LlamaForCausalLMEagle3(Eagle3DraftModel): - config_class = LlamaConfig def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: @@ -1340,6 +1338,17 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: config.hidden_size, config.draft_vocab_size, bias=False ) + # Embedding scale factor for target models that use scaled embeddings + # (e.g., Gemma3/Gemma4 multiply by hidden_size**0.5). Set via config + # field ``embed_scale`` or auto-detected from ``target_model_type``. + target_type = getattr(config, "target_model_type", None) or "" + if getattr(config, "embed_scale", None) is not None: + self.embed_scale = config.embed_scale + elif "gemma" in target_type: + self.embed_scale = config.hidden_size**0.5 + else: + self.embed_scale = 1.0 + # create vocab buffers t2d = torch.ones(self.vocab_size, dtype=torch.bool) d2t = torch.zeros(self.draft_vocab_size, dtype=torch.int64) @@ -1403,7 +1412,10 @@ def forward( return hidden_states def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) + embeds = self.embed_tokens(input_ids) + if self.embed_scale != 1.0: + embeds = embeds * self.embed_scale + return embeds def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: # eagle 3 requires hidden states from 3 layers From fa7ba138e4b1f7f8194a577c1663f82981802ee5 Mon Sep 17 00:00:00 2001 From: kathyligg Date: Fri, 17 Apr 2026 00:43:00 +0000 Subject: [PATCH 4/4] feat: Enable global attention for Gemma3/Gemma4 drafter models --- examples/run_gemma3_27b_eagle3_online.sh | 1 + scripts/train_eagle3.py | 6 ++ specforge/modeling/draft/llama3_eagle.py | 81 ++++++++++++++---------- 3 files changed, 55 insertions(+), 33 deletions(-) diff --git a/examples/run_gemma3_27b_eagle3_online.sh b/examples/run_gemma3_27b_eagle3_online.sh index f72773e8d..612e624ea 100755 --- a/examples/run_gemma3_27b_eagle3_online.sh +++ b/examples/run_gemma3_27b_eagle3_online.sh @@ -14,6 +14,7 @@ torchrun \ --draft-model-config $ROOT_DIR/configs/gemma3-27b-eagle3.json \ --train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \ --output-dir $ROOT_DIR/outputs/gemma3-27b-eagle3-ultrachat \ + --eval-holdout-ratio 0.03 \ --num-epochs 10 \ --batch-size 8 \ --tp-size $TP_SIZE \ diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 92166fd2e..438f53175 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -402,6 +402,11 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module] # Use provided config file draft_model_config = AutoDraftModelConfig.from_file(args.draft_model_config) + # if the target model is gemma, we should use global attention for the draft model + if "gemma" in getattr(draft_model_config, "target_model_type", "").lower(): + draft_model_config.use_global_attention = True + print_on_rank0("Using global attention for draft model.") + # Handle base ckpt, config file draft_model_last_checkpoint = None is_resume_checkpoint = False @@ -427,6 +432,7 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module] if draft_model_last_checkpoint: draft_model = AutoEagle3DraftModel.from_pretrained( draft_model_last_checkpoint, + config=draft_model_config, attention_backend=args.attention_backend, torch_dtype=torch.bfloat16, ).cuda() diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 5d1e59d05..b73f39e3d 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -523,6 +523,7 @@ def __init__(self, config): self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings + self.use_global_attention = getattr(config, "use_global_attention", False) self.q_proj = nn.Linear( self.hidden_size * 2, self.num_heads * self.head_dim, bias=False @@ -760,6 +761,10 @@ class LlamaFlexAttention(LlamaAttention): - past_key_values: dynamic cache used for storing past key and value states. """ + def __init__(self, config): + super().__init__(config) + self.use_global_attention = getattr(config, "use_global_attention", False) + def forward( self, hidden_states: torch.Tensor, @@ -821,39 +826,45 @@ def forward( cache_kwargs=cache_kwargs, ) - seq_lengths = attention_mask.sum(dim=-1) - # Shrink the attention mask to align with the padding to the right. - # This is equivalent to the shrinking logic in eagle3.py - seq_lengths -= lck - # TODO: Remove the usage of uncompiled create_block_mask after - # https://github.com/pytorch/pytorch/issues/160018 - if q_len <= 128: - create_block_mask_func = create_block_mask - flex_attention_func = flex_attention + if self.use_global_attention: + block_mask = None # Enables full attention else: - create_block_mask_func = compile_friendly_create_block_mask - flex_attention_func = compile_friendly_flex_attention - - block_mask = create_block_mask_func( - mask_mod=generate_eagle3_mask( - seq_lengths=seq_lengths, - Q_LEN=q_len, - KV_LEN=key_cache.shape[-2], - lck=lck, - ), - B=bsz, - H=1, # Rely on broadcast - Q_LEN=q_len, - KV_LEN=key_cache.shape[-2], - device=query_states.device, - ) - attn_output = flex_attention_func( - query=query_states, - key=key_cache.contiguous(), - value=value_cache.contiguous(), - block_mask=block_mask, - enable_gqa=True, - ) + seq_lengths = attention_mask.sum(dim=-1) + # Shrink the attention mask to align with the padding to the right. + # This is equivalent to the shrinking logic in eagle3.py + seq_lengths -= lck + # TODO: Remove the usage of uncompiled create_block_mask after + # https://github.com/pytorch/pytorch/issues/160018 + if q_len <= 128: + create_block_mask_func = create_block_mask + flex_attention_func = flex_attention + else: + create_block_mask_func = compile_friendly_create_block_mask + flex_attention_func = compile_friendly_flex_attention + + if self.use_global_attention: + block_mask = None # This will result in dense attention + else: + block_mask = create_block_mask_func( + mask_mod=generate_eagle3_mask( + seq_lengths=seq_lengths, + Q_LEN=q_len, + KV_LEN=key_cache.shape[-2], + lck=lck, + ), + B=bsz, + H=1, # Rely on broadcast + Q_LEN=q_len, + KV_LEN=key_cache.shape[-2], + device=query_states.device, + ) + attn_output = flex_attention_func( + query=query_states, + key=key_cache.contiguous(), + value=value_cache.contiguous(), + block_mask=block_mask, + enable_gqa=True, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.head_dim * self.num_heads) attn_output = self.o_proj(attn_output) @@ -869,6 +880,10 @@ class LlamaFlashAttention(LlamaAttention): - cache_hidden: manual cache used for storing past key and value states """ + def __init__(self, config): + super().__init__(config) + self.use_global_attention = getattr(config, "use_global_attention", False) + def forward( self, hidden_states: torch.Tensor, @@ -934,7 +949,7 @@ def forward( v0, dropout_p=0.0, softmax_scale=1.0 / math.sqrt(self.head_dim), - causal=True, + causal=not self.use_global_attention, # Set causal based on the flag return_attn_probs=True, ) lse = lse.transpose(1, 2)