[FSDP] improve compiler passes#424
Conversation
There was a problem hiding this comment.
Code Review
This pull request enhances the SimpleFSDP compilation pipeline by introducing several new FX graph optimization passes, including identity operation removal and view normalization. It also adds support for advanced bucketing strategies, such as manual transformer block reordering and asynchronous tensor parallel collectives. Feedback from the review identifies a potential deadlock risk when enabling symmetric memory within a compiler pass and points out AttributeError bugs in the Inductor backend where GraphModule objects are incorrectly handled. Additionally, the logic for identifying identity slices was flagged as brittle for graphs involving dynamic shapes.
| def _inductor_autobucketing_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule: | ||
| return schedule_overlap_bucketing_from_inductor_configs(gm.owning_module) | ||
| return schedule_overlap_bucketing(gm.owning_module, collective_bucketing=True) |
There was a problem hiding this comment.
In Inductor, post_grad_custom_post_pass receives a GraphModule as its argument, not a Graph. Since GraphModule does not have an owning_module attribute, gm.owning_module will raise an AttributeError. You should pass gm directly to schedule_overlap_bucketing.
| def _inductor_autobucketing_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule: | |
| return schedule_overlap_bucketing_from_inductor_configs(gm.owning_module) | |
| return schedule_overlap_bucketing(gm.owning_module, collective_bucketing=True) | |
| def _inductor_autobucketing_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | |
| return schedule_overlap_bucketing(gm, collective_bucketing=True) |
| def _inductor_tb_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule: | ||
| return manual_overlap_bucketing( | ||
| gm.owning_module, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=True | ||
| ) |
There was a problem hiding this comment.
Similar to the autobucketing pass, gm here is a GraphModule. Accessing gm.owning_module will fail. Pass gm directly to manual_overlap_bucketing.
| def _inductor_tb_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule: | |
| return manual_overlap_bucketing( | |
| gm.owning_module, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=True | |
| ) | |
| def _inductor_tb_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | |
| return manual_overlap_bucketing( | |
| gm, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=True | |
| ) |
| input_node = args[0] | ||
| dim = args[1] if len(args) > 1 else 0 | ||
| start = args[2] if len(args) > 2 else 0 | ||
| end = args[3] if len(args) > 3 else sys.maxsize | ||
| step = args[4] if len(args) > 4 else 1 | ||
| if start != 0 or step != 1: | ||
| continue |
There was a problem hiding this comment.
The logic for identifying identity slices is brittle because it assumes dim, start, end, and step are concrete integers. In FX graphs, these can be torch.fx.Node objects (e.g., when using dynamic shapes or computed indices). Comparing a Node to an integer (like start != 0) will not work as intended. Consider adding isinstance(..., int) checks to ensure these are literals before performing the comparison.
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
No description provided.