@@ -165,6 +165,7 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
165165
166166class MistralTokenizer (TokenizerBase ):
167167 def __init__ (self , tokenizer : "TransformersMistralTokenizer" ) -> None :
168+ from mistral_common .protocol .instruct .validator import ValidationMode
168169 from mistral_common .tokens .tokenizers .sentencepiece import (
169170 SentencePieceTokenizer ,
170171 )
@@ -175,6 +176,14 @@ def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
175176 self .instruct = self .mistral .instruct_tokenizer
176177 self .tokenizer = self .instruct .tokenizer
177178
179+ mode = self .mistral ._chat_completion_request_validator ._mode
180+ if mode != ValidationMode .test :
181+ raise ValueError (
182+ "Mistral tokenizer must be in test mode. Make sure to "
183+ "set `mode='ValidationMode.test'` when creating the "
184+ "Mistral tokenizer."
185+ )
186+
178187 _mistral_version_str = str (self .tokenizer .version .value )
179188 self .version : int = int (_mistral_version_str .split ("v" )[- 1 ])
180189
@@ -205,14 +214,15 @@ def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
205214 def from_pretrained (
206215 cls , path_or_repo_id : str , * , revision : str | None = None
207216 ) -> "MistralTokenizer" :
217+ from mistral_common .protocol .instruct .validator import ValidationMode
208218 from transformers .tokenization_mistral_common import (
209219 MistralCommonTokenizer as TransformersMistralTokenizer ,
210220 )
211221
212222 str_revision = "main" if revision is None else revision
213223 return cls (
214224 TransformersMistralTokenizer .from_pretrained (
215- path_or_repo_id , revision = str_revision
225+ path_or_repo_id , revision = str_revision , mode = ValidationMode . test
216226 )
217227 )
218228
@@ -339,20 +349,14 @@ def encode(
339349 max_length : int | None = None ,
340350 add_special_tokens : bool | None = None ,
341351 ) -> list [int ]:
342- if add_special_tokens is not None :
343- return self .transformers_tokenizer .encode (
344- text ,
345- truncation = truncation ,
346- max_length = max_length ,
347- add_special_tokens = add_special_tokens ,
348- )
349- else :
350- encoded = self .tokenizer .encode (text , bos = True , eos = False )
352+ encoded = self .tokenizer .encode (
353+ text , bos = add_special_tokens is not False , eos = False
354+ )
351355
352- if truncation is not False and max_length is not None :
353- return encoded [:max_length ]
354- else :
355- return encoded
356+ if truncation is not False and max_length is not None :
357+ return encoded [:max_length ]
358+ else :
359+ return encoded
356360
357361 def apply_chat_template (
358362 self ,
@@ -383,6 +387,9 @@ def apply_chat_template(
383387 )
384388
385389 def decode (self , ids : list [int ] | int , skip_special_tokens : bool = True ) -> str :
390+ if isinstance (ids , int ):
391+ ids = [ids ]
392+
386393 return self .transformers_tokenizer .decode (
387394 ids , skip_special_tokens = skip_special_tokens
388395 )
0 commit comments