Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ def forward(self, x):

class TestQAT(TestCase):
SEED = 123
DEVICE = (
torch.accelerator.current_accelerator()
if torch.accelerator.is_available()
else None
)

def test_fake_quantize_per_channel_group(self):
n_bit = 4
Expand Down Expand Up @@ -347,7 +352,7 @@ def _set_ptq_weight(
group_size,
)
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to("cuda"),
q_weight.to(self.DEVICE),
qat_linear.inner_k_tiles,
)
ptq_linear.weight = q_weight
Expand Down Expand Up @@ -600,13 +605,13 @@ def _assert_close_4w(self, val, ref):
print(mean_err)
self.assertTrue(mean_err < 0.05)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(DEVICE is None, "skipping when GPU is not available")
def test_qat_4w_primitives(self):
n_bit = 4
group_size = 32
inner_k_tiles = 8
scales_precision = torch.bfloat16
device = torch.device("cuda")
device = self.DEVICE
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
x = torch.randn(100, 256, dtype=dtype, device=device)
Expand Down Expand Up @@ -651,13 +656,13 @@ def test_qat_4w_primitives(self):

self._assert_close_4w(qat_out, ptq_out)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(DEVICE is None, "skipping when GPU is not available")
def test_qat_4w_linear(self):
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear

group_size = 128
device = torch.device("cuda")
device = self.DEVICE
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
qat_linear = Int4WeightOnlyQATLinear(
Expand Down Expand Up @@ -692,15 +697,19 @@ def test_qat_4w_quantizer_gradients(self):
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
self._test_qat_quantized_gradients(quantizer)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(DEVICE is None, "skipping when GPU is not available")
@unittest.skipIf(
DEVICE is torch.device("xpu"),
"skipped due to https://github.com/intel/torch-xpu-ops/issues/1770",
)
def test_qat_4w_quantizer(self):
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer

group_size = 32
inner_k_tiles = 8
device = torch.device("cuda")
dtype = torch.bfloat16
device = self.DEVICE
torch.manual_seed(self.SEED)
m = M().to(device).to(dtype)
m2 = copy.deepcopy(m)
Expand All @@ -709,8 +718,7 @@ def test_qat_4w_quantizer(self):
inner_k_tiles=inner_k_tiles,
)
ptq_quantizer = Int4WeightOnlyQuantizer(
groupsize=group_size,
inner_k_tiles=inner_k_tiles,
groupsize=group_size, inner_k_tiles=inner_k_tiles, device=device
)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)
Expand Down Expand Up @@ -1891,12 +1899,12 @@ def _test_quantize_api_against_ptq(
torch.manual_seed(self.SEED)

if module_type == "linear":
m = M().to(dtype).cuda()
example_inputs = (m.example_inputs()[0].to(dtype).cuda(),)
m = M().to(dtype).to(self.DEVICE)
example_inputs = (m.example_inputs()[0].to(dtype).to(self.DEVICE),)
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear)
elif module_type == "embedding":
m = M3().to(dtype).cuda()
example_inputs = (m.example_inputs()[0].cuda(),)
m = M3().to(dtype).to(self.DEVICE)
example_inputs = (m.example_inputs()[0].to(self.DEVICE),)
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding)
else:
raise ValueError(f"Unknown module type {module_type}")
Expand Down Expand Up @@ -1971,7 +1979,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat
target_convert_sqnr=float("inf"),
)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(DEVICE is None, "skipping when GPU is not available")
def test_quantize_api_int8_int4(self):
"""
Test the following:
Expand All @@ -1984,7 +1992,7 @@ def test_quantize_api_int8_int4(self):
target_convert_sqnr=float("inf"),
)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(DEVICE is None, "skipping when GPU is not available")
@parametrize(
"weight_dtype, weight_granularity, dtype",
[
Expand All @@ -2009,7 +2017,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype):
dtype=dtype,
)

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(DEVICE is None, "skipping when GPU is not available")
@parametrize(
"weight_dtype, granularity, dtype, module_type",
[
Expand Down