@@ -2676,7 +2676,6 @@ def _process_fsdp_args(self):
26762676 with open (self .fsdp_config , encoding = "utf-8" ) as f :
26772677 self .fsdp_config = json .load (f )
26782678
2679- fsdp_version = self .fsdp_config .get ("fsdp_version" , 1 )
26802679 if self .fsdp_config is not None and isinstance (self .fsdp_config , dict ):
26812680 for k in list (self .fsdp_config .keys ()):
26822681 if k .startswith ("fsdp_" ) and k != "fsdp_version" :
@@ -2749,14 +2748,28 @@ def _process_fsdp_args(self):
27492748 self .fsdp_config ["transformer_layer_cls_to_wrap" ]
27502749 )
27512750
2752- # Pull allowed parameters from fsdp_config
2753- ALLOWED_FSDP_PARAMS = {f .name for f in fields (FullyShardedDataParallelPlugin )}
2754- for key in ALLOWED_FSDP_PARAMS :
2755- if key in self .fsdp_config and key not in fsdp_plugin_args :
2756- fsdp_plugin_args [key ] = self .fsdp_config [key ]
2757- fsdp_plugin_args ["fsdp_version" ] = fsdp_version
2751+ fsdp_plugin_args ["fsdp_version" ] = self .fsdp_config .get ("fsdp_version" , 1 )
2752+ prefetch_policy = self .fsdp_config .get ("backward_prefetch" , "NO_PREFETCH" )
2753+ fsdp_plugin_args ["backward_prefetch" ] = prefetch_policy .upper ()
2754+ if fsdp_plugin_args ["fsdp_version" ] == 1 :
2755+ # forward_prefetch is not yet implemented in FSDP2
2756+ fsdp_plugin_args ["forward_prefetch" ] = str (self .fsdp_config .get ("forward_prefetch" , "false" )).lower ()
2757+
2758+ sync_module_states = str (self .fsdp_config .get ("sync_module_states" , "true" )).lower ()
2759+ cpu_ram_efficient_loading = str (self .fsdp_config .get ("cpu_ram_efficient_loading" , "false" )).lower ()
2760+ if sync_module_states == "false" and cpu_ram_efficient_loading == "true" :
2761+ # In this case, all the processes except the main process would have random weights leading
2762+ # to unexpected behaviour during training, thus throwing error here to prevent it.
2763+ raise ValueError ('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`' )
2764+
2765+ # we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers
2766+ fsdp_plugin_args ["cpu_ram_efficient_loading" ] = cpu_ram_efficient_loading
2767+ os .environ ["FSDP_CPU_RAM_EFFICIENT_LOADING" ] = cpu_ram_efficient_loading
27582768
2759- # Special HF-to-plugin map: transformer_layer_cls_to_wrap → joined cls_names
2769+ fsdp_plugin_args ["sync_module_states" ] = sync_module_states
2770+ fsdp_plugin_args ["use_orig_params" ] = str (self .fsdp_config .get ("use_orig_params" , "true" )).lower ()
2771+
2772+ # HF-to-plugin map
27602773 if (
27612774 "transformer_layer_cls_to_wrap" in self .fsdp_config
27622775 and "transformer_cls_names_to_wrap" not in fsdp_plugin_args
@@ -2765,26 +2778,11 @@ def _process_fsdp_args(self):
27652778 self .fsdp_config ["transformer_layer_cls_to_wrap" ]
27662779 )
27672780
2768- # Validation: sync_module_states vs cpu_ram_efficient_loading
2769- sync_states = fsdp_plugin_args .get ("sync_module_states" , "true" )
2770- cpu_loading = fsdp_plugin_args .get ("cpu_ram_efficient_loading" , "false" )
2771-
2772- if isinstance (sync_states , str ):
2773- sync_states = sync_states .lower ()
2774- if isinstance (cpu_loading , str ):
2775- cpu_loading = cpu_loading .lower ()
2776- if sync_states == "false" and cpu_loading == "true" :
2777- raise ValueError ('`sync_module_states` must be `"true"` if `cpu_ram_efficient_loading` is `"true"`.' )
2778-
2779- # CRITICAL: Set environment variable for cpu_ram_efficient_loading
2780- if "cpu_ram_efficient_loading" in fsdp_plugin_args :
2781- cpu_ram_value = fsdp_plugin_args ["cpu_ram_efficient_loading" ]
2782- # Handle both bool and string values
2783- if isinstance (cpu_ram_value , bool ):
2784- cpu_ram_value = str (cpu_ram_value ).lower ()
2785- elif isinstance (cpu_ram_value , str ):
2786- cpu_ram_value = cpu_ram_value .lower ()
2787- os .environ ["FSDP_CPU_RAM_EFFICIENT_LOADING" ] = cpu_ram_value
2781+ # Pull allowed parameters from fsdp_config
2782+ ALLOWED_FSDP_PARAMS = {f .name for f in fields (FullyShardedDataParallelPlugin )}
2783+ for key in ALLOWED_FSDP_PARAMS :
2784+ if key in self .fsdp_config and key not in fsdp_plugin_args :
2785+ fsdp_plugin_args [key ] = self .fsdp_config [key ]
27882786
27892787 return fsdp_plugin_args
27902788
0 commit comments