Skip to content

Commit 6e21a1f

Browse files
authored
[xpu][test] Port 2 test/quantization_{gptq, quant_primitive} UT files to intel XPU (#3350)
* port 2 test/quantization UTs to intel XPU * update format * update format
1 parent 7a2a7b3 commit 6e21a1f

File tree

2 files changed

+30
-16
lines changed

2 files changed

+30
-16
lines changed

test/quantization/test_gptq.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@
1818
from torchao._models.llama.tokenizer import get_tokenizer
1919
from torchao.quantization import Int4WeightOnlyConfig, quantize_
2020
from torchao.quantization.utils import compute_error
21+
from torchao.utils import get_current_accelerator_device
2122

2223
torch.manual_seed(0)
2324

25+
_DEVICE = get_current_accelerator_device()
26+
2427

2528
class TestGPTQ(TestCase):
2629
@unittest.skip("skipping until we get checkpoints for gpt-fast")
27-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
30+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
2831
def test_gptq_quantizer_int4_weight_only(self):
2932
from torchao._models._eval import (
3033
LMEvalInputRecorder,
@@ -33,7 +36,7 @@ def test_gptq_quantizer_int4_weight_only(self):
3336
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
3437

3538
precision = torch.bfloat16
36-
device = "cuda"
39+
device = _DEVICE
3740
checkpoint_path = Path(
3841
"../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"
3942
)
@@ -80,15 +83,15 @@ def test_gptq_quantizer_int4_weight_only(self):
8083
)
8184
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
8285

83-
model = quantizer.quantize(model, *inputs).cuda()
86+
model = quantizer.quantize(model, *inputs).to(_DEVICE)
8487

8588
model.reset_caches()
86-
with torch.device("cuda"):
89+
with torch.device(_DEVICE):
8790
model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size)
8891

8992
limit = 1
9093
result = TransformerEvalWrapper(
91-
model.cuda(),
94+
model.to(_DEVICE),
9295
tokenizer,
9396
model.config.block_size,
9497
prepare_inputs_for_model,
@@ -104,7 +107,7 @@ def test_gptq_quantizer_int4_weight_only(self):
104107

105108

106109
class TestMultiTensorFlow(TestCase):
107-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
110+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
108111
def test_multitensor_add_tensors(self):
109112
from torchao.quantization.GPTQ import MultiTensor
110113

@@ -116,7 +119,7 @@ def test_multitensor_add_tensors(self):
116119
self.assertTrue(torch.equal(mt.values[0], tensor1))
117120
self.assertTrue(torch.equal(mt.values[1], tensor2))
118121

119-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
122+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
120123
def test_multitensor_pad_unpad(self):
121124
from torchao.quantization.GPTQ import MultiTensor
122125

@@ -127,7 +130,7 @@ def test_multitensor_pad_unpad(self):
127130
mt.unpad()
128131
self.assertEqual(mt.count, 1)
129132

130-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
133+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
131134
def test_multitensor_inplace_operation(self):
132135
from torchao.quantization.GPTQ import MultiTensor
133136

@@ -138,7 +141,7 @@ def test_multitensor_inplace_operation(self):
138141

139142

140143
class TestMultiTensorInputRecorder(TestCase):
141-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
144+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
142145
def test_multitensor_input_recorder(self):
143146
from torchao.quantization.GPTQ import MultiTensor, MultiTensorInputRecorder
144147

@@ -159,7 +162,7 @@ def test_multitensor_input_recorder(self):
159162
self.assertTrue(isinstance(MT_input[2][2], MultiTensor))
160163
self.assertEqual(MT_input[3], torch.float)
161164

162-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
165+
@unittest.skipIf(not torch.accelerator.is_available(), "Need GPU available")
163166
def test_gptq_with_input_recorder(self):
164167
from torchao.quantization.GPTQ import (
165168
Int4WeightOnlyGPTQQuantizer,
@@ -170,7 +173,7 @@ def test_gptq_with_input_recorder(self):
170173

171174
config = ModelArgs(n_layer=2)
172175

173-
with torch.device("cuda"):
176+
with torch.device(_DEVICE):
174177
model = Transformer(config)
175178
model.setup_caches(max_batch_size=2, max_seq_length=100)
176179
idx = torch.randint(1, 10000, (10, 2, 50)).to(torch.int32)
@@ -191,7 +194,14 @@ def test_gptq_with_input_recorder(self):
191194

192195
args = input_recorder.get_recorded_inputs()
193196

194-
quantizer = Int4WeightOnlyGPTQQuantizer()
197+
if _DEVICE.type == "xpu":
198+
from torchao.dtypes import Int4XPULayout
199+
200+
quantizer = Int4WeightOnlyGPTQQuantizer(
201+
device=torch.device("xpu"), layout=Int4XPULayout()
202+
)
203+
else:
204+
quantizer = Int4WeightOnlyGPTQQuantizer()
195205

196206
quantizer.quantize(model, *args)
197207

test/quantization/test_quant_primitives.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,15 @@
3636
from torchao.utils import (
3737
check_cpu_version,
3838
check_xpu_version,
39+
get_current_accelerator_device,
3940
is_fbcode,
4041
)
4142

4243
_SEED = 1234
4344
torch.manual_seed(_SEED)
4445

46+
_DEVICE = get_current_accelerator_device()
47+
4548

4649
# Helper function to run a function twice
4750
# and verify that the result is the same.
@@ -592,16 +595,17 @@ def test_choose_qparams_tensor_asym_eps(self):
592595
self.assertEqual(scale, eps)
593596

594597
@unittest.skipIf(
595-
not torch.cuda.is_available(), "skipping when cuda is not available"
598+
not torch.accelerator.is_available(), "skipping when gpu is not available"
596599
)
597600
def test_get_group_qparams_symmetric_memory(self):
598601
"""Check the memory usage of the op"""
599-
weight = torch.randn(1024, 1024).to(device="cuda")
600-
original_mem_use = torch.cuda.memory_allocated()
602+
weight = torch.randn(1024, 1024).to(device=_DEVICE)
603+
device_module = torch.get_device_module(_DEVICE)
604+
original_mem_use = device_module.memory_allocated()
601605
n_bit = 4
602606
groupsize = 128
603607
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
604-
after_choose_qparams_mem_use = torch.cuda.memory_allocated()
608+
after_choose_qparams_mem_use = device_module.memory_allocated()
605609
self.assertTrue(after_choose_qparams_mem_use < 1.2 * original_mem_use)
606610

607611
def test_raises(self):

0 commit comments

Comments
 (0)