1818from torchao ._models .llama .tokenizer import get_tokenizer
1919from torchao .quantization import Int4WeightOnlyConfig , quantize_
2020from torchao .quantization .utils import compute_error
21+ from torchao .utils import get_current_accelerator_device
2122
2223torch .manual_seed (0 )
2324
25+ _DEVICE = get_current_accelerator_device ()
26+
2427
2528class TestGPTQ (TestCase ):
2629 @unittest .skip ("skipping until we get checkpoints for gpt-fast" )
27- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
30+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
2831 def test_gptq_quantizer_int4_weight_only (self ):
2932 from torchao ._models ._eval import (
3033 LMEvalInputRecorder ,
@@ -33,7 +36,7 @@ def test_gptq_quantizer_int4_weight_only(self):
3336 from torchao .quantization .GPTQ import Int4WeightOnlyGPTQQuantizer
3437
3538 precision = torch .bfloat16
36- device = "cuda"
39+ device = _DEVICE
3740 checkpoint_path = Path (
3841 "../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"
3942 )
@@ -80,15 +83,15 @@ def test_gptq_quantizer_int4_weight_only(self):
8083 )
8184 model .setup_caches (max_batch_size = 1 , max_seq_length = calibration_seq_length )
8285
83- model = quantizer .quantize (model , * inputs ).cuda ( )
86+ model = quantizer .quantize (model , * inputs ).to ( _DEVICE )
8487
8588 model .reset_caches ()
86- with torch .device ("cuda" ):
89+ with torch .device (_DEVICE ):
8790 model .setup_caches (max_batch_size = 1 , max_seq_length = model .config .block_size )
8891
8992 limit = 1
9093 result = TransformerEvalWrapper (
91- model .cuda ( ),
94+ model .to ( _DEVICE ),
9295 tokenizer ,
9396 model .config .block_size ,
9497 prepare_inputs_for_model ,
@@ -104,7 +107,7 @@ def test_gptq_quantizer_int4_weight_only(self):
104107
105108
106109class TestMultiTensorFlow (TestCase ):
107- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
110+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
108111 def test_multitensor_add_tensors (self ):
109112 from torchao .quantization .GPTQ import MultiTensor
110113
@@ -116,7 +119,7 @@ def test_multitensor_add_tensors(self):
116119 self .assertTrue (torch .equal (mt .values [0 ], tensor1 ))
117120 self .assertTrue (torch .equal (mt .values [1 ], tensor2 ))
118121
119- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
122+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
120123 def test_multitensor_pad_unpad (self ):
121124 from torchao .quantization .GPTQ import MultiTensor
122125
@@ -127,7 +130,7 @@ def test_multitensor_pad_unpad(self):
127130 mt .unpad ()
128131 self .assertEqual (mt .count , 1 )
129132
130- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
133+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
131134 def test_multitensor_inplace_operation (self ):
132135 from torchao .quantization .GPTQ import MultiTensor
133136
@@ -138,7 +141,7 @@ def test_multitensor_inplace_operation(self):
138141
139142
140143class TestMultiTensorInputRecorder (TestCase ):
141- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
144+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
142145 def test_multitensor_input_recorder (self ):
143146 from torchao .quantization .GPTQ import MultiTensor , MultiTensorInputRecorder
144147
@@ -159,7 +162,7 @@ def test_multitensor_input_recorder(self):
159162 self .assertTrue (isinstance (MT_input [2 ][2 ], MultiTensor ))
160163 self .assertEqual (MT_input [3 ], torch .float )
161164
162- @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
165+ @unittest .skipIf (not torch .accelerator .is_available (), "Need GPU available" )
163166 def test_gptq_with_input_recorder (self ):
164167 from torchao .quantization .GPTQ import (
165168 Int4WeightOnlyGPTQQuantizer ,
@@ -170,7 +173,7 @@ def test_gptq_with_input_recorder(self):
170173
171174 config = ModelArgs (n_layer = 2 )
172175
173- with torch .device ("cuda" ):
176+ with torch .device (_DEVICE ):
174177 model = Transformer (config )
175178 model .setup_caches (max_batch_size = 2 , max_seq_length = 100 )
176179 idx = torch .randint (1 , 10000 , (10 , 2 , 50 )).to (torch .int32 )
@@ -191,7 +194,14 @@ def test_gptq_with_input_recorder(self):
191194
192195 args = input_recorder .get_recorded_inputs ()
193196
194- quantizer = Int4WeightOnlyGPTQQuantizer ()
197+ if _DEVICE .type == "xpu" :
198+ from torchao .dtypes import Int4XPULayout
199+
200+ quantizer = Int4WeightOnlyGPTQQuantizer (
201+ device = torch .device ("xpu" ), layout = Int4XPULayout ()
202+ )
203+ else :
204+ quantizer = Int4WeightOnlyGPTQQuantizer ()
195205
196206 quantizer .quantize (model , * args )
197207
0 commit comments