Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/1xL4_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ jobs:
pytest test/dtypes/test_affine_quantized_float.py --verbose -s
./test/float8/test_everything_single_gpu.sh
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py
python test/kernel/test_blockwise_triton.py --verbose -s
2 changes: 1 addition & 1 deletion benchmarks/benchmark_blockwise_scaled_linear_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from triton.testing import do_bench

from torchao.float8.float8_utils import compute_error
from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
from torchao.kernel.blockwise_quantization import (
blockwise_fp8_gemm,
fp8_blockwise_act_quant,
fp8_blockwise_weight_quant,
Expand Down
4 changes: 2 additions & 2 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_invalid_granularity(self):
def test_mismatched_granularity(self):
with pytest.raises(
ValueError,
match="Different granularities for activation and weight are not supported",
match="Unsupported granularity types",
):
Float8DynamicActivationFloat8WeightConfig(
granularity=(PerTensor(), PerRow())
Expand All @@ -165,7 +165,7 @@ def test_unsupported_granularity(self):
class UnsupportedGranularity:
pass

with pytest.raises(ValueError, match="Invalid granularity types"):
with pytest.raises(ValueError, match="Unsupported granularity types"):
Float8DynamicActivationFloat8WeightConfig(
granularity=(UnsupportedGranularity(), UnsupportedGranularity()),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

triton = pytest.importorskip("triton", reason="Triton required to run this test")

from torchao.prototype.blockwise_fp8_inference.blockwise_quantization import (
from torchao.kernel.blockwise_quantization import (
blockwise_fp8_gemm,
fp8_blockwise_act_quant,
fp8_blockwise_weight_dequant,
Expand Down
47 changes: 38 additions & 9 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
PerBlock,
PerRow,
PerTensor,
quantize_,
Expand Down Expand Up @@ -64,7 +65,10 @@ def setUp(self):
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@common_utils.parametrize(
"granularity",
[PerTensor(), PerRow(), (PerBlock((1, 128)), PerBlock((128, 128)))],
)
@common_utils.parametrize(
"kernel_preference",
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
Expand All @@ -74,7 +78,7 @@ def setUp(self):
"sizes",
[
((128,), 256, 128),
((32, 128), 64, 256),
((32, 128), 256, 512),
],
)
def test_fp8_linear_variants(
Expand All @@ -86,13 +90,21 @@ def test_fp8_linear_variants(
kernel_preference: KernelPreference,
sizes: Tuple,
):
if (
isinstance(granularity, PerTensor)
and kernel_preference == KernelPreference.FBGEMM
):
return unittest.skip(
"per tensor with fbgemm kernel preferece does not work yet"
)
if isinstance(granularity, PerTensor):
if kernel_preference is KernelPreference.FBGEMM:
return unittest.skip(
"per tensor with fbgemm kernel preference does not work yet"
)
elif mode == "weight-only":
return unittest.skip("unimplemented")

elif granularity == (PerBlock((1, 128)), PerBlock((128, 128))):
if dtype is torch.float32:
return unittest.skip("unimplemented")
elif mode == "weight-only":
return unittest.skip("unimplemented")
elif kernel_preference is KernelPreference.FBGEMM:
return unittest.skip("unimplemented")

error_message = None
if isinstance(granularity, PerRow):
Expand Down Expand Up @@ -137,6 +149,20 @@ def test_fp8_linear_variants(

quantize_(quantized_model, config)

# ensure weight scaling is what we expect
qs1 = quantized_model.linear1.weight.scale
qs2 = quantized_model.linear2.weight.scale
if granularity == PerTensor():
assert qs1.shape == (1, 1)
assert qs2.shape == (1, 1)
elif granularity == PerRow():
assert qs1.shape == (N, 1)
assert qs2.shape == (K, 1)
else:
assert granularity == (PerBlock((1, 128)), PerBlock((128, 128)))
assert qs1.shape == (N // 128, K // 128)
assert qs2.shape == (K // 128, N // 128)

if compile:
quantized_model = torch.compile(quantized_model, fullgraph=True)

Expand Down Expand Up @@ -294,6 +320,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity):
self._test_slice_and_copy_similar_to_vllm(config)

@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
def test_bmm(self):
# only support per row quantization
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
Expand Down Expand Up @@ -406,6 +433,7 @@ def test_cat(self, granularity, sizes):
self.assertEqual(cat_qweight2.scale, ref_scale)

@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
def test_moe_weight_reshape_ops(self):
# only per row quantization is supported for bmm
granularity = PerRow()
Expand All @@ -416,6 +444,7 @@ def test_moe_weight_reshape_ops(self):
# that should be moved here after v1 config is deprecated:
# https://github.com/pytorch/ao/issues/2649
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
@unittest.skipIf(not _is_fbgemm_gpu_genai_available(), "Need fbgemm_gpu_genai")
def test_expected_gpu_kernel_fbgemm(self):
"""Making sure KernelPreference.FBGEMM calls correct quantize and gemm kernels
and the bias add happens in the gemm kernel for per row quantization
Expand Down
75 changes: 59 additions & 16 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
Defines an nn module designed to be used during inference
"""

import math
from typing import List, NamedTuple, Optional, Tuple, Union

import torch

from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
from torchao.float8.types import FP8Granularity
from torchao.quantization.granularity import (
PerBlock,
PerRow,
PerTensor,
)
Expand Down Expand Up @@ -196,6 +198,36 @@ def _is_tensorwise_scaled(x: torch.Tensor) -> bool:
)


def _is_1_128_scaled(x: torch.Tensor) -> bool:
"""Checks if a quantized tensor is scaled with a block size of 1x128
Args:
x: quantized tensor (should have `block_size` attribute)
"""
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
b = x.block_size
return len(b) >= 2 and math.prod(b[:-1]) == 1 and b[-1] == 128


def _is_128_128_scaled(x: torch.Tensor) -> bool:
"""Checks if a quantized tensor is scaled with a block size of 128x128
Args:
x: quantized tensor (should have `block_size` attribute)
"""
assert hasattr(x, "block_size"), "Expecting input to have `block_size` attribute"
b = x.block_size
return len(b) == 2 and b[0] == 128 and b[1] == 128


def _granularity_is_a_1_128_w_128_128(
g: Union[
FP8Granularity,
Tuple[FP8Granularity, FP8Granularity],
list[FP8Granularity],
],
) -> bool:
return len(g) == 2 and g[0] == PerBlock((1, 128)) and g[1] == PerBlock((128, 128))


def _normalize_granularity(
granularity: Optional[
Union[
Expand All @@ -211,22 +243,23 @@ def _normalize_granularity(
elif isinstance(granularity, (PerTensor, PerRow)):
processed_granularity = (granularity, granularity)
elif isinstance(granularity, (tuple, list)) and len(granularity) == 2:
if not (
isinstance(granularity[0], (PerTensor, PerRow))
and isinstance(granularity[1], (PerTensor, PerRow))
):
raise ValueError(
f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported."
)
is_per_tensor = isinstance(granularity[0], PerTensor) and isinstance(
granularity[1], PerTensor
)
is_per_row = isinstance(granularity[0], PerRow) and isinstance(
granularity[1], PerRow
)
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularity)

if not (is_per_tensor or is_per_row or is_a_1_128_w_128_128):
raise ValueError(f"Unsupported granularity types: {granularity}.")
if not isinstance(granularity[0], type(granularity[1])):
raise ValueError(
f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported."
f"Different granularities for activation and weight are not supported: {granularity}."
)
processed_granularity = tuple(granularity)
else:
raise ValueError(
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
)
raise ValueError(f"Invalid granularity specification: {granularity}.")
return processed_granularity


Expand All @@ -243,12 +276,22 @@ def _check_hardware_support(
AssertionError: If hardware doesn't support the requested granularity
ValueError: If invalid granularity type is provided
"""
for _granularity in granularities:
if not isinstance(_granularity, (PerTensor, PerRow)):
raise ValueError(
f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported."
)
is_per_tensor = isinstance(granularities[0], PerTensor) and isinstance(
granularities[1], PerTensor
)
is_per_row = isinstance(granularities[0], PerRow) and isinstance(
granularities[1], PerRow
)
is_a_1_128_w_128_128 = _granularity_is_a_1_128_w_128_128(granularities)

if is_per_tensor or is_per_row:
assert is_sm_at_least_89() or is_MI300(), (
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+."
)
elif is_a_1_128_w_128_128:
# TODO(future PR): look into AMD support
assert is_sm_at_least_89(), (
"Float8 1x128 activation and 128x128 weight scaling requires CUDA compute capability ≥8.9."
)
else:
raise ValueError(f"Invalid granularities {granularities}.")
Loading
Loading