Skip to content

Conversation

@vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Nov 5, 2025

Summary:

Fixes the Float8Tensor torch.bmm override to match the semantics of the
high precision op. Specifically, input 1 is of shape (B, M, K) and input
2 is of shape (B, K, N).

Previously, the shape expectation from torch.bmm was not consistent between high precision and quantized versions, which is confusing.

This is important for quantizing LLaMa 4 MoE variants, which use
torch.bmm in the HF implementation.

Test Plan:

pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -s -x -k bmm

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Nov 5, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3296

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c97b030 with merge base a257166 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vkuzo added a commit that referenced this pull request Nov 5, 2025
Summary:

Fixes the `Float8Tensor` `torch.bmm` override to match the semantics of the
high precision op. Specifically, input 1 is of shape (B, M, K) and input
2 is of shape (B, K, N).

Previously, the shape expectation from `torch.bmm`, which is confusing.

This is important for quantizing LLaMa 4 MoE variants, which use
`torch.bmm` in the HF implementation.

Test Plan:

```
pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -s -x -k bmm
```

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: 3f8887b
ghstack-comment-id: 3493356198
Pull-Request: #3296
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 5, 2025
@vkuzo vkuzo added the topic: bug fix Use this tag for PRs that fix bugs label Nov 5, 2025
res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
a_data,
b_data,
b_data.transpose(-2, -1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will performance be a concern?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transpose is just metadata, no impact on performance

m = Model(weight).eval()
original = m(input)
# we need to transpose the weight first for bmm
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
Copy link
Contributor

@jerryzh168 jerryzh168 Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was from llama-models I think: https://github.com/meta-llama/llama-models/blob/0e0b8c519242d5833d8c11bffc1232b77ad7f301/models/llama4/quantization/loader.py#L142, although not as important now

but I guess the important thing is how do we implement it in a way that it can be used by different implementations, would current fp8 bmm implementation work for different ways people use bmm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is overriding torch.bmm, so we definitely should match the semantics of torch.bmm in terms of input shapes. It doesn't make sense to do a bmm with shapes that aren't (B, M, K) and (B, K, N). If that breaks llama-models, then they should fix it to match bmm semantics.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, discussed offline that we want to do a contiguous() call for weight if the model weight for bmm is not transposed (will happen in a separate PR)

[ghstack-poisoned]
@vkuzo vkuzo merged commit 6815e57 into main Nov 6, 2025
50 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants