88import os
99import sys
1010from dataclasses import dataclass
11+ from enum import Enum
1112from pathlib import Path
1213from typing import Any , Dict , Optional , Tuple , Union
1314
@@ -237,23 +238,24 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
237238 speculative_builder_args .pte_path = None
238239 return speculative_builder_args
239240
241+ class TokenizerType (Enum ):
242+ NONE = 0
243+ TIKTOKEN = 1
244+ SENTENCEPIECE = 2
245+ HF_TOKENIZER = 3
240246
241247@dataclass
242248class TokenizerArgs :
243249 tokenizer_path : Optional [Union [Path , str ]] = None
244- is_sentencepiece : bool = False
245- is_tiktoken : bool = False
246- is_hf_tokenizer : bool = False
250+ tokenizer_type : TokenizerType = TokenizerType .NONE
247251 t : Optional [Any ] = None
248252
249253 def __post_init__ (self ):
250254 try :
251255 from tokenizer .tiktoken import Tokenizer as TiktokenTokenizer
252256
253257 self .t = TiktokenTokenizer (model_path = str (self .tokenizer_path ))
254- self .is_tiktoken = True
255- self .is_sentencepiece = False
256- self .is_hf_tokenizer = False
258+ self .tokenizer_type = TokenizerType .TIKTOKEN
257259 return
258260 except :
259261 pass
@@ -262,9 +264,7 @@ def __post_init__(self):
262264 from sentencepiece import SentencePieceProcessor
263265
264266 self .t = SentencePieceProcessor (model_file = str (self .tokenizer_path ))
265- self .is_tiktoken = False
266- self .is_sentencepiece = True
267- self .is_hf_tokenizer = False
267+ self .tokenizer_type = TokenizerType .SENTENCEPIECE
268268 return
269269 except :
270270 pass
@@ -273,19 +273,24 @@ def __post_init__(self):
273273 from tokenizer .hf_tokenizer import HFTokenizer
274274
275275 self .t = HFTokenizer (str (self .tokenizer_path ))
276- self .is_tiktoken = False
277- self .is_sentencepiece = False
278- self .is_hf_tokenizer = True
276+ self .tokenizer_type = TokenizerType .HF_TOKENIZER
279277 return
280278 except :
281279 pass
282280
283- self .is_tiktoken = False
284- self .is_sentencepiece = False
285- self .is_hf_tokenizer = False
281+ self .tokenizer_type = TokenizerType .NONE
286282 self .t = None
287283 return
288284
285+ def is_tiktoken (self ) -> bool :
286+ return self .tokenizer_type == TokenizerType .TIKTOKEN
287+
288+ def is_sentencepiece (self ) -> bool :
289+ return self .tokenizer_type == TokenizerType .SENTENCEPIECE
290+
291+ def is_hf_tokenizer (self ) -> bool :
292+ return self .tokenizer_type == TokenizerType .HF_TOKENIZER
293+
289294 def validate_model (
290295 self ,
291296 model : Optional [Model ],
@@ -294,12 +299,14 @@ def validate_model(
294299 if model is None :
295300 return
296301
297- if sum ([self .is_tiktoken , self .is_hf_tokenizer , self .is_sentencepiece ]) != 1 :
302+
303+ is_tiktoken = self .is_tiktoken ()
304+ is_sentencepiece = self .is_sentencepiece ()
305+ is_hf_tokenizer = self .is_hf_tokenizer ()
306+
307+ if sum ([is_tiktoken , is_hf_tokenizer , is_sentencepiece ]) != 1 :
298308 raise RuntimeError (f"no tokenizer was found at { self .tokenizer_path } " )
299309
300- is_tiktoken = self .is_tiktoken
301- is_sentencepiece = self .is_sentencepiece
302- is_hf_tokenizer = self .is_hf_tokenizer
303310 use_tiktoken = model .config .use_tiktoken
304311 use_hf_tokenizer = model .config .use_hf_tokenizer
305312 use_sentencepiece = not (use_tiktoken or use_hf_tokenizer )
@@ -651,13 +658,13 @@ def do_nothing(max_batch_size, max_seq_length):
651658 model = torch .load (builder_args .snapshot_path , weights_only = False )
652659 except Exception :
653660 raise RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path } " )
654- # _active_backend() does not allow DSO & AOTI to be true.
661+ # _active_backend() does not allow DSO & AOTI to be true.
655662 # Choose either.
656663 from torchchat .utils .build_utils import set_backend
657664 set_backend (dso = True , pte = False , aoti_package = False )
658665 if (model .config != config ):
659666 raise RuntimeError ("loaded model architecture mismatch" )
660- ##
667+ ##
661668 ## import all libraries with custom kernels ans custom operators
662669 ## that quantize may be pulling in
663670 ##
@@ -792,4 +799,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
792799 return "TikToken"
793800 if tokenizers :
794801 return "Tokenizers"
795- return "SentencePiece"
802+ return "SentencePiece"
0 commit comments