-
Notifications
You must be signed in to change notification settings - Fork 369
Add NPU (Ascend) backend support for INT4 weight-only quantization workflow #3172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f3aefca
68eea61
164435e
06c77d1
498f052
ea2aa7a
ca8f056
05af947
25360da
fa3220f
623c589
89ad729
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,12 +5,12 @@ | |
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import tempfile | ||
| import unittest | ||
|
|
||
| import pytest | ||
| import torch | ||
| from torch.testing._internal.common_device_type import instantiate_device_type_tests | ||
| from torch.testing._internal.common_utils import ( | ||
| TestCase, | ||
| instantiate_parametrized_tests, | ||
| parametrize, | ||
| run_tests, | ||
| ) | ||
|
|
@@ -33,9 +33,19 @@ def get_config(group_size): | |
| ) | ||
|
|
||
|
|
||
| @unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+") | ||
| @unittest.skipIf(not torch.xpu.is_available(), "XPU not available") | ||
| class Int4PlainInt32Tensor(TestCase): | ||
| _MIN_VER = { | ||
| "xpu": "2.8.0", | ||
| "npu": "2.7.1", | ||
| } | ||
|
|
||
| def setUp(self): | ||
| min_req = type(self)._MIN_VER.get(self.device_type) | ||
| if not torch_version_at_least(min_req): | ||
| self.skipTest( | ||
| f"{self.device_type} requires torch >= {min_req}, current {torch.__version__}" | ||
| ) | ||
|
|
||
| @parametrize( | ||
| "sizes", | ||
| [ | ||
|
|
@@ -46,24 +56,35 @@ class Int4PlainInt32Tensor(TestCase): | |
| ) | ||
| @parametrize("dtype", [torch.bfloat16, torch.half]) | ||
| @parametrize("group_size", [32, 64, 128]) | ||
| def test_linear(self, sizes, dtype, group_size): | ||
| device = "xpu" | ||
| @parametrize("thresholds", [{"xpu": 20, "npu": 10}]) | ||
| def test_linear(self, device, sizes, dtype, group_size, thresholds): | ||
| M, N, K = sizes | ||
| if "npu" in device and group_size == K: | ||
| pytest.skip( | ||
| f"{device} does not support group_size equal to K dimension ({group_size} == {K})" | ||
| ) | ||
| threshold = thresholds.get(device.split(":")[0]) | ||
|
|
||
| input = torch.randn(*M, K, dtype=dtype, device=device) | ||
| linear = torch.nn.Linear(K, N, dtype=dtype, device=device) | ||
| original = linear(input) | ||
| quantize_(linear, get_config(group_size)) | ||
| quantized = linear(input) | ||
| self.assertTrue(compute_error(original, quantized) > 20) | ||
| self.assertTrue(compute_error(original, quantized) > threshold) | ||
|
|
||
| compiled_linear = torch.compile(linear) | ||
| quantized_and_compiled = compiled_linear(input) | ||
| self.assertTrue(compute_error(original, quantized_and_compiled) > 20) | ||
| if "xpu" in device: | ||
| compiled_linear = torch.compile(linear) | ||
| quantized_and_compiled = compiled_linear(input) | ||
| self.assertTrue(compute_error(original, quantized_and_compiled) > threshold) | ||
|
|
||
| @parametrize("dtype", [torch.bfloat16, torch.half]) | ||
| def test_module_path(self, dtype): | ||
| linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu") | ||
| quantize_(linear, get_config(group_size=128)) | ||
| def test_module_path(self, device, dtype): | ||
| K, N, group_size = 128, 256, 128 | ||
| if "npu" in device: | ||
| group_size = 64 | ||
|
|
||
| linear = torch.nn.Linear(K, N, dtype=dtype, device=device) | ||
| quantize_(linear, get_config(group_size)) | ||
| self.assertEqual( | ||
| str(type(linear.weight)), | ||
| "<class 'torchao.quantization.Int4PlainInt32Tensor'>", | ||
|
|
@@ -78,13 +99,21 @@ def test_module_path(self, dtype): | |
| "<class 'torchao.quantization.Int4PlainInt32Tensor'>", | ||
| ) | ||
|
|
||
| def test_activation_prescaling(self): | ||
| dtype = torch.bfloat16 | ||
| device = "xpu" | ||
| input = torch.randn(1, 128, dtype=dtype, device=device) | ||
| linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) | ||
| @parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
| @parametrize("thresholds", [{"xpu": 20, "npu": 10}]) | ||
| def test_activation_prescaling(self, device, dtype, thresholds): | ||
| if "xpu" in device and dtype == torch.float16: | ||
| pytest.skip(f"{device} test_activation_prescaling don't test {dtype}") | ||
|
|
||
| threshold = thresholds.get(device.split(":")[0]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does device_type have
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, good catch — you’re right! I actually meant to use the function argument device there, but forgot to remove
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed. |
||
| K, N, group_size = 128, 256, 128 | ||
| if "npu" in device: | ||
| group_size = 64 | ||
|
|
||
| input = torch.randn(1, K, dtype=dtype, device=device) | ||
| linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device) | ||
| original = linear(input) | ||
| quantize_(linear, get_config(128)) | ||
| quantize_(linear, get_config(group_size)) | ||
| qw = linear.weight | ||
| assert isinstance(qw, SupportsActivationPreScaling), ( | ||
| "Expected int4 tensor supports activation prescaling" | ||
|
|
@@ -95,10 +124,12 @@ def test_activation_prescaling(self): | |
| quantized = linear(input) | ||
|
|
||
| # making sure activation pre scaling is successfully applied to the activation | ||
| self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20) | ||
| self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > threshold) | ||
|
|
||
|
|
||
| instantiate_parametrized_tests(Int4PlainInt32Tensor) | ||
| instantiate_device_type_tests( | ||
| Int4PlainInt32Tensor, globals(), only_for=("xpu", "npu"), allow_xpu=True | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.