-
Notifications
You must be signed in to change notification settings - Fork 51
feat(tools): add backward graph generation and validation tools #711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
67a81d6
ea832a1
210a51f
af8fd20
db5f971
81a157a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,14 +27,18 @@ def __call__(self): | |
| module, forward_inputs = get_torch_module_and_inputs( | ||
| self.model_path, use_dummy_inputs=False, device=self.device | ||
| ) | ||
| module.train() | ||
| module.eval() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. model.eval() 不会禁用梯度计算,只有 torch.no_grad() / torch.inference_mode() 才会。eval 仅改变特定层的前向行为(dropout → identity,BatchNorm → 用 running stats 而非 batch stats),反向传播完全正常。而且使用 eval 模式反而更好
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不行,反向图生成时,这些算子就应该用
Dayuxiaoshui marked this conversation as resolved.
|
||
|
|
||
| eval_forward_dir = os.path.join( | ||
| self.output_dir, "eval_forward", self.rel_model_path | ||
| ) | ||
| if not os.path.exists(eval_forward_dir): | ||
| shutil.copytree(self.model_path, eval_forward_dir) | ||
|
|
||
| forward_inputs = [ | ||
| inp.detach().clone() if isinstance(inp, torch.Tensor) else inp | ||
| for inp in forward_inputs | ||
| ] | ||
| forward_inputs = self.set_requires_grad_for_forward_inputs( | ||
| self.model_path, module, forward_inputs | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
调用
hash_util.py中的函数。