1616import torch ._dynamo .config
1717import torch ._inductor .config
1818import torch .distributed as dist
19- from torchchat .distributed .logging_utils import SingletonLogger
2019
21- from torchchat .distributed .utils import (
20+ from torchchat .distributed .utils import (
2221 Color as color ,
2322 CUDATrackTime ,
24- GPUMemoryMonitor ,
2523 init_distributed ,
24+ GPUMemoryMonitor ,
2625)
26+ from torchchat .distributed .logging_utils import SingletonLogger
2727
2828from torchchat .model import Model , ModelArgs , ModelType , Transformer , TransformerArgs
2929from torchchat .model_config .model_config import resolve_model_config
3636from torchchat .utils .measure_time import measure_time
3737from torchchat .utils .quantize import quantize_model
3838
39+
3940from torchtune .models .convert_weights import meta_to_tune
4041
4142from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
@@ -187,19 +188,15 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
187188 tp = getattr (args , "tp" , 1 )
188189 chpt_from = getattr (args , "chpt_from" , "hf" )
189190 sdp_backend_dict = {
190- " math" : torch .nn .attention .SDPBackend .MATH ,
191- " flash_attention" : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
192- " efficient_attention" : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
193- " cudnn_attention" : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
191+ ' math' : torch .nn .attention .SDPBackend .MATH ,
192+ ' flash_attention' : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
193+ ' efficient_attention' : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
194+ ' cudnn_attention' : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
194195 }
195196 attention_backend = sdp_backend_dict [args .attention_backend ]
196- if args .device == "cpu" and (
197- args .attention_backend == "efficient_attention"
198- or args .attention_backend == "cudnn_attention"
199- ):
200- print (
201- f"Warning: { args .attention_backend } is not supported on CPU. Using math instead."
202- )
197+ if args .device == "cpu" and (args .attention_backend == "efficient_attention"
198+ or args .attention_backend == "cudnn_attention" ):
199+ print (f"Warning: { args .attention_backend } is not supported on CPU. Using math instead." )
203200 attention_backend = torch .nn .attention .SDPBackend .MATH
204201 return cls (
205202 checkpoint_dir = checkpoint_dir ,
@@ -281,6 +278,8 @@ def __post_init__(self):
281278 except :
282279 pass
283280
281+ self .tokenizer_type = TokenizerType .NONE
282+ self .t = None
284283 return
285284
286285 def is_tiktoken (self ) -> bool :
@@ -292,13 +291,6 @@ def is_sentencepiece(self) -> bool:
292291 def is_hf_tokenizer (self ) -> bool :
293292 return self .tokenizer_type == TokenizerType .HF_TOKENIZER
294293
295- def is_tokenizer_none (self ) -> bool :
296- if self .tokenizer_type != TokenizerType .NONE :
297- return False
298-
299- assert self .t is None , "tokenizer_type is NONE but t is not None"
300- return True
301-
302294 def validate_model (
303295 self ,
304296 model : Optional [Model ],
@@ -307,21 +299,22 @@ def validate_model(
307299 if model is None :
308300 return
309301
310- if self .is_tokenizer_none ():
311- raise RuntimeError (f"no tokenizer was found at { self .tokenizer_path } " )
312302
313303 is_tiktoken = self .is_tiktoken ()
314304 is_sentencepiece = self .is_sentencepiece ()
315305 is_hf_tokenizer = self .is_hf_tokenizer ()
316306
307+ if sum ([is_tiktoken , is_hf_tokenizer , is_sentencepiece ]) != 1 :
308+ raise RuntimeError (f"no tokenizer was found at { self .tokenizer_path } " )
309+
317310 use_tiktoken = model .config .use_tiktoken
318311 use_hf_tokenizer = model .config .use_hf_tokenizer
319312 use_sentencepiece = not (use_tiktoken or use_hf_tokenizer )
320313
321314 if (
322- (is_tiktoken and not use_tiktoken )
323- or (is_hf_tokenizer and not use_hf_tokenizer )
324- or (is_sentencepiece and not use_sentencepiece )
315+ (is_tiktoken and not use_tiktoken ) or
316+ (is_hf_tokenizer and not use_hf_tokenizer ) or
317+ (is_sentencepiece and not use_sentencepiece )
325318 ):
326319 raise RuntimeError (
327320 "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}" .format (
@@ -519,7 +512,6 @@ def _load_model(builder_args: BuilderArgs) -> Model:
519512 # AOTI-compoiled model will load its own weights.
520513 # Release weights here to avoid OOM
521514 import gc
522-
523515 if hasattr (model , "model" ):
524516 model .model = None
525517 gc .collect ()
@@ -577,7 +569,6 @@ def _initialize_model(
577569
578570 def do_nothing (max_batch_size , max_seq_length ):
579571 pass
580-
581572 model .setup_caches = do_nothing
582573
583574 model .forward = torch ._export .aot_load (
@@ -615,7 +606,6 @@ def do_nothing(max_batch_size, max_seq_length):
615606
616607 def do_nothing (max_batch_size , max_seq_length ):
617608 pass
618-
619609 model .setup_caches = do_nothing
620610
621611 model .forward = aoti_compiled_model
@@ -667,15 +657,12 @@ def do_nothing(max_batch_size, max_seq_length):
667657 try :
668658 model = torch .load (builder_args .snapshot_path , weights_only = False )
669659 except Exception :
670- raise RuntimeError (
671- f"Failed to load torchchat snapshot { builder_args .snapshot_path } "
672- )
660+ raise RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path } " )
673661 # _active_backend() does not allow DSO & AOTI to be true.
674662 # Choose either.
675663 from torchchat .utils .build_utils import set_backend
676-
677- set_backend (dso = True , pte = False , aoti_package = False )
678- if model .config != config :
664+ set_backend (dso = True , pte = False , aoti_package = False )
665+ if (model .config != config ):
679666 raise RuntimeError ("loaded model architecture mismatch" )
680667 ##
681668 ## import all libraries with custom kernels ans custom operators
@@ -693,9 +680,7 @@ def do_nothing(max_batch_size, max_seq_length):
693680 logger = SingletonLogger .get_logger ()
694681
695682 gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
696- logger .info (
697- f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } "
698- )
683+ logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
699684
700685 # Model-level config
701686 if builder_args .params_table :
@@ -706,16 +691,20 @@ def do_nothing(max_batch_size, max_seq_length):
706691 config = TransformerArgs .from_params (model_config .transformer_args ["text" ])
707692 logger .info (f"Transformer Config: { config } " )
708693
709- # TODO: Move into head of file after solving circular import
710- from torchchat .distributed .checkpoint_utils import load_model_weights
694+ #TODO: Move into head of file after solving circular import
695+ from torchchat .distributed .checkpoint_utils import (
696+ load_model_weights ,
697+ )
711698
712699 # Validate pipeline degree
713700 assert config .n_layers % pp_degree == 0
714701
715702 # Create device mesh
716703 device_mesh = dist .init_device_mesh (
717- "cuda" , (pp_degree , tp_degree ), mesh_dim_names = ("pp" , "tp" )
718- )
704+ "cuda" ,
705+ (pp_degree , tp_degree ),
706+ mesh_dim_names = ("pp" , "tp" )
707+ )
719708 tp_mesh = device_mesh ["tp" ]
720709 pp_mesh = device_mesh ["pp" ]
721710 logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } , { pp_mesh = } " )
@@ -744,13 +733,7 @@ def do_nothing(max_batch_size, max_seq_length):
744733 # Load weights
745734 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
746735 with CUDATrackTime () as timer :
747- load_model_weights (
748- model ,
749- builder_args .distribution_path ,
750- device ,
751- config ,
752- builder_args .chpt_from ,
753- )
736+ load_model_weights (model , builder_args .distribution_path , device , config , builder_args .chpt_from )
754737
755738 logger .info (
756739 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
@@ -764,7 +747,7 @@ def do_nothing(max_batch_size, max_seq_length):
764747 # lanes.
765748 # TODO: bump up the lane count
766749 pipeline_lanes = 1
767- seqlen_prefill = 1024
750+ seqlen_prefill = 1024
768751 with device :
769752 model .setup_caches (1 , seqlen_prefill , cache_lanes = pipeline_lanes )
770753
0 commit comments