Skip to content

Commit 6c25f26

Browse files
committed
split up tests and remove common ones that shoudl not be run for each model
1 parent ec13e39 commit 6c25f26

File tree

3 files changed

+1013
-1452
lines changed

3 files changed

+1013
-1452
lines changed
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
# Sentencepiece backend layer tests
2+
3+
import pickle
4+
import shutil
5+
import tempfile
6+
import unittest
7+
from typing import TYPE_CHECKING
8+
9+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
10+
from transformers.testing_utils import require_tokenizers
11+
from transformers.tokenization_utils import AddedToken
12+
13+
if TYPE_CHECKING:
14+
pass
15+
16+
17+
class SentencePieceBackendTesterMixin:
18+
"""
19+
Tests that specifically test the SentencePiece backend.
20+
"""
21+
22+
tokenizer_class = None
23+
rust_tokenizer_class = None
24+
test_sentencepiece = True
25+
test_sentencepiece_ignore_case = False
26+
test_slow_tokenizer = True
27+
test_rust_tokenizer = True
28+
from_pretrained_id = None
29+
from_pretrained_kwargs = None
30+
31+
@classmethod
32+
def setUpClass(cls) -> None:
33+
cls.tmpdirname = tempfile.mkdtemp()
34+
35+
@classmethod
36+
def tearDownClass(cls):
37+
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
38+
39+
@classmethod
40+
def get_tokenizer(cls, **kwargs) -> PreTrainedTokenizer:
41+
return cls.tokenizer_class.from_pretrained(cls.from_pretrained_id, **kwargs)
42+
43+
@classmethod
44+
def get_rust_tokenizer(cls, **kwargs) -> PreTrainedTokenizerFast:
45+
return cls.rust_tokenizer_class.from_pretrained(cls.from_pretrained_id, **kwargs)
46+
47+
def get_tokenizers(self, fast=True, **kwargs):
48+
if fast and self.test_rust_tokenizer and self.test_slow_tokenizer:
49+
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]
50+
elif fast and self.test_rust_tokenizer:
51+
return [self.get_rust_tokenizer(**kwargs)]
52+
elif self.test_slow_tokenizer:
53+
return [self.get_tokenizer(**kwargs)]
54+
else:
55+
raise ValueError("This tokenizer class has no tokenizer to be tested.")
56+
57+
def test_sentencepiece_tokenize_and_convert_tokens_to_string(self):
58+
"""Test ``_tokenize`` and ``convert_tokens_to_string``."""
59+
if not self.test_sentencepiece:
60+
self.skipTest(reason="test_sentencepiece is set to False")
61+
62+
tokenizer = self.get_tokenizer()
63+
text = "This is text to test the tokenizer."
64+
65+
if self.test_sentencepiece_ignore_case:
66+
text = text.lower()
67+
68+
tokens = tokenizer.tokenize(text)
69+
70+
self.assertTrue(len(tokens) > 0)
71+
72+
# check if converting back to original text works
73+
reverse_text = tokenizer.convert_tokens_to_string(tokens)
74+
75+
if self.test_sentencepiece_ignore_case:
76+
reverse_text = reverse_text.lower()
77+
78+
self.assertEqual(reverse_text, text)
79+
80+
special_tokens = tokenizer.all_special_tokens
81+
special_tokens_string = tokenizer.convert_tokens_to_string(special_tokens)
82+
for special_token in special_tokens:
83+
self.assertIn(special_token, special_tokens_string)
84+
85+
if self.test_rust_tokenizer:
86+
rust_tokenizer = self.get_rust_tokenizer()
87+
special_tokens_string_rust = rust_tokenizer.convert_tokens_to_string(special_tokens)
88+
self.assertEqual(special_tokens_string, special_tokens_string_rust)
89+
90+
def test_sentencepiece_tokenize_and_decode(self):
91+
if not self.test_sentencepiece:
92+
self.skipTest(reason="test_sentencepiece is set to False")
93+
94+
text = "This is text to test the tokenizer."
95+
if self.test_rust_tokenizer:
96+
tokenizer = self.get_tokenizer()
97+
rust_tokenizer = self.get_rust_tokenizer()
98+
99+
slow_ids = tokenizer(text).input_ids
100+
fast_ids = rust_tokenizer(text).input_ids
101+
self.assertEqual(slow_ids, fast_ids)
102+
103+
slow_decoded = tokenizer.decode(slow_ids)
104+
fast_decoded = rust_tokenizer.decode(slow_ids)
105+
self.assertEqual(slow_decoded, fast_decoded)
106+
107+
def test_save_sentencepiece_tokenizer(self) -> None:
108+
if not self.test_sentencepiece or not self.test_slow_tokenizer:
109+
self.skipTest(reason="test_sentencepiece or test_slow_tokenizer is set to False")
110+
# We want to verify that we will be able to save the tokenizer even if the original files that were used to
111+
# build the tokenizer have been deleted in the meantime.
112+
text = "This is text to test the tokenizer."
113+
114+
tokenizer_slow_1 = self.get_tokenizer()
115+
encoding_tokenizer_slow_1 = tokenizer_slow_1(text)
116+
117+
tmpdirname_1 = tempfile.mkdtemp()
118+
tmpdirname_2 = tempfile.mkdtemp()
119+
120+
tokenizer_slow_1.save_pretrained(tmpdirname_1)
121+
tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1)
122+
encoding_tokenizer_slow_2 = tokenizer_slow_2(text)
123+
124+
shutil.rmtree(tmpdirname_1)
125+
tokenizer_slow_2.save_pretrained(tmpdirname_2)
126+
127+
tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2)
128+
encoding_tokenizer_slow_3 = tokenizer_slow_3(text)
129+
shutil.rmtree(tmpdirname_2)
130+
131+
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
132+
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
133+
134+
def test_added_token_are_matched_longest_first(self):
135+
tokenizers = self.get_tokenizers(fast=False)
136+
for tokenizer in tokenizers:
137+
with self.subTest(f"{tokenizer.__class__.__name__}"):
138+
try:
139+
tokenizer.add_tokens([AddedToken("extra_id_1")])
140+
tokenizer.add_tokens([AddedToken("extra_id_100")])
141+
except Exception:
142+
# Canine cannot add tokens which are not codepoints
143+
self.skipTest(reason="Cannot add those Added tokens")
144+
145+
# XXX: This used to split on `extra_id_1` first we're matching
146+
# longest first now.
147+
tokens = tokenizer.tokenize("This is some extra_id_100")
148+
self.assertIn("extra_id_100", tokens)
149+
150+
for tokenizer in tokenizers:
151+
with self.subTest(f"{tokenizer.__class__.__name__}"):
152+
tokenizer.add_tokens([AddedToken("extra_id_100")])
153+
tokenizer.add_tokens([AddedToken("extra_id_1")])
154+
155+
tokens = tokenizer.tokenize("This is some extra_id_100")
156+
self.assertIn("extra_id_100", tokens)
157+
158+
@require_tokenizers
159+
def test_pickle_added_tokens(self):
160+
tok1 = AddedToken("<s>", rstrip=True, lstrip=True, normalized=False, single_word=True)
161+
tok2 = pickle.loads(pickle.dumps(tok1))
162+
163+
self.assertEqual(tok1.__getstate__(), tok2.__getstate__())
164+
165+
def test_added_tokens_do_lower_case(self):
166+
tokenizer = self.get_tokenizer(do_lower_case=True)
167+
if not hasattr(tokenizer, "do_lower_case") or not tokenizer.do_lower_case:
168+
self.skipTest(reason="Tokenizer does not support do_lower_case")
169+
170+
special_token = tokenizer.all_special_tokens[0]
171+
172+
text = special_token + " aaaaa bbbbbb low cccccccccdddddddd l " + special_token
173+
text2 = special_token + " AAAAA BBBBBB low CCCCCCCCCDDDDDDDD l " + special_token
174+
175+
toks_before_adding = tokenizer.tokenize(text) # toks before adding new_toks
176+
177+
new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd", "AAAAA BBBBBB", "CCCCCCCCCDDDDDDDD"]
178+
added = tokenizer.add_tokens([AddedToken(tok, lstrip=True, rstrip=True) for tok in new_toks])
179+
180+
toks_after_adding = tokenizer.tokenize(text)
181+
toks_after_adding2 = tokenizer.tokenize(text2)
182+
183+
# Rust tokenizers don't lowercase added tokens at the time calling `tokenizer.add_tokens`,
184+
# while python tokenizers do, so new_toks 0 and 2 would be treated as the same, so do new_toks 1 and 3.
185+
self.assertIn(added, [2, 4])
186+
187+
self.assertListEqual(toks_after_adding, toks_after_adding2)
188+
self.assertTrue(
189+
len(toks_before_adding) > len(toks_after_adding), # toks_before_adding should be longer
190+
)
191+
192+
# Check that none of the special tokens are lowercased
193+
sequence_with_special_tokens = "A " + " yEs ".join(tokenizer.all_special_tokens) + " B"
194+
# Convert the tokenized list to str as some special tokens are tokenized like normal tokens
195+
# which have a prefix spacee e.g. the mask token of Albert, and cannot match the original
196+
# special tokens exactly.
197+
tokenized_sequence = "".join(tokenizer.tokenize(sequence_with_special_tokens))
198+
199+
for special_token in tokenizer.all_special_tokens:
200+
self.assertTrue(special_token in tokenized_sequence or special_token.lower() in tokenized_sequence)
201+
202+
def test_add_tokens_tokenizer(self):
203+
tokenizer = self.get_tokenizer(do_lower_case=False)
204+
vocab_size = tokenizer.vocab_size
205+
all_size = len(tokenizer)
206+
207+
self.assertNotEqual(vocab_size, 0)
208+
209+
new_toks = [
210+
AddedToken("aaaaa bbbbbb", rstrip=True, lstrip=True),
211+
AddedToken("cccccccccdddddddd", rstrip=True, lstrip=True),
212+
]
213+
added_toks = tokenizer.add_tokens(new_toks)
214+
vocab_size_2 = tokenizer.vocab_size
215+
all_size_2 = len(tokenizer)
216+
217+
self.assertNotEqual(vocab_size_2, 0)
218+
self.assertEqual(vocab_size, vocab_size_2)
219+
self.assertEqual(added_toks, len(new_toks))
220+
self.assertEqual(all_size_2, all_size + len(new_toks))
221+
222+
tokens = tokenizer.encode("aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False)
223+
224+
self.assertGreaterEqual(len(tokens), 4)
225+
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
226+
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
227+
228+
new_toks_2 = {
229+
"eos_token": AddedToken(">>>>|||<||<<|<<", rstrip=True, lstrip=True),
230+
"pad_token": AddedToken("<<<<<|||>|>>>>|>", rstrip=True, lstrip=True),
231+
}
232+
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
233+
vocab_size_3 = tokenizer.vocab_size
234+
all_size_3 = len(tokenizer)
235+
236+
self.assertNotEqual(vocab_size_3, 0)
237+
self.assertEqual(vocab_size, vocab_size_3)
238+
self.assertEqual(added_toks_2, len(new_toks_2))
239+
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
240+
241+
tokens = tokenizer.encode(
242+
">>>>|||<||<<|<< aaaaa bbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", add_special_tokens=False
243+
)
244+
245+
self.assertGreaterEqual(len(tokens), 6)
246+
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
247+
self.assertGreater(tokens[0], tokens[1])
248+
249+
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
250+
self.assertGreater(tokens[-2], tokens[-3])
251+
self.assertEqual(tokens[0], tokenizer.eos_token_id)
252+
self.assertEqual(tokens[-2], tokenizer.pad_token_id)
253+
254+
def test_add_special_tokens(self):
255+
self.skipTest(reason="Redundant with test_add_tokens_tokenizer")
256+
257+
def test_add_tokens(self):
258+
if not self.test_rust_tokenizer:
259+
self.skipTest(reason="test_rust_tokenizer is set to False")
260+
261+
tokenizer_r = self.get_rust_tokenizer()
262+
263+
vocab_size = len(tokenizer_r)
264+
self.assertEqual(tokenizer_r.add_tokens(""), 0)
265+
self.assertEqual(tokenizer_r.add_tokens("testoken"), 1)
266+
self.assertEqual(tokenizer_r.add_tokens(["testoken1", "testtoken2"]), 2)
267+
self.assertEqual(len(tokenizer_r), vocab_size + 3)
268+
269+
self.assertEqual(tokenizer_r.add_special_tokens({}), 0)
270+
self.assertEqual(tokenizer_r.add_special_tokens({"bos_token": "[BOS]", "eos_token": "[EOS]"}), 2)
271+
self.assertRaises(
272+
AssertionError, tokenizer_r.add_special_tokens, {"additional_special_tokens": "<testtoken1>"}
273+
)
274+
self.assertEqual(tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken2>"]}), 1)
275+
self.assertEqual(
276+
tokenizer_r.add_special_tokens({"additional_special_tokens": ["<testtoken3>", "<testtoken4>"]}), 2
277+
)
278+
self.assertIn("<testtoken3>", tokenizer_r.special_tokens_map["additional_special_tokens"])
279+
self.assertIsInstance(tokenizer_r.special_tokens_map["additional_special_tokens"], list)
280+
self.assertGreaterEqual(len(tokenizer_r.special_tokens_map["additional_special_tokens"]), 2)
281+
282+
self.assertEqual(len(tokenizer_r), vocab_size + 8)
283+
284+
def test_compare_add_special_tokens(self):
285+
if not self.test_rust_tokenizer:
286+
self.skipTest(reason="test_rust_tokenizer is set to False")
287+
288+
tokenizer_r = self.get_rust_tokenizer()
289+
290+
simple_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=False)
291+
292+
for text in ["", " "]:
293+
# tokenize()
294+
no_special_tokens = tokenizer_r.tokenize(text, add_special_tokens=False)
295+
with_special_tokens = tokenizer_r.tokenize(text, add_special_tokens=True)
296+
self.assertEqual(
297+
len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add
298+
)
299+
300+
# Single input
301+
no_special_tokens = tokenizer_r(text, add_special_tokens=False)
302+
with_special_tokens = tokenizer_r(text, add_special_tokens=True)
303+
for key in no_special_tokens:
304+
self.assertEqual(
305+
len(no_special_tokens[key]),
306+
len(with_special_tokens[key]) - simple_num_special_tokens_to_add,
307+
)
308+
309+
# Batched input
310+
no_special_tokens = tokenizer_r([text, text], add_special_tokens=False)
311+
with_special_tokens = tokenizer_r([text, text], add_special_tokens=True)
312+
for key in no_special_tokens:
313+
for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]):
314+
self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add)
315+
316+
def test_special_tokens_initialization(self):
317+
if not self.test_rust_tokenizer:
318+
self.skipTest(reason="test_rust_tokenizer is set to False")
319+
320+
added_tokens = [AddedToken("<special>", lstrip=True)]
321+
tokenizer_r = self.get_rust_tokenizer(additional_special_tokens=added_tokens)
322+
r_output = tokenizer_r.encode("Hey this is a <special> token")
323+
324+
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
325+
326+
self.assertTrue(special_token_id in r_output)
327+
328+
def test_special_token_addition(self):
329+
tokenizer = self.get_tokenizer()
330+
# Create tokenizer and add an additional special token
331+
tokenizer.add_special_tokens({"additional_special_tokens": ["<tok>"]})
332+
self.assertEqual(tokenizer.additional_special_tokens, ["<tok>"])
333+
with tempfile.TemporaryDirectory() as tmp_dir:
334+
tokenizer.save_pretrained(tmp_dir)
335+
# Load the above tokenizer and add the same special token a second time
336+
tokenizer_2 = self.tokenizer_class.from_pretrained(tmp_dir)
337+
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>"]})
338+
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>"])
339+
340+
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<tok>", "<other>"]})
341+
self.assertEqual(tokenizer_2.additional_special_tokens, ["<tok>", "<other>"])
342+
343+
tokenizer_2.add_special_tokens({"additional_special_tokens": ["<other>", "<another>"]})
344+
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>"])
345+
346+
tokenizer_2.add_special_tokens(
347+
{"additional_special_tokens": ["<tok>"]},
348+
replace_additional_special_tokens=False,
349+
)
350+
self.assertEqual(tokenizer_2.additional_special_tokens, ["<other>", "<another>", "<tok>"])

0 commit comments

Comments
 (0)