Skip to content

Commit a404e2c

Browse files
authored
Patch Mistral Tokenizer (#28146)
Signed-off-by: Julien Denize <[email protected]>
1 parent e31946f commit a404e2c

File tree

2 files changed

+42
-22
lines changed

2 files changed

+42
-22
lines changed

tests/tokenization/test_mistral_tokenizer.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -334,20 +334,20 @@ def test_encode_one(self, mistral_tokenizer: MistralTokenizer):
334334

335335
def test_encode(self, mistral_tokenizer: MistralTokenizer):
336336
token_ids = (
337-
[1, 22177, 4304, 2662, 2]
337+
[1, 22177, 4304, 2662]
338338
if mistral_tokenizer.is_tekken
339-
else [1, 23325, 2294, 1686, 2]
339+
else [1, 23325, 2294, 1686]
340340
)
341341

342-
assert mistral_tokenizer.encode("Hello world !") == token_ids[:-1]
343-
assert mistral_tokenizer.encode("Hello world !", max_length=3) == token_ids[:-2]
342+
assert mistral_tokenizer.encode("Hello world !") == token_ids
343+
assert mistral_tokenizer.encode("Hello world !", max_length=3) == token_ids[:-1]
344344
assert (
345345
mistral_tokenizer.encode("Hello world !", truncation=True, max_length=3)
346-
== token_ids[:-2]
346+
== token_ids[:-1]
347347
)
348348
assert (
349349
mistral_tokenizer.encode("Hello world !", truncation=False, max_length=3)
350-
== token_ids[:-1]
350+
== token_ids
351351
)
352352

353353
assert (
@@ -358,7 +358,7 @@ def test_encode(self, mistral_tokenizer: MistralTokenizer):
358358
mistral_tokenizer.encode(
359359
"Hello world !", add_special_tokens=True, max_length=3
360360
)
361-
== token_ids[:-2]
361+
== token_ids[:-1]
362362
)
363363
assert (
364364
mistral_tokenizer.encode(
@@ -368,7 +368,7 @@ def test_encode(self, mistral_tokenizer: MistralTokenizer):
368368
)
369369
assert (
370370
mistral_tokenizer.encode("Hello world !", add_special_tokens=False)
371-
== token_ids[1:-1]
371+
== token_ids[1:]
372372
)
373373

374374
@pytest.mark.parametrize(
@@ -1088,6 +1088,19 @@ def test_decode(
10881088
== expected_tokens[mistral_tokenizer.is_tekken]
10891089
)
10901090

1091+
def test_decode_int(
1092+
self,
1093+
mistral_tokenizer: MistralTokenizer,
1094+
):
1095+
ids = 1
1096+
assert (
1097+
mistral_tokenizer.decode(
1098+
ids,
1099+
skip_special_tokens=False,
1100+
)
1101+
== "<s>"
1102+
)
1103+
10911104
def test_convert_tokens_to_string(self, mistral_tokenizer: MistralTokenizer):
10921105
tokens = (
10931106
[

vllm/transformers_utils/tokenizers/mistral.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
165165

166166
class MistralTokenizer(TokenizerBase):
167167
def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
168+
from mistral_common.protocol.instruct.validator import ValidationMode
168169
from mistral_common.tokens.tokenizers.sentencepiece import (
169170
SentencePieceTokenizer,
170171
)
@@ -175,6 +176,14 @@ def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
175176
self.instruct = self.mistral.instruct_tokenizer
176177
self.tokenizer = self.instruct.tokenizer
177178

179+
mode = self.mistral._chat_completion_request_validator._mode
180+
if mode != ValidationMode.test:
181+
raise ValueError(
182+
"Mistral tokenizer must be in test mode. Make sure to "
183+
"set `mode='ValidationMode.test'` when creating the "
184+
"Mistral tokenizer."
185+
)
186+
178187
_mistral_version_str = str(self.tokenizer.version.value)
179188
self.version: int = int(_mistral_version_str.split("v")[-1])
180189

@@ -205,14 +214,15 @@ def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
205214
def from_pretrained(
206215
cls, path_or_repo_id: str, *, revision: str | None = None
207216
) -> "MistralTokenizer":
217+
from mistral_common.protocol.instruct.validator import ValidationMode
208218
from transformers.tokenization_mistral_common import (
209219
MistralCommonTokenizer as TransformersMistralTokenizer,
210220
)
211221

212222
str_revision = "main" if revision is None else revision
213223
return cls(
214224
TransformersMistralTokenizer.from_pretrained(
215-
path_or_repo_id, revision=str_revision
225+
path_or_repo_id, revision=str_revision, mode=ValidationMode.test
216226
)
217227
)
218228

@@ -339,20 +349,14 @@ def encode(
339349
max_length: int | None = None,
340350
add_special_tokens: bool | None = None,
341351
) -> list[int]:
342-
if add_special_tokens is not None:
343-
return self.transformers_tokenizer.encode(
344-
text,
345-
truncation=truncation,
346-
max_length=max_length,
347-
add_special_tokens=add_special_tokens,
348-
)
349-
else:
350-
encoded = self.tokenizer.encode(text, bos=True, eos=False)
352+
encoded = self.tokenizer.encode(
353+
text, bos=add_special_tokens is not False, eos=False
354+
)
351355

352-
if truncation is not False and max_length is not None:
353-
return encoded[:max_length]
354-
else:
355-
return encoded
356+
if truncation is not False and max_length is not None:
357+
return encoded[:max_length]
358+
else:
359+
return encoded
356360

357361
def apply_chat_template(
358362
self,
@@ -383,6 +387,9 @@ def apply_chat_template(
383387
)
384388

385389
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
390+
if isinstance(ids, int):
391+
ids = [ids]
392+
386393
return self.transformers_tokenizer.decode(
387394
ids, skip_special_tokens=skip_special_tokens
388395
)

0 commit comments

Comments
 (0)