Skip to content

feat(train): freeze target components and support param-list in BF16Optimizer#561

Open
curnane-lab wants to merge 1 commit into
sgl-project:mainfrom
curnane-lab:feat/train-freeze-and-optimizer
Open

feat(train): freeze target components and support param-list in BF16Optimizer#561
curnane-lab wants to merge 1 commit into
sgl-project:mainfrom
curnane-lab:feat/train-freeze-and-optimizer

Conversation

@curnane-lab
Copy link
Copy Markdown

@curnane-lab curnane-lab commented May 25, 2026

Motivation

When training DFlash with FSDP(use_orig_params=True), two issues arise:

  1. Parameter reference mismatch: Passing the unwrapped draft_model to BF16Optimizer causes it to call draft_model.parameters(), which returns the original parameter objects. However, FSDP(use_orig_params=True) manages its own parameter references via dflash_model.named_parameters(). The optimizer must operate on the FSDP-managed parameters to ensure gradient synchronization works correctly across ranks.

  2. Defensive freeze for target components: While TargetEmbeddingsAndHead.from_pretrained() sets requires_grad_(False), explicitly freezing lm_head and embed_tokens before FSDP wrapping adds a defensive layer against accidental unfreezing by future code changes.

This PR fixes both issues to improve training stability on both CUDA and NPU.

Modifications

  • specforge/optimizer.py:

    • Extend BF16Optimizer.__init__() to accept either an nn.Module or a pre-filtered list[nn.Parameter]
    • When a list is passed, set self.model = None and filter p.requires_grad
    • Backward compatible: existing code passing BF16Optimizer(draft_model, ...) still works
  • scripts/train_dflash.py:

    • After OnlineDFlashModel creation, explicitly freeze target_components.lm_head and target_components.embed_tokens with requires_grad = False
    • After FSDP wrapping, collect only draft_model.* parameters from dflash_model.named_parameters()
    • Pass the filtered draft_params list to BF16Optimizer instead of the raw draft_model

Related Issues

N/A (new feature)

Accuracy Test

  • Not applicable — no model architecture or kernel changes; this is a training loop refactor.

Benchmark & Profiling

  • Verified backward compatibility: BF16Optimizer(model, ...) still works for non-FSDP training.
  • With use_orig_params=True, optimizer now operates on FSDP-managed parameter references, avoiding potential gradient sync issues.

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the training script to freeze target model components and ensures the BF16Optimizer correctly receives draft parameters from the FSDP-wrapped model. It also modifies the optimizer's initialization to support both model objects and parameter lists. A critical issue was identified in specforge/optimizer.py where the use of an undefined variable name model_or_params instead of the function argument model would result in a NameError.

Comment thread specforge/optimizer.py Outdated
…ptimizer

- Explicitly freeze target_lm_head and target_embed_tokens before FSDP
  wrapping to save memory and avoid gradient sync overhead for frozen
  parameters.

- Extend BF16Optimizer to accept a pre-filtered parameter list (list)
  in addition to a raw nn.Module. This fixes gradient synchronization
  issues when using FSDP with use_orig_params=True, where the optimizer
  must operate on the FSDP-wrapped model's parameters rather than the
  original unwrapped model.

- Filter draft_model.* parameters in train_dflash.py and pass the list
  to BF16Optimizer, ensuring only trainable draft parameters receive
  gradients.

- Maintain backward compatibility: passing an nn.Module still works.
@curnane-lab curnane-lab force-pushed the feat/train-freeze-and-optimizer branch from a909467 to 0201117 Compare May 25, 2026 08:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants