Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,25 +444,27 @@ def test_bmm(self):
# only support per row quantization
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())

class M(torch.nn.Module):
class Model(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight

def forward(self, x):
return torch.bmm(x, self.weight)
return torch.bmm(x, self.weight.transpose(-2, -1))

dtype = torch.bfloat16
device = "cuda"
input = torch.randn(10, 32, 128, dtype=dtype, device=device)
weight = torch.randn(10, 128, 256, dtype=dtype, device=device)
m = M(weight).eval()

B, M, K, N = 10, 32, 128, 256

input = torch.randn(B, M, K, dtype=dtype, device=device)
weight = torch.randn(B, N, K, dtype=dtype, device=device)
m = Model(weight).eval()
original = m(input)
# we need to transpose the weight first for bmm
m.weight = torch.nn.Parameter(m.weight.transpose(1, 2).contiguous())
Copy link
Contributor

@jerryzh168 jerryzh168 Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was from llama-models I think: https://github.com/meta-llama/llama-models/blob/0e0b8c519242d5833d8c11bffc1232b77ad7f301/models/llama4/quantization/loader.py#L142, although not as important now

but I guess the important thing is how do we implement it in a way that it can be used by different implementations, would current fp8 bmm implementation work for different ways people use bmm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is overriding torch.bmm, so we definitely should match the semantics of torch.bmm in terms of input shapes. It doesn't make sense to do a bmm with shapes that aren't (B, M, K) and (B, K, N). If that breaks llama-models, then they should fix it to match bmm semantics.

quantize_(m, config, filter_fn=lambda x, fqn: True)
quantized = m(input)
self.assertTrue(compute_error(original, quantized) > 20)
sqnr = compute_error(original, quantized)
self.assertTrue(sqnr > 20)

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@common_utils.parametrize(
Expand Down
13 changes: 7 additions & 6 deletions torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,24 +422,25 @@ def _(func, types, args, kwargs):
a_scale = input_tensor.scale

b_data = weight_tensor.qdata
b_scale = weight_tensor.scale.squeeze(-1)
assert b_data.is_contiguous(), "weight for bmm must be contiguous"
b_scale = weight_tensor.scale

assert (
all(x == 1 for x in weight_tensor.block_size[:-1])
and weight_tensor.block_size[-1] == weight_tensor.shape[-1]
weight_tensor.block_size[0] == 1
and weight_tensor.block_size[1] == weight_tensor.shape[1]
and weight_tensor.block_size[2] == 1
), "bmm only works for per row weight quantization"
assert (
all(x == 1 for x in input_tensor.block_size[:-1])
and input_tensor.block_size[-1] == input_tensor.shape[-1]
), "bmm only works for per row activation quantization"

orig_out_features = b_data.shape[-2]
orig_out_features = b_data.shape[-1]

res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
a_data,
b_data,
b_data.transpose(-2, -1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will performance be a concern?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transpose is just metadata, no impact on performance

a_scale,
b_scale.transpose(-2, -1),
b_scale,
)
res = res.reshape(*orig_act_size[:-1], orig_out_features)
Expand Down
Loading