Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import pytest
import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
Expand All @@ -33,9 +33,19 @@ def get_config(group_size):
)


@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
@unittest.skipIf(not torch.xpu.is_available(), "XPU not available")
class Int4PlainInt32Tensor(TestCase):
_MIN_VER = {
"xpu": "2.8.0",
"npu": "2.7.1",
}

def setUp(self):
min_req = type(self)._MIN_VER.get(self.device_type)
if not torch_version_at_least(min_req):
self.skipTest(
f"{self.device_type} requires torch >= {min_req}, current {torch.__version__}"
)

@parametrize(
"sizes",
[
Expand All @@ -46,24 +56,35 @@ class Int4PlainInt32Tensor(TestCase):
)
@parametrize("dtype", [torch.bfloat16, torch.half])
@parametrize("group_size", [32, 64, 128])
def test_linear(self, sizes, dtype, group_size):
device = "xpu"
@parametrize("thresholds", [{"xpu": 20, "npu": 10}])
def test_linear(self, device, sizes, dtype, group_size, thresholds):
M, N, K = sizes
if "npu" in device and group_size == K:
pytest.skip(
f"{device} does not support group_size equal to K dimension ({group_size} == {K})"
)
threshold = thresholds.get(device.split(":")[0])

input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(group_size))
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)
self.assertTrue(compute_error(original, quantized) > threshold)

compiled_linear = torch.compile(linear)
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
if "xpu" in device:
compiled_linear = torch.compile(linear)
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > threshold)

@parametrize("dtype", [torch.bfloat16, torch.half])
def test_module_path(self, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu")
quantize_(linear, get_config(group_size=128))
def test_module_path(self, device, dtype):
K, N, group_size = 128, 256, 128
if "npu" in device:
group_size = 64

linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
quantize_(linear, get_config(group_size))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
Expand All @@ -78,13 +99,21 @@ def test_module_path(self, dtype):
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
)

def test_activation_prescaling(self):
dtype = torch.bfloat16
device = "xpu"
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
@parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("thresholds", [{"xpu": 20, "npu": 10}])
def test_activation_prescaling(self, device, dtype, thresholds):
if "xpu" in device and dtype == torch.float16:
pytest.skip(f"{device} test_activation_prescaling don't test {dtype}")

threshold = thresholds.get(device.split(":")[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does device_type have :? I thought it should only be things like xpu, npu, cuda, not cuda:0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good catch — you’re right! I actually meant to use the function argument device there, but forgot to remove device = self.device_type from the setup.
device can include the suffix like ":0", while device_type should not. I’ll fix that, thanks for pointing it out!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

K, N, group_size = 128, 256, 128
if "npu" in device:
group_size = 64

input = torch.randn(1, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(128))
quantize_(linear, get_config(group_size))
qw = linear.weight
assert isinstance(qw, SupportsActivationPreScaling), (
"Expected int4 tensor supports activation prescaling"
Expand All @@ -95,10 +124,12 @@ def test_activation_prescaling(self):
quantized = linear(input)

# making sure activation pre scaling is successfully applied to the activation
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > threshold)


instantiate_parametrized_tests(Int4PlainInt32Tensor)
instantiate_device_type_tests(
Int4PlainInt32Tensor, globals(), only_for=("xpu", "npu"), allow_xpu=True
)


if __name__ == "__main__":
Expand Down
8 changes: 6 additions & 2 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,12 @@ use_hqq = False
quantize_(model, Int4WeightOnlyConfig(group_size=group_size, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq"))
```

Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.

Note:
- The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
- Third-party backend CI status:
- Ascend NPU(requires torch_npu ≥ 2.7.1)
[![Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml/badge.svg)](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml)

#### A16W8 Int8 WeightOnly Quantization

```python
Expand Down
Loading