1616import torch ._dynamo .config
1717import torch ._inductor .config
1818import torch .distributed as dist
19+ from torchchat .distributed .logging_utils import SingletonLogger
1920
20- from torchchat .distributed .utils import (
21+ from torchchat .distributed .utils import (
2122 Color as color ,
2223 CUDATrackTime ,
23- init_distributed ,
2424 GPUMemoryMonitor ,
25+ init_distributed ,
2526)
26- from torchchat .distributed .logging_utils import SingletonLogger
2727
2828from torchchat .model import Model , ModelArgs , ModelType , Transformer , TransformerArgs
2929from torchchat .model_config .model_config import resolve_model_config
3636from torchchat .utils .measure_time import measure_time
3737from torchchat .utils .quantize import quantize_model
3838
39-
4039from torchtune .models .convert_weights import meta_to_tune
4140
4241from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
@@ -188,15 +187,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
188187 tp = getattr (args , "tp" , 1 )
189188 chpt_from = getattr (args , "chpt_from" , "hf" )
190189 sdp_backend_dict = {
191- ' math' : torch .nn .attention .SDPBackend .MATH ,
192- ' flash_attention' : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
193- ' efficient_attention' : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
194- ' cudnn_attention' : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
190+ " math" : torch .nn .attention .SDPBackend .MATH ,
191+ " flash_attention" : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
192+ " efficient_attention" : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
193+ " cudnn_attention" : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
195194 }
196195 attention_backend = sdp_backend_dict [args .attention_backend ]
197- if args .device == "cpu" and (args .attention_backend == "efficient_attention"
198- or args .attention_backend == "cudnn_attention" ):
199- print (f"Warning: { args .attention_backend } is not supported on CPU. Using math instead." )
196+ if args .device == "cpu" and (
197+ args .attention_backend == "efficient_attention"
198+ or args .attention_backend == "cudnn_attention"
199+ ):
200+ print (
201+ f"Warning: { args .attention_backend } is not supported on CPU. Using math instead."
202+ )
200203 attention_backend = torch .nn .attention .SDPBackend .MATH
201204 return cls (
202205 checkpoint_dir = checkpoint_dir ,
@@ -238,11 +241,6 @@ def from_speculative_args(cls, args: argparse.Namespace) -> "BuilderArgs":
238241 speculative_builder_args .pte_path = None
239242 return speculative_builder_args
240243
241- class TokenizerType (Enum ):
242- NONE = 0
243- TIKTOKEN = 1
244- SENTENCEPIECE = 2
245- HF_TOKENIZER = 3
246244
247245@dataclass
248246class TokenizerArgs :
@@ -278,19 +276,12 @@ def __post_init__(self):
278276 except :
279277 pass
280278
281- self .tokenizer_type = TokenizerType .NONE
279+ self .is_tiktoken = False
280+ self .is_sentencepiece = False
281+ self .is_hf_tokenizer = False
282282 self .t = None
283283 return
284284
285- def is_tiktoken (self ) -> bool :
286- return self .tokenizer_type == TokenizerType .TIKTOKEN
287-
288- def is_sentencepiece (self ) -> bool :
289- return self .tokenizer_type == TokenizerType .SENTENCEPIECE
290-
291- def is_hf_tokenizer (self ) -> bool :
292- return self .tokenizer_type == TokenizerType .HF_TOKENIZER
293-
294285 def validate_model (
295286 self ,
296287 model : Optional [Model ],
@@ -299,22 +290,20 @@ def validate_model(
299290 if model is None :
300291 return
301292
302-
303- is_tiktoken = self .is_tiktoken ()
304- is_sentencepiece = self .is_sentencepiece ()
305- is_hf_tokenizer = self .is_hf_tokenizer ()
306-
307- if sum ([is_tiktoken , is_hf_tokenizer , is_sentencepiece ]) != 1 :
293+ if sum ([self .is_tiktoken , self .is_hf_tokenizer , self .is_sentencepiece ]) != 1 :
308294 raise RuntimeError (f"no tokenizer was found at { self .tokenizer_path } " )
309295
296+ is_tiktoken = self .is_tiktoken
297+ is_sentencepiece = self .is_sentencepiece
298+ is_hf_tokenizer = self .is_hf_tokenizer
310299 use_tiktoken = model .config .use_tiktoken
311300 use_hf_tokenizer = model .config .use_hf_tokenizer
312301 use_sentencepiece = not (use_tiktoken or use_hf_tokenizer )
313302
314303 if (
315- (is_tiktoken and not use_tiktoken ) or
316- (is_hf_tokenizer and not use_hf_tokenizer ) or
317- (is_sentencepiece and not use_sentencepiece )
304+ (is_tiktoken and not use_tiktoken )
305+ or (is_hf_tokenizer and not use_hf_tokenizer )
306+ or (is_sentencepiece and not use_sentencepiece )
318307 ):
319308 raise RuntimeError (
320309 "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}" .format (
@@ -512,6 +501,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
512501 # AOTI-compoiled model will load its own weights.
513502 # Release weights here to avoid OOM
514503 import gc
504+
515505 if hasattr (model , "model" ):
516506 model .model = None
517507 gc .collect ()
@@ -569,6 +559,7 @@ def _initialize_model(
569559
570560 def do_nothing (max_batch_size , max_seq_length ):
571561 pass
562+
572563 model .setup_caches = do_nothing
573564
574565 model .forward = torch ._export .aot_load (
@@ -606,6 +597,7 @@ def do_nothing(max_batch_size, max_seq_length):
606597
607598 def do_nothing (max_batch_size , max_seq_length ):
608599 pass
600+
609601 model .setup_caches = do_nothing
610602
611603 model .forward = aoti_compiled_model
@@ -657,12 +649,15 @@ def do_nothing(max_batch_size, max_seq_length):
657649 try :
658650 model = torch .load (builder_args .snapshot_path , weights_only = False )
659651 except Exception :
660- raise RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path } " )
652+ raise RuntimeError (
653+ f"Failed to load torchchat snapshot { builder_args .snapshot_path } "
654+ )
661655 # _active_backend() does not allow DSO & AOTI to be true.
662656 # Choose either.
663657 from torchchat .utils .build_utils import set_backend
664- set_backend (dso = True , pte = False , aoti_package = False )
665- if (model .config != config ):
658+
659+ set_backend (dso = True , pte = False , aoti_package = False )
660+ if model .config != config :
666661 raise RuntimeError ("loaded model architecture mismatch" )
667662 ##
668663 ## import all libraries with custom kernels ans custom operators
@@ -680,7 +675,9 @@ def do_nothing(max_batch_size, max_seq_length):
680675 logger = SingletonLogger .get_logger ()
681676
682677 gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
683- logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
678+ logger .info (
679+ f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } "
680+ )
684681
685682 # Model-level config
686683 if builder_args .params_table :
@@ -691,20 +688,16 @@ def do_nothing(max_batch_size, max_seq_length):
691688 config = TransformerArgs .from_params (model_config .transformer_args ["text" ])
692689 logger .info (f"Transformer Config: { config } " )
693690
694- #TODO: Move into head of file after solving circular import
695- from torchchat .distributed .checkpoint_utils import (
696- load_model_weights ,
697- )
691+ # TODO: Move into head of file after solving circular import
692+ from torchchat .distributed .checkpoint_utils import load_model_weights
698693
699694 # Validate pipeline degree
700695 assert config .n_layers % pp_degree == 0
701696
702697 # Create device mesh
703698 device_mesh = dist .init_device_mesh (
704- "cuda" ,
705- (pp_degree , tp_degree ),
706- mesh_dim_names = ("pp" , "tp" )
707- )
699+ "cuda" , (pp_degree , tp_degree ), mesh_dim_names = ("pp" , "tp" )
700+ )
708701 tp_mesh = device_mesh ["tp" ]
709702 pp_mesh = device_mesh ["pp" ]
710703 logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } , { pp_mesh = } " )
@@ -733,7 +726,13 @@ def do_nothing(max_batch_size, max_seq_length):
733726 # Load weights
734727 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
735728 with CUDATrackTime () as timer :
736- load_model_weights (model , builder_args .distribution_path , device , config , builder_args .chpt_from )
729+ load_model_weights (
730+ model ,
731+ builder_args .distribution_path ,
732+ device ,
733+ config ,
734+ builder_args .chpt_from ,
735+ )
737736
738737 logger .info (
739738 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
@@ -747,7 +746,7 @@ def do_nothing(max_batch_size, max_seq_length):
747746 # lanes.
748747 # TODO: bump up the lane count
749748 pipeline_lanes = 1
750- seqlen_prefill = 1024
749+ seqlen_prefill = 1024
751750 with device :
752751 model .setup_caches (1 , seqlen_prefill , cache_lanes = pipeline_lanes )
753752
0 commit comments