1717import torch ._inductor .config
1818import torch .distributed as dist
1919
20- from torchchat .distributed .utils import (
20+ from torchchat .distributed .logging_utils import SingletonLogger
21+
22+ from torchchat .distributed .utils import (
2123 Color as color ,
2224 CUDATrackTime ,
23- init_distributed ,
2425 GPUMemoryMonitor ,
26+ init_distributed ,
2527)
26- from torchchat .distributed .logging_utils import SingletonLogger
2728
2829from torchchat .model import Model , ModelArgs , ModelType , Transformer , TransformerArgs
2930from torchchat .model_config .model_config import resolve_model_config
3738from torchchat .utils .quantize import quantize_model
3839
3940
40- from torchtune .models .convert_weights import meta_to_tune
41-
42- from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
43-
44- from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
45-
46- from torchtune .training import set_default_dtype
47-
48-
4941@dataclass
5042class BuilderArgs :
5143 checkpoint_path : Optional [Union [Path , str ]] = None
@@ -188,15 +180,19 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
188180 tp = getattr (args , "tp" , 1 )
189181 chpt_from = getattr (args , "chpt_from" , "hf" )
190182 sdp_backend_dict = {
191- ' math' : torch .nn .attention .SDPBackend .MATH ,
192- ' flash_attention' : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
193- ' efficient_attention' : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
194- ' cudnn_attention' : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
183+ " math" : torch .nn .attention .SDPBackend .MATH ,
184+ " flash_attention" : torch .nn .attention .SDPBackend .FLASH_ATTENTION ,
185+ " efficient_attention" : torch .nn .attention .SDPBackend .EFFICIENT_ATTENTION ,
186+ " cudnn_attention" : torch .nn .attention .SDPBackend .CUDNN_ATTENTION ,
195187 }
196188 attention_backend = sdp_backend_dict [args .attention_backend ]
197- if args .device == "cpu" and (args .attention_backend == "efficient_attention"
198- or args .attention_backend == "cudnn_attention" ):
199- print (f"Warning: { args .attention_backend } is not supported on CPU. Using math instead." )
189+ if args .device == "cpu" and (
190+ args .attention_backend == "efficient_attention"
191+ or args .attention_backend == "cudnn_attention"
192+ ):
193+ print (
194+ f"Warning: { args .attention_backend } is not supported on CPU. Using math instead."
195+ )
200196 attention_backend = torch .nn .attention .SDPBackend .MATH
201197 return cls (
202198 checkpoint_dir = checkpoint_dir ,
@@ -294,9 +290,9 @@ def validate_model(
294290 use_sentencepiece = not (use_tiktoken or use_hf_tokenizer )
295291
296292 if (
297- (is_tiktoken and not use_tiktoken ) or
298- (is_hf_tokenizer and not use_hf_tokenizer ) or
299- (is_sentencepiece and not use_sentencepiece )
293+ (is_tiktoken and not use_tiktoken )
294+ or (is_hf_tokenizer and not use_hf_tokenizer )
295+ or (is_sentencepiece and not use_sentencepiece )
300296 ):
301297 raise RuntimeError (
302298 "model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}" .format (
@@ -404,6 +400,7 @@ def _load_model_gguf(builder_args: BuilderArgs) -> Model:
404400
405401def _load_checkpoint (builder_args : BuilderArgs ):
406402 if builder_args .params_table and builder_args .params_table .endswith ("Tune" ):
403+ from torchtune .models .convert_weights import meta_to_tune
407404 print ("Loading Tune checkpoint" )
408405 meta_checkpoint = torch .load (
409406 str (builder_args .checkpoint_path ), mmap = True , weights_only = True
@@ -456,9 +453,15 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
456453 checkpoint = checkpoint ["model" ]
457454
458455 if model .config .model_type == ModelType .Flamingo :
456+ from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
457+ from torchtune .models .llama3_2_vision ._convert_weights import (
458+ llama3_vision_meta_to_tune ,
459+ )
460+ from torchtune .training import set_default_dtype
459461 # TODO: Refactor this. For now, overwrite the model with model loaded from params_path
460- with set_default_dtype (builder_args .precision ), torch .device (
461- builder_args .device
462+ with (
463+ set_default_dtype (builder_args .precision ),
464+ torch .device (builder_args .device ),
462465 ):
463466 # It doubles the model size the memory, with redundancies of the initialized weights.
464467 # model = Model.from_params(builder_args.params_path)
@@ -494,6 +497,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
494497 # AOTI-compoiled model will load its own weights.
495498 # Release weights here to avoid OOM
496499 import gc
500+
497501 if hasattr (model , "model" ):
498502 model .model = None
499503 gc .collect ()
@@ -551,6 +555,7 @@ def _initialize_model(
551555
552556 def do_nothing (max_batch_size , max_seq_length ):
553557 pass
558+
554559 model .setup_caches = do_nothing
555560
556561 model .forward = torch ._export .aot_load (
@@ -588,6 +593,7 @@ def do_nothing(max_batch_size, max_seq_length):
588593
589594 def do_nothing (max_batch_size , max_seq_length ):
590595 pass
596+
591597 model .setup_caches = do_nothing
592598
593599 model .forward = aoti_compiled_model
@@ -639,12 +645,15 @@ def do_nothing(max_batch_size, max_seq_length):
639645 try :
640646 model = torch .load (builder_args .snapshot_path , weights_only = False )
641647 except Exception :
642- raise RuntimeError (f"Failed to load torchchat snapshot { builder_args .snapshot_path } " )
648+ raise RuntimeError (
649+ f"Failed to load torchchat snapshot { builder_args .snapshot_path } "
650+ )
643651 # _active_backend() does not allow DSO & AOTI to be true.
644652 # Choose either.
645653 from torchchat .utils .build_utils import set_backend
646- set_backend (dso = True , pte = False , aoti_package = False )
647- if (model .config != config ):
654+
655+ set_backend (dso = True , pte = False , aoti_package = False )
656+ if model .config != config :
648657 raise RuntimeError ("loaded model architecture mismatch" )
649658 ##
650659 ## import all libraries with custom kernels ans custom operators
@@ -662,7 +671,9 @@ def do_nothing(max_batch_size, max_seq_length):
662671 logger = SingletonLogger .get_logger ()
663672
664673 gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
665- logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
674+ logger .info (
675+ f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } "
676+ )
666677
667678 # Model-level config
668679 if builder_args .params_table :
@@ -673,20 +684,16 @@ def do_nothing(max_batch_size, max_seq_length):
673684 config = TransformerArgs .from_params (model_config .transformer_args ["text" ])
674685 logger .info (f"Transformer Config: { config } " )
675686
676- #TODO: Move into head of file after solving circular import
677- from torchchat .distributed .checkpoint_utils import (
678- load_model_weights ,
679- )
687+ # TODO: Move into head of file after solving circular import
688+ from torchchat .distributed .checkpoint_utils import load_model_weights
680689
681690 # Validate pipeline degree
682691 assert config .n_layers % pp_degree == 0
683692
684693 # Create device mesh
685694 device_mesh = dist .init_device_mesh (
686- "cuda" ,
687- (pp_degree , tp_degree ),
688- mesh_dim_names = ("pp" , "tp" )
689- )
695+ "cuda" , (pp_degree , tp_degree ), mesh_dim_names = ("pp" , "tp" )
696+ )
690697 tp_mesh = device_mesh ["tp" ]
691698 pp_mesh = device_mesh ["pp" ]
692699 logger .info (f"Created device mesh: { device_mesh } \n { tp_mesh = } , { pp_mesh = } " )
@@ -715,7 +722,13 @@ def do_nothing(max_batch_size, max_seq_length):
715722 # Load weights
716723 logger .info (f"Loading weights for { pp_rank = } on { device = } " )
717724 with CUDATrackTime () as timer :
718- load_model_weights (model , builder_args .distribution_path , device , config , builder_args .chpt_from )
725+ load_model_weights (
726+ model ,
727+ builder_args .distribution_path ,
728+ device ,
729+ config ,
730+ builder_args .chpt_from ,
731+ )
719732
720733 logger .info (
721734 f"{ color .green } Total weight loading time: { timer .get_time ()} { timer .unit } for rank { rank } { color .reset } "
@@ -729,7 +742,7 @@ def do_nothing(max_batch_size, max_seq_length):
729742 # lanes.
730743 # TODO: bump up the lane count
731744 pipeline_lanes = 1
732- seqlen_prefill = 1024
745+ seqlen_prefill = 1024
733746 with device :
734747 model .setup_caches (1 , seqlen_prefill , cache_lanes = pipeline_lanes )
735748
0 commit comments