Skip to content

Commit b90c6f7

Browse files
jiqing-fengsarathc-cerebras
authored andcommitted
* fix compressed tensor tests Signed-off-by: jiqing-feng <[email protected]> * update Signed-off-by: jiqing-feng <[email protected]> * update comment Signed-off-by: jiqing-feng <[email protected]> * format Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]>
1 parent 1caeb1f commit b90c6f7

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

tests/quantization/compressed_tensors_integration/test_compressed_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def _has_nested_attr(obj, attr_path):
8080
if comp_decomp_obj is not None and hasattr(submodule, "weight"):
8181
if "sparse-only" in uncompressed_model:
8282
self.assertTrue(
83-
torch.equal(submodule.weight, comp_decomp_obj.weight),
83+
torch.equal(
84+
submodule.weight.to(torch_device), comp_decomp_obj.weight.to(torch_device)
85+
),
8486
f"Weight mismatch for module '{name}' in sparse-only model.",
8587
)
8688
else:

tests/quantization/compressed_tensors_integration/test_compressed_tensors.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
import unittest
33

44
from transformers import AutoModelForCausalLM, AutoTokenizer, CompressedTensorsConfig
5-
from transformers.testing_utils import backend_empty_cache, require_compressed_tensors, require_torch, torch_device
5+
from transformers.testing_utils import (
6+
backend_empty_cache,
7+
require_compressed_tensors,
8+
require_deterministic_for_xpu,
9+
require_torch,
10+
torch_device,
11+
)
612
from transformers.utils import is_torch_available
713

814

@@ -47,22 +53,33 @@ def test_config_to_from_dict(self):
4753
self.assertIsInstance(config_from_dict.sparsity_config, SparsityCompressionConfig)
4854

4955
def test_tinyllama_w8a8(self):
50-
expected_out = "<s> Paris is the capital of which country?\n\n**A) 10** Paris is the capital of which country?\n\n**B) 11** Paris is the capital of which country?\n\n**C) 1"
56+
expected_out = [
57+
"<s> Paris is the capital of which country?\n\n**A) 10** Paris is the capital of which country?\n\n**B) 11** Paris is the capital of which country?\n\n**C) 1",
58+
"<s> Paris is the capital of which country?\n\n** 10.** Which country is the capital of which country?\n\n** 11.** Which country is the capital of which country?\n\n** 12.", # XPU
59+
]
5160
self._test_quantized_model(self.tinyllama_w8a8, expected_out)
5261

5362
def test_tinyllama_w4a16(self):
54-
expected_out = "<s> Paris is the capital of which country?\nAnswer: Paris is the capital of France.\nQuestion: Which country is the capital of which city?\nAnswer: The capital of the city of New York is New York.\nQuestion: Which"
63+
expected_out = [
64+
"<s> Paris is the capital of which country?\nAnswer: Paris is the capital of France.\nQuestion: Which country is the capital of which city?\nAnswer: The capital of the city of New York is New York.\nQuestion: Which"
65+
]
5566
self._test_quantized_model(self.tinyllama_w4a16, expected_out)
5667

5768
def test_tinyllama_w8a16(self):
58-
expected_out = "<s> Paris is the capital of which country?\nA. France\nB. Germany\nC. Spain\nD. Italy\nE. Switzerland\nQ10. Which of the following is not a country in the European Union?\nA."
69+
expected_out = [
70+
"<s> Paris is the capital of which country?\nA. France\nB. Germany\nC. Spain\nD. Italy\nE. Switzerland\nQ10. Which of the following is not a country in the European Union?\nA."
71+
]
5972
self._test_quantized_model(self.tinyllama_w8a16, expected_out)
6073

6174
def test_llama_8b_fp8(self):
62-
expected_out = "<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous art museum in Paris? The Louvre\nWhat is the name of the famous bridge in Paris? Pont des Arts\nWhat is the name of the famous opera? "
75+
expected_out = [
76+
"<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous art museum in Paris? The Louvre\nWhat is the name of the famous bridge in Paris? Pont des Arts\nWhat is the name of the famous opera? ",
77+
"<|begin_of_text|>Paris is the capital of which country? France\nWhat is the name of the famous art museum in Paris? The Louvre\nWhat is the name of the famous bridge in Paris? Pont des Arts\nWhat is the name of the famous opera", # XPU
78+
]
6379
self._test_quantized_model(self.llama3_8b_fp8, expected_out)
6480

65-
def _test_quantized_model(self, model_name: str, expected_output: str):
81+
@require_deterministic_for_xpu
82+
def _test_quantized_model(self, model_name: str, expected_output: list):
6683
"""Carry out generation"""
6784
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
6885
tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -84,4 +101,4 @@ def _test_quantized_model(self, model_name: str, expected_output: str):
84101
outputs = tokenizer.batch_decode(generated_ids)
85102

86103
self.assertIsNotNone(outputs)
87-
self.assertEqual(outputs[0], expected_output)
104+
self.assertIn(outputs[0], expected_output)

0 commit comments

Comments
 (0)