You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Fixes the `Float8Tensor` `torch.bmm` override to match the semantics of the
high precision op. Specifically, input 1 is of shape (B, M, K) and input
2 is of shape (B, K, N).
Previously, the shape expectation from `torch.bmm`, which is confusing.
This is important for quantizing LLaMa 4 MoE variants, which use
`torch.bmm` in the HF implementation.
Test Plan:
```
pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -s -x -k bmm
```
Reviewers:
Subscribers:
Tasks:
Tags:
ghstack-source-id: 9e16572
ghstack-comment-id: 3493356198
Pull-Request: #3296
0 commit comments