feat: add dtype option in LayerNorm class#1359
Conversation
|
Thanks Carl! The overall direction looks right (storing
The key question that determines the shape of this PR: does CK
Note this only matters on the unfused path ( |
|
Thanks for the fast review @ChuanLi1101. I tested the unfused path as well and layernorm2d_fwd is accepting the fp32 weight/bias and accuracy is good:
I've updated the code to only update the initial LN parameters to fp32 and removed the unnecessary casting logic in forward. Previously stated performance remains unchanged. |
Motivation
This is a companion PR to ROCm/aiter#3451.
ROCm/aiter#3451 requires that LayerNorm weights are in fp32 for use in the DeepSeek v3.2 fused indexer, this is done to prevent casting weights from fp32 to bf16 which results in a loss of precision. This PR aligns with this by making the LayerNorm weights FP32 for the DeepSeek v3.2 ATOM pipeline.
Test Plan
Verify performance and accuracy before and after changes from this PR in combination with the companion aiter PR.
Test Result
Performance comparison
Configuration: ISL=1000, OSL=100, CONC=4
Accuracy
FP32 LayerNorm
Main
Submission Checklist