From 02011179eae0ffda2805638b3da34943f0dd0b11 Mon Sep 17 00:00:00 2001 From: Curnane Date: Mon, 25 May 2026 15:56:34 +0800 Subject: [PATCH] feat(train): freeze target components and support param-list in BF16Optimizer - Explicitly freeze target_lm_head and target_embed_tokens before FSDP wrapping to save memory and avoid gradient sync overhead for frozen parameters. - Extend BF16Optimizer to accept a pre-filtered parameter list (list) in addition to a raw nn.Module. This fixes gradient synchronization issues when using FSDP with use_orig_params=True, where the optimizer must operate on the FSDP-wrapped model's parameters rather than the original unwrapped model. - Filter draft_model.* parameters in train_dflash.py and pass the list to BF16Optimizer, ensuring only trainable draft parameters receive gradients. - Maintain backward compatibility: passing an nn.Module still works. --- scripts/train_dflash.py | 16 +++++++++++++++- specforge/optimizer.py | 12 ++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index 808e928c2..d495d67b9 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -438,6 +438,13 @@ def main(): loss_decay_gamma=args.loss_decay_gamma, ) + # Freeze target components: only train draft model + for param in target_components.lm_head.parameters(): + param.requires_grad = False + for param in target_components.embed_tokens.parameters(): + param.requires_grad = False + print_on_rank0("Frozen target_lm_head and target_embed_tokens") + dflash_model = FSDP( dflash_model, use_orig_params=True, @@ -452,8 +459,15 @@ def main(): start_epoch = ckpt_info[0] global_step = ckpt_info[1] + # BF16Optimizer must operate on the FSDP-wrapped model's draft parameters + # to ensure gradient synchronization works correctly with use_orig_params=True + draft_params = [] + for name, param in dflash_model.named_parameters(): + if param.requires_grad and "draft_model." in name: + draft_params.append(param) + optimizer = BF16Optimizer( - draft_model, + draft_params, lr=args.learning_rate, max_grad_norm=args.max_grad_norm, warmup_ratio=args.warmup_ratio, diff --git a/specforge/optimizer.py b/specforge/optimizer.py index 7bdd3ab8d..7b34b355b 100644 --- a/specforge/optimizer.py +++ b/specforge/optimizer.py @@ -18,8 +18,16 @@ def __init__( # TODO: We should make these parameters configurable # These magic numbers: weight_decay=0.0, max_grad_norm=0.5, total_steps=800k, warmup_steps=12k are copied from # https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/ds_config.json - self.model = model - self.model_params = [p for p in model.parameters() if p.requires_grad] + if isinstance(model, list): + self.model = None + self.model_params = [ + p for p in model if p.requires_grad + ] + else: + self.model = model + self.model_params = [ + p for p in model.parameters() if p.requires_grad + ] self.max_grad_norm = max_grad_norm self.fp32_params = [ p.detach().clone().to(torch.float32) for p in self.model_params