-
Notifications
You must be signed in to change notification settings - Fork 1.5k
fix: only instantiate CrossAttentionBlock when with_cross_attention=True #8848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -147,7 +147,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: | |
|
|
||
| # fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2 | ||
| for k in list(old_state_dict.keys()): | ||
| if "norm2" in k: | ||
| if "norm2" in k and k.replace("norm2", "norm_cross_attn") in new_state_dict: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this fixes the issue. |
||
| new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict.pop(k) | ||
| if "norm3" in k: | ||
| new_state_dict[k.replace("norm3", "norm2")] = old_state_dict.pop(k) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've had this issue with other classes with conditional components. If we don't take this branch then the members aren't created which causes issues with typing, Torchscript (though this isn't so much a concern anymore), and loading weights. The saved weights for this class will have those for
cross_attneven if it's not used, so loading it with this updated version of the class will raise exceptions about unused keys. We've had to work around this with methods to load old state dicts like this.You would need to look at where this class is used and see if such adaptation is needed, but either way the
norm_cross_attnandcross_attnmembers should always exist. Sincenorm_cross_attnis pretty lightweight I'd instantiate it always, andcross_attnshould benn.Identity.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ericspod Thank you so much for the detailed explanation!
I have updated the implementation based on your feedback:
I also looked at DecoderOnlyTransformer in transformer.py (the only caller that uses with_cross_attention) and updated its load_old_state_dict method to guard the norm2 → norm_cross_attn key remapping so it only fires when the target key exists in the new state dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @chhayankjain this class is now what I'd suggest to do, however your fixes for loading the state dictionaries won't resolve the issue in other places where
TransformerBlockis used. For example, in ViT theTransformerBlockinstantiation will havewith_cross_attentionwith its default value of False, so your change to this class meansself.cross_attnwill be missing from the state dict. If you attempt to load old weights you'll get an exception saying there's unused weights. There would need to be aload_old_state_dictadded for this class and others like it which would check if a key withcross_attnis missing from the new state dict and skip it if present in the state dict being loaded. That's what these methods do in existing classes. There would need to be a lot more work done in a number of places to make this change work, I agree it is better to not instantiate things you don't need, but perhaps it's too much work.