@@ -213,6 +213,11 @@ def forward(self, x):
213213
214214class TestQAT (TestCase ):
215215 SEED = 123
216+ DEVICE = (
217+ torch .accelerator .current_accelerator ()
218+ if torch .accelerator .is_available ()
219+ else None
220+ )
216221
217222 def test_fake_quantize_per_channel_group (self ):
218223 n_bit = 4
@@ -347,7 +352,7 @@ def _set_ptq_weight(
347352 group_size ,
348353 )
349354 q_weight = torch .ops .aten ._convert_weight_to_int4pack (
350- q_weight .to ("cuda" ),
355+ q_weight .to (self . DEVICE ),
351356 qat_linear .inner_k_tiles ,
352357 )
353358 ptq_linear .weight = q_weight
@@ -600,13 +605,13 @@ def _assert_close_4w(self, val, ref):
600605 print (mean_err )
601606 self .assertTrue (mean_err < 0.05 )
602607
603- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
608+ @unittest .skipIf (DEVICE is None , "skipping when GPU is not available" )
604609 def test_qat_4w_primitives (self ):
605610 n_bit = 4
606611 group_size = 32
607612 inner_k_tiles = 8
608613 scales_precision = torch .bfloat16
609- device = torch . device ( "cuda" )
614+ device = self . DEVICE
610615 dtype = torch .bfloat16
611616 torch .manual_seed (self .SEED )
612617 x = torch .randn (100 , 256 , dtype = dtype , device = device )
@@ -651,13 +656,13 @@ def test_qat_4w_primitives(self):
651656
652657 self ._assert_close_4w (qat_out , ptq_out )
653658
654- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
659+ @unittest .skipIf (DEVICE is None , "skipping when GPU is not available" )
655660 def test_qat_4w_linear (self ):
656661 from torchao .quantization .GPTQ import WeightOnlyInt4Linear
657662 from torchao .quantization .qat .linear import Int4WeightOnlyQATLinear
658663
659664 group_size = 128
660- device = torch . device ( "cuda" )
665+ device = self . DEVICE
661666 dtype = torch .bfloat16
662667 torch .manual_seed (self .SEED )
663668 qat_linear = Int4WeightOnlyQATLinear (
@@ -692,15 +697,19 @@ def test_qat_4w_quantizer_gradients(self):
692697 quantizer = Int4WeightOnlyQATQuantizer (groupsize = 32 , inner_k_tiles = 8 )
693698 self ._test_qat_quantized_gradients (quantizer )
694699
695- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
700+ @unittest .skipIf (DEVICE is None , "skipping when GPU is not available" )
701+ @unittest .skipIf (
702+ DEVICE is torch .device ("xpu" ),
703+ "skipped due to https://github.com/intel/torch-xpu-ops/issues/1770" ,
704+ )
696705 def test_qat_4w_quantizer (self ):
697706 from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
698707 from torchao .quantization .qat import Int4WeightOnlyQATQuantizer
699708
700709 group_size = 32
701710 inner_k_tiles = 8
702- device = torch .device ("cuda" )
703711 dtype = torch .bfloat16
712+ device = self .DEVICE
704713 torch .manual_seed (self .SEED )
705714 m = M ().to (device ).to (dtype )
706715 m2 = copy .deepcopy (m )
@@ -709,8 +718,7 @@ def test_qat_4w_quantizer(self):
709718 inner_k_tiles = inner_k_tiles ,
710719 )
711720 ptq_quantizer = Int4WeightOnlyQuantizer (
712- groupsize = group_size ,
713- inner_k_tiles = inner_k_tiles ,
721+ groupsize = group_size , inner_k_tiles = inner_k_tiles , device = device
714722 )
715723 qat_model = qat_quantizer .prepare (m )
716724 ptq_model = ptq_quantizer .quantize (m2 )
@@ -1891,12 +1899,12 @@ def _test_quantize_api_against_ptq(
18911899 torch .manual_seed (self .SEED )
18921900
18931901 if module_type == "linear" :
1894- m = M ().to (dtype ).cuda ( )
1895- example_inputs = (m .example_inputs ()[0 ].to (dtype ).cuda ( ),)
1902+ m = M ().to (dtype ).to ( self . DEVICE )
1903+ example_inputs = (m .example_inputs ()[0 ].to (dtype ).to ( self . DEVICE ),)
18961904 filter_fn = lambda m , fqn : isinstance (m , torch .nn .Linear )
18971905 elif module_type == "embedding" :
1898- m = M3 ().to (dtype ).cuda ( )
1899- example_inputs = (m .example_inputs ()[0 ].cuda ( ),)
1906+ m = M3 ().to (dtype ).to ( self . DEVICE )
1907+ example_inputs = (m .example_inputs ()[0 ].to ( self . DEVICE ),)
19001908 filter_fn = lambda m , fqn : isinstance (m , torch .nn .Embedding )
19011909 else :
19021910 raise ValueError (f"Unknown module type { module_type } " )
@@ -1971,7 +1979,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat
19711979 target_convert_sqnr = float ("inf" ),
19721980 )
19731981
1974- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1982+ @unittest .skipIf (DEVICE is None , "skipping when GPU is not available" )
19751983 def test_quantize_api_int8_int4 (self ):
19761984 """
19771985 Test the following:
@@ -1984,7 +1992,7 @@ def test_quantize_api_int8_int4(self):
19841992 target_convert_sqnr = float ("inf" ),
19851993 )
19861994
1987- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1995+ @unittest .skipIf (DEVICE is None , "skipping when GPU is not available" )
19881996 @parametrize (
19891997 "weight_dtype, weight_granularity, dtype" ,
19901998 [
@@ -2009,7 +2017,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype):
20092017 dtype = dtype ,
20102018 )
20112019
2012- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
2020+ @unittest .skipIf (DEVICE is None , "skipping when GPU is not available" )
20132021 @parametrize (
20142022 "weight_dtype, granularity, dtype, module_type" ,
20152023 [
0 commit comments