feat(train): freeze target components and support param-list in BF16Optimizer#561
Open
curnane-lab wants to merge 1 commit into
Open
feat(train): freeze target components and support param-list in BF16Optimizer#561curnane-lab wants to merge 1 commit into
curnane-lab wants to merge 1 commit into
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request updates the training script to freeze target model components and ensures the BF16Optimizer correctly receives draft parameters from the FSDP-wrapped model. It also modifies the optimizer's initialization to support both model objects and parameter lists. A critical issue was identified in specforge/optimizer.py where the use of an undefined variable name model_or_params instead of the function argument model would result in a NameError.
…ptimizer - Explicitly freeze target_lm_head and target_embed_tokens before FSDP wrapping to save memory and avoid gradient sync overhead for frozen parameters. - Extend BF16Optimizer to accept a pre-filtered parameter list (list) in addition to a raw nn.Module. This fixes gradient synchronization issues when using FSDP with use_orig_params=True, where the optimizer must operate on the FSDP-wrapped model's parameters rather than the original unwrapped model. - Filter draft_model.* parameters in train_dflash.py and pass the list to BF16Optimizer, ensuring only trainable draft parameters receive gradients. - Maintain backward compatibility: passing an nn.Module still works.
a909467 to
0201117
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
When training DFlash with
FSDP(use_orig_params=True), two issues arise:Parameter reference mismatch: Passing the unwrapped
draft_modeltoBF16Optimizercauses it to calldraft_model.parameters(), which returns the original parameter objects. However,FSDP(use_orig_params=True)manages its own parameter references viadflash_model.named_parameters(). The optimizer must operate on the FSDP-managed parameters to ensure gradient synchronization works correctly across ranks.Defensive freeze for target components: While
TargetEmbeddingsAndHead.from_pretrained()setsrequires_grad_(False), explicitly freezinglm_headandembed_tokensbefore FSDP wrapping adds a defensive layer against accidental unfreezing by future code changes.This PR fixes both issues to improve training stability on both CUDA and NPU.
Modifications
specforge/optimizer.py:BF16Optimizer.__init__()to accept either annn.Moduleor a pre-filteredlist[nn.Parameter]self.model = Noneand filterp.requires_gradBF16Optimizer(draft_model, ...)still worksscripts/train_dflash.py:OnlineDFlashModelcreation, explicitly freezetarget_components.lm_headandtarget_components.embed_tokenswithrequires_grad = Falsedraft_model.*parameters fromdflash_model.named_parameters()draft_paramslist toBF16Optimizerinstead of the rawdraft_modelRelated Issues
N/A (new feature)
Accuracy Test
Benchmark & Profiling
BF16Optimizer(model, ...)still works for non-FSDP training.use_orig_params=True, optimizer now operates on FSDP-managed parameter references, avoiding potential gradient sync issues.Checklist