Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion scripts/train_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,13 @@ def main():
loss_decay_gamma=args.loss_decay_gamma,
)

# Freeze target components: only train draft model
for param in target_components.lm_head.parameters():
param.requires_grad = False
for param in target_components.embed_tokens.parameters():
param.requires_grad = False
print_on_rank0("Frozen target_lm_head and target_embed_tokens")

dflash_model = FSDP(
dflash_model,
use_orig_params=True,
Expand All @@ -452,8 +459,15 @@ def main():
start_epoch = ckpt_info[0]
global_step = ckpt_info[1]

# BF16Optimizer must operate on the FSDP-wrapped model's draft parameters
# to ensure gradient synchronization works correctly with use_orig_params=True
draft_params = []
for name, param in dflash_model.named_parameters():
if param.requires_grad and "draft_model." in name:
draft_params.append(param)

optimizer = BF16Optimizer(
draft_model,
draft_params,
lr=args.learning_rate,
max_grad_norm=args.max_grad_norm,
warmup_ratio=args.warmup_ratio,
Expand Down
12 changes: 10 additions & 2 deletions specforge/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,16 @@ def __init__(
# TODO: We should make these parameters configurable
# These magic numbers: weight_decay=0.0, max_grad_norm=0.5, total_steps=800k, warmup_steps=12k are copied from
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/ds_config.json
self.model = model
self.model_params = [p for p in model.parameters() if p.requires_grad]
if isinstance(model, list):
self.model = None
self.model_params = [
p for p in model if p.requires_grad
]
else:
self.model = model
self.model_params = [
p for p in model.parameters() if p.requires_grad
]
self.max_grad_norm = max_grad_norm
self.fp32_params = [
p.detach().clone().to(torch.float32) for p in self.model_params
Expand Down