-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Refactor RMSNorm implementations to use torch.nn.functional.rms_norm #42461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Refactor RMSNorm implementations to use torch.nn.functional.rms_norm #42461
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: aimv2, apertus, arcee, aria, bamba, bitnet, blt, chameleon, clvp, csm, cwm, deepseek_v2, deepseek_v3, dia, diffllama, doge |
|
Hey @mstojkovicTT, thanks for the PR! We definitely want the functions to be a drop-in replacement, so they should return exactly the same dtype as the old functions did. Also, in your tests you're initializing |
This may be problematic, because of the following scenario:
return self.weight * hidden_states.to(input_dtype)
Here is the pytorch default implementation of
I did the same testing just with the additional with torch.no_grad():
random_weight = torch.randn(hidden_size, device=device) * 0.05 + 1.0
hf_module.weight.copy_(random_weight)
new_module.weight.copy_(random_weight)and everything still works. And also @Rocketknight1, thank you for taking a time to review this! |
|
Yeah, I understand! My guess is that the original code was made to follow the original model implementation, even if it seems weird. Downcasting |
|
Ah, that makes sense. What do you think about trying CI run against the changes for now, and if it breaks something I will take a closer look? @Rocketknight1 |
|
Sure, let's see how it goes |
|
run-slow: llama |
|
This comment contains models: ["models/llama"] |
CI ResultsModel CI Report❌ Failed tests
|
What does this PR do?
Fixes #42398
This PR replaces custom
RMSNorm/T5-stylenorm implementations (e.g. in Llama) that manually compute variance and scaling with the built-intorch.nn.functional.rms_norm. For example, code like:is simplified to:
This keeps the behavior and epsilon handling the same while reducing the number of ops, this should improve performance for users without requiring any additional changes on their side.
To verify the performance and the numerical stability, i have wrote the following test
The results show the following:
note: I have encountered that when I try
dtypesthat are lower thenfloat32, old implementation will keep it atfloat32, but my new one will have thedtypeof the input tensor. Thats why i have to cast toy_hf.dtype(tryingfloat64for example will make both implementation outputfloat64). This can be changed, depending on what we want to accomplish.Who can review?
@Rocketknight1