@@ -241,6 +241,11 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
241241 speculative_builder_args .pte_path = None
242242 return speculative_builder_args
243243
244+ class TokenizerType (Enum ):
245+ NONE = 0
246+ TIKTOKEN = 1
247+ SENTENCEPIECE = 2
248+ HF_TOKENIZER = 3
244249
245250@dataclass
246251class TokenizerArgs :
@@ -276,12 +281,24 @@ def __post_init__(self):
276281 except :
277282 pass
278283
279- self .is_tiktoken = False
280- self .is_sentencepiece = False
281- self .is_hf_tokenizer = False
282- self .t = None
283284 return
284285
286+ def is_tiktoken (self ) -> bool :
287+ return self .tokenizer_type == TokenizerType .TIKTOKEN
288+
289+ def is_sentencepiece (self ) -> bool :
290+ return self .tokenizer_type == TokenizerType .SENTENCEPIECE
291+
292+ def is_hf_tokenizer (self ) -> bool :
293+ return self .tokenizer_type == TokenizerType .HF_TOKENIZER
294+
295+ def is_tokenizer_none (self ) -> bool :
296+ if self .tokenizer_type != TokenizerType .NONE :
297+ return False
298+
299+ assert self .t is None , "tokenizer_type is NONE but t is not None"
300+ return True
301+
285302 def validate_model (
286303 self ,
287304 model : Optional [Model ],
@@ -290,12 +307,13 @@ def validate_model(
290307 if model is None :
291308 return
292309
293- if sum ([ self .is_tiktoken , self . is_hf_tokenizer , self . is_sentencepiece ]) != 1 :
310+ if self .is_tokenizer_none () :
294311 raise RuntimeError (f"no tokenizer was found at { self .tokenizer_path } " )
295312
296- is_tiktoken = self .is_tiktoken
297- is_sentencepiece = self .is_sentencepiece
298- is_hf_tokenizer = self .is_hf_tokenizer
313+ is_tiktoken = self .is_tiktoken ()
314+ is_sentencepiece = self .is_sentencepiece ()
315+ is_hf_tokenizer = self .is_hf_tokenizer ()
316+
299317 use_tiktoken = model .config .use_tiktoken
300318 use_hf_tokenizer = model .config .use_hf_tokenizer
301319 use_sentencepiece = not (use_tiktoken or use_hf_tokenizer )
0 commit comments