Skip to content

Commit b2e0bcd

Browse files
committed
[wip] float8 rowwise quant along row 1 of tensor rank 2
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f511df3 ghstack-comment-id: 3497584430 Pull-Request: #3303
1 parent 6815e57 commit b2e0bcd

File tree

6 files changed

+125
-11
lines changed

6 files changed

+125
-11
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.testing._internal import common_utils
1616
from torch.testing._internal.common_utils import run_tests
1717

18+
from torchao.core.config import config_from_dict, config_to_dict
1819
from torchao.quantization import (
1920
Float8DynamicActivationFloat8WeightConfig,
2021
Float8WeightOnlyConfig,
@@ -466,6 +467,44 @@ def forward(self, x):
466467
sqnr = compute_error(original, quantized)
467468
self.assertTrue(sqnr > 20)
468469

470+
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
471+
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
472+
def test_bmm_weight_in_bkn_layout(self):
473+
# Tests rowwise quantization of a 3d weight stored with shape (B, K, N)
474+
# and contigous with that shape. Since the `K` dimension is not last, we
475+
# need to specify granularity with `PerRow(1)`.
476+
477+
# only support per row quantization
478+
granularity = [PerRow(), PerRow(1)]
479+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
480+
481+
class Model(torch.nn.Module):
482+
def __init__(self, weight):
483+
super().__init__()
484+
self.weight = weight
485+
486+
def forward(self, x):
487+
return torch.bmm(x, self.weight)
488+
489+
dtype = torch.bfloat16
490+
device = "cuda"
491+
492+
B, M, K, N = 10, 32, 128, 256
493+
494+
input = torch.randn(B, M, K, dtype=dtype, device=device)
495+
weight = torch.randn(B, K, N, dtype=dtype, device=device)
496+
m = Model(weight).eval()
497+
original = m(input)
498+
quantize_(m, config, filter_fn=lambda x, fqn: True)
499+
500+
assert m.weight.scale.shape == (B, 1, N), (
501+
f"unexpected scale shape {m.weight.scale.shape}"
502+
)
503+
504+
quantized = m(input)
505+
sqnr = compute_error(original, quantized)
506+
self.assertTrue(sqnr > 20)
507+
469508
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
470509
@common_utils.parametrize(
471510
"sizes",
@@ -807,6 +846,32 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape):
807846

808847
self.assertEqual(sliced_dequantized, sliced_original)
809848

849+
def test_per_row_config_before_dim(self):
850+
"""
851+
Test that loading a serialized config of `PerRow` before the `dim`
852+
argument was introduced works properly
853+
"""
854+
855+
# create a config with PerRow granularity
856+
config = Float8DynamicActivationFloat8WeightConfig(
857+
granularity=PerRow(),
858+
)
859+
860+
# serialize it
861+
config_ser = config_to_dict(config)
862+
863+
# manually modify the serialized config to match v1
864+
# reference: https://gist.github.com/vkuzo/d347c4f8b8121819483d2d31e79f7335
865+
del config_ser["_data"]["granularity"][0]["_data"]["dim"]
866+
del config_ser["_data"]["granularity"][1]["_data"]["dim"]
867+
assert len(config_ser["_data"]["granularity"][0]["_data"]) == 0
868+
assert len(config_ser["_data"]["granularity"][1]["_data"]) == 0
869+
870+
# load the modified version, verify that granularity is as expected
871+
config_deser = config_from_dict(config_ser)
872+
assert config_deser.granularity[0].dim == -1
873+
assert config_deser.granularity[1].dim == -1
874+
810875

811876
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)
812877

test/quantization/test_quant_primitives.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212

13+
from torchao.quantization.granularity import PerRow
1314
from torchao.quantization.quant_primitives import (
1415
MappingType,
1516
ZeroPointDomain,
@@ -27,6 +28,7 @@
2728
# TODO: remove test for utils?
2829
from torchao.quantization.utils import (
2930
_quantize_activation_per_token_absmax,
31+
get_block_size,
3032
get_group_qparams_symmetric,
3133
groupwise_affine_dequantize_tensor_from_qparams,
3234
groupwise_affine_quantize_tensor_from_qparams,
@@ -844,6 +846,29 @@ def test_float8_blockwise_scaling(self):
844846
torch.testing.assert_close(scale, ref_scale, atol=0, rtol=0)
845847
torch.testing.assert_close(data.float(), ref_data.float(), atol=0, rtol=0)
846848

849+
def test_float8_rowwise_scaling_3d_weight_axis_1(self):
850+
"""
851+
Test scaling a weight with shape (B, K, N) and row-major memory layout
852+
across the K dimension.
853+
"""
854+
855+
B, K, N = 8, 16, 32
856+
hp_tensor = torch.randn(B, K, N, dtype=torch.float)
857+
858+
granularity = PerRow(1)
859+
block_size = get_block_size(hp_tensor.shape, granularity)
860+
scale = _choose_scale_float8(
861+
hp_tensor,
862+
float8_dtype=torch.float8_e4m3fn,
863+
block_size=block_size,
864+
hp_value_lb=None,
865+
hp_value_ub=None,
866+
)
867+
data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn)
868+
869+
assert scale.shape == (B, 1, N)
870+
assert data.shape == (B, K, N)
871+
847872

848873
if __name__ == "__main__":
849874
unittest.main()

torchao/quantization/granularity.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ class PerAxis(Granularity):
3939
This granularity type calculates different quantization parameters
4040
along a specified axis of the tensor.
4141
42-
For example if the input tensor is shape [8, 16] and axis=0, then
43-
the quantization parameters are calculated for each row of the tensor.
44-
Giving a total of 8 quantization parameters.
42+
Examples:
43+
* input_tensor shape [A, B], axis 0 -> scale_shape [A, 1]
44+
* input_tensor shape [A, B], axis 1 -> scale_shape [1, B]
45+
* input_tensor shape [A, B, C], axis 1 -> scale_shape [1, B, 1]
4546
4647
Attributes:
47-
axis (int): The axis along which reduction is performed.
48+
axis (int): The axis which is kept, reduction is performed across all
49+
the other axes
4850
"""
4951

5052
axis: int
@@ -76,12 +78,19 @@ class PerRow(Granularity):
7678
"""
7779
Represents row-wise granularity in quantization.
7880
79-
This is a special case of per-axis quantization and is unique to Float8 matmuls
80-
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
81-
is quantized with a block_size of (1, weight.shape[1]).
81+
Examples:
82+
* input_tensor shape [A, B], dim 0 -> scale_shape [1, B]
83+
* input_tensor shape [A, B], dim 1 -> scale_shape [A, 1]
84+
* input_tensor shape [A, B], dim -1 -> scale_shape [A, 1]
85+
* input_tensor shape [A, B, C], dim 1 -> scale_shape [A, 1, C]
86+
87+
Attributes:
88+
dim (int): The dim which is reduced across, all other dims are kept
8289
"""
8390

84-
pass
91+
# TODO(before land): any BC concerns with loading old checkpoints
92+
# serialized without this arg? investigate this
93+
dim: int = -1
8594

8695

8796
@dataclass(frozen=True)

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ def from_hp(
180180
and _is_fbgemm_gpu_genai_available()
181181
and is_sm_at_least_90()
182182
and isinstance(granularity, PerRow)
183+
# fbgemm path only supports quantizing along the last dim
184+
and granularity.dim in (-1, len(hp_tensor.shape) - 1)
183185
and float8_dtype == torch.float8_e4m3fn
184186
and hp_value_lb is None
185187
):
@@ -438,7 +440,7 @@ def _(func, types, args, kwargs):
438440

439441
res = torch.ops.fbgemm.f8f8bf16_rowwise_batched(
440442
a_data,
441-
b_data.transpose(-2, -1),
443+
b_data.transpose(-2, -1).contiguous(),
442444
a_scale,
443445
b_scale.transpose(-2, -1),
444446
b_scale,

torchao/quantization/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,12 @@ def get_block_size(
723723
f"Not all shapes in input shape {input_shape} are divisible by block size {block_size}"
724724
)
725725
return block_size
726-
elif isinstance(granularity, (PerRow, PerToken)):
726+
elif isinstance(granularity, PerToken):
727727
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
728+
elif isinstance(granularity, PerRow):
729+
block_size = [1] * len(input_shape)
730+
block_size[granularity.dim] = input_shape[granularity.dim]
731+
return tuple(block_size)
728732
elif isinstance(granularity, PerGroup):
729733
assert input_shape[-1] % granularity.group_size == 0, (
730734
f"Last dimension of input {input_shape[-1]} is not divisible by group size {granularity.group_size}"

torchao/testing/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,9 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig):
444444
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
445445
# making the weight different
446446
dummy_l.weight = torch.nn.Parameter(
447-
dummy_l.weight + 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
447+
dummy_l.weight
448+
+ 1.0
449+
+ 2 * torch.randn(1024, 1024, device=device, dtype=dtype),
448450
requires_grad=False,
449451
)
450452
quantize_(dummy_l, config)
@@ -460,6 +462,13 @@ def _test_slice_and_copy_similar_to_vllm(self, config: AOBaseConfig):
460462
loaded_weight = dummy_l.weight
461463
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
462464

465+
# debugging CI failures
466+
# TODO(before land): remove this
467+
if torch.equal(orig_value, loaded_weight.qdata[0][0]):
468+
print("param_data.qdata", param_data.qdata)
469+
print("orig_value", orig_value)
470+
print("loaded_weight.qdata", loaded_weight.qdata)
471+
463472
# making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
464473
assert not torch.equal(orig_value, loaded_weight.qdata[0][0])
465474
param_data.copy_(loaded_weight)

0 commit comments

Comments
 (0)