Skip to content

Commit 8c59e66

Browse files
committed
[Intel GPU] Extend TestQAT module with xpu testcases
Add xpu mode to tests from test_qat.py TestQAT module
1 parent 0f05b40 commit 8c59e66

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

test/quantization/test_qat.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,11 @@ def forward(self, x):
213213

214214
class TestQAT(TestCase):
215215
SEED = 123
216+
DEVICE = (
217+
torch.accelerator.current_accelerator()
218+
if torch.accelerator.is_available()
219+
else None
220+
)
216221

217222
def test_fake_quantize_per_channel_group(self):
218223
n_bit = 4
@@ -347,7 +352,7 @@ def _set_ptq_weight(
347352
group_size,
348353
)
349354
q_weight = torch.ops.aten._convert_weight_to_int4pack(
350-
q_weight.to("cuda"),
355+
q_weight.to(self.DEVICE),
351356
qat_linear.inner_k_tiles,
352357
)
353358
ptq_linear.weight = q_weight
@@ -600,13 +605,13 @@ def _assert_close_4w(self, val, ref):
600605
print(mean_err)
601606
self.assertTrue(mean_err < 0.05)
602607

603-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
608+
@unittest.skipIf(DEVICE is None, "skipping when GPU is not available")
604609
def test_qat_4w_primitives(self):
605610
n_bit = 4
606611
group_size = 32
607612
inner_k_tiles = 8
608613
scales_precision = torch.bfloat16
609-
device = torch.device("cuda")
614+
device = self.DEVICE
610615
dtype = torch.bfloat16
611616
torch.manual_seed(self.SEED)
612617
x = torch.randn(100, 256, dtype=dtype, device=device)
@@ -651,13 +656,13 @@ def test_qat_4w_primitives(self):
651656

652657
self._assert_close_4w(qat_out, ptq_out)
653658

654-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
659+
@unittest.skipIf(DEVICE is None, "skipping when GPU is not available")
655660
def test_qat_4w_linear(self):
656661
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
657662
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear
658663

659664
group_size = 128
660-
device = torch.device("cuda")
665+
device = self.DEVICE
661666
dtype = torch.bfloat16
662667
torch.manual_seed(self.SEED)
663668
qat_linear = Int4WeightOnlyQATLinear(
@@ -692,15 +697,19 @@ def test_qat_4w_quantizer_gradients(self):
692697
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
693698
self._test_qat_quantized_gradients(quantizer)
694699

695-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
700+
@unittest.skipIf(DEVICE is None, "skipping when GPU is not available")
701+
@unittest.skipIf(
702+
DEVICE is torch.device("xpu"),
703+
"skipped due to https://github.com/intel/torch-xpu-ops/issues/1770",
704+
)
696705
def test_qat_4w_quantizer(self):
697706
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
698707
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
699708

700709
group_size = 32
701710
inner_k_tiles = 8
702-
device = torch.device("cuda")
703711
dtype = torch.bfloat16
712+
device = self.DEVICE
704713
torch.manual_seed(self.SEED)
705714
m = M().to(device).to(dtype)
706715
m2 = copy.deepcopy(m)
@@ -709,8 +718,7 @@ def test_qat_4w_quantizer(self):
709718
inner_k_tiles=inner_k_tiles,
710719
)
711720
ptq_quantizer = Int4WeightOnlyQuantizer(
712-
groupsize=group_size,
713-
inner_k_tiles=inner_k_tiles,
721+
groupsize=group_size, inner_k_tiles=inner_k_tiles, device=device
714722
)
715723
qat_model = qat_quantizer.prepare(m)
716724
ptq_model = ptq_quantizer.quantize(m2)
@@ -1891,12 +1899,12 @@ def _test_quantize_api_against_ptq(
18911899
torch.manual_seed(self.SEED)
18921900

18931901
if module_type == "linear":
1894-
m = M().to(dtype).cuda()
1895-
example_inputs = (m.example_inputs()[0].to(dtype).cuda(),)
1902+
m = M().to(dtype).to(self.DEVICE)
1903+
example_inputs = (m.example_inputs()[0].to(dtype).to(self.DEVICE),)
18961904
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear)
18971905
elif module_type == "embedding":
1898-
m = M3().to(dtype).cuda()
1899-
example_inputs = (m.example_inputs()[0].cuda(),)
1906+
m = M3().to(dtype).to(self.DEVICE)
1907+
example_inputs = (m.example_inputs()[0].to(self.DEVICE),)
19001908
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding)
19011909
else:
19021910
raise ValueError(f"Unknown module type {module_type}")
@@ -1971,7 +1979,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat
19711979
target_convert_sqnr=float("inf"),
19721980
)
19731981

1974-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1982+
@unittest.skipIf(DEVICE is None, "skipping when GPU is not available")
19751983
def test_quantize_api_int8_int4(self):
19761984
"""
19771985
Test the following:
@@ -1984,7 +1992,7 @@ def test_quantize_api_int8_int4(self):
19841992
target_convert_sqnr=float("inf"),
19851993
)
19861994

1987-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1995+
@unittest.skipIf(DEVICE is None, "skipping when GPU is not available")
19881996
@parametrize(
19891997
"weight_dtype, weight_granularity, dtype",
19901998
[
@@ -2009,7 +2017,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype):
20092017
dtype=dtype,
20102018
)
20112019

2012-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
2020+
@unittest.skipIf(DEVICE is None, "skipping when GPU is not available")
20132021
@parametrize(
20142022
"weight_dtype, granularity, dtype, module_type",
20152023
[

0 commit comments

Comments
 (0)