Skip to content

[FSDP] improve compiler passes#424

Draft
mayank31398 wants to merge 12 commits into
mainfrom
fsdp
Draft

[FSDP] improve compiler passes#424
mayank31398 wants to merge 12 commits into
mainfrom
fsdp

Conversation

@mayank31398

Copy link
Copy Markdown
Collaborator

No description provided.

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances the SimpleFSDP compilation pipeline by introducing several new FX graph optimization passes, including identity operation removal and view normalization. It also adds support for advanced bucketing strategies, such as manual transformer block reordering and asynchronous tensor parallel collectives. Feedback from the review identifies a potential deadlock risk when enabling symmetric memory within a compiler pass and points out AttributeError bugs in the Inductor backend where GraphModule objects are incorrectly handled. Additionally, the logic for identifying identity slices was flagged as brittle for graphs involving dynamic shapes.

Comment thread lm_engine/parallel/simple_fsdp/compile.py
Comment on lines 597 to +598
def _inductor_autobucketing_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule:
return schedule_overlap_bucketing_from_inductor_configs(gm.owning_module)
return schedule_overlap_bucketing(gm.owning_module, collective_bucketing=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In Inductor, post_grad_custom_post_pass receives a GraphModule as its argument, not a Graph. Since GraphModule does not have an owning_module attribute, gm.owning_module will raise an AttributeError. You should pass gm directly to schedule_overlap_bucketing.

Suggested change
def _inductor_autobucketing_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule:
return schedule_overlap_bucketing_from_inductor_configs(gm.owning_module)
return schedule_overlap_bucketing(gm.owning_module, collective_bucketing=True)
def _inductor_autobucketing_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
return schedule_overlap_bucketing(gm, collective_bucketing=True)

Comment on lines +618 to +621
def _inductor_tb_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule:
return manual_overlap_bucketing(
gm.owning_module, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=True
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the autobucketing pass, gm here is a GraphModule. Accessing gm.owning_module will fail. Pass gm directly to manual_overlap_bucketing.

Suggested change
def _inductor_tb_pass(gm: torch.fx.Graph) -> torch.fx.GraphModule:
return manual_overlap_bucketing(
gm.owning_module, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=True
)
def _inductor_tb_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
return manual_overlap_bucketing(
gm, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=True
)

Comment on lines +128 to +134
input_node = args[0]
dim = args[1] if len(args) > 1 else 0
start = args[2] if len(args) > 2 else 0
end = args[3] if len(args) > 3 else sys.maxsize
step = args[4] if len(args) > 4 else 1
if start != 0 or step != 1:
continue

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for identifying identity slices is brittle because it assumes dim, start, end, and step are concrete integers. In FX graphs, these can be torch.fx.Node objects (e.g., when using dynamic shapes or computed indices). Comparing a Node to an integer (like start != 0) will not work as intended. Consider adding isinstance(..., int) checks to ensure these are literals before performing the comparison.

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant