-
Notifications
You must be signed in to change notification settings - Fork 363
float8 inference: fix bmm semantics #3296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3296
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c97b030 with merge base a257166 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Fixes the `Float8Tensor` `torch.bmm` override to match the semantics of the high precision op. Specifically, input 1 is of shape (B, M, K) and input 2 is of shape (B, K, N). Previously, the shape expectation from `torch.bmm`, which is confusing. This is important for quantizing LLaMa 4 MoE variants, which use `torch.bmm` in the HF implementation. Test Plan: ``` pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -s -x -k bmm ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 3f8887b ghstack-comment-id: 3493356198 Pull-Request: #3296
| res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( | ||
| a_data, | ||
| b_data, | ||
| b_data.transpose(-2, -1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will performance be a concern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transpose is just metadata, no impact on performance
| m = Model(weight).eval() | ||
| original = m(input) | ||
| # we need to transpose the weight first for bmm | ||
| m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was from llama-models I think: https://github.com/meta-llama/llama-models/blob/0e0b8c519242d5833d8c11bffc1232b77ad7f301/models/llama4/quantization/loader.py#L142, although not as important now
but I guess the important thing is how do we implement it in a way that it can be used by different implementations, would current fp8 bmm implementation work for different ways people use bmm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is overriding torch.bmm, so we definitely should match the semantics of torch.bmm in terms of input shapes. It doesn't make sense to do a bmm with shapes that aren't (B, M, K) and (B, K, N). If that breaks llama-models, then they should fix it to match bmm semantics.
jerryzh168
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, discussed offline that we want to do a contiguous() call for weight if the model weight for bmm is not transposed (will happen in a separate PR)
Summary:
Fixes the
Float8Tensortorch.bmmoverride to match the semantics of thehigh precision op. Specifically, input 1 is of shape (B, M, K) and input
2 is of shape (B, K, N).
Previously, the shape expectation from
torch.bmmwas not consistent between high precision and quantized versions, which is confusing.This is important for quantizing LLaMa 4 MoE variants, which use
torch.bmmin the HF implementation.Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: