@@ -2676,9 +2676,10 @@ def _process_fsdp_args(self):
26762676 with open (self .fsdp_config , encoding = "utf-8" ) as f :
26772677 self .fsdp_config = json .load (f )
26782678
2679+ fsdp_version = self .fsdp_config .get ("fsdp_version" , 1 )
26792680 if self .fsdp_config is not None and isinstance (self .fsdp_config , dict ):
26802681 for k in list (self .fsdp_config .keys ()):
2681- if k .startswith ("fsdp_" ):
2682+ if k .startswith ("fsdp_" ) and k != "fsdp_version" :
26822683 v = self .fsdp_config .pop (k )
26832684 self .fsdp_config [k [5 :]] = v
26842685
@@ -2722,15 +2723,20 @@ def _process_fsdp_args(self):
27222723 # accelerate integration for FSDP
27232724 fsdp_plugin_args = None
27242725 if len (self .fsdp ) > 0 and not self .fsdp_config ["xla" ]:
2726+ from accelerate .utils import FullyShardedDataParallelPlugin
27252727 from accelerate .utils .constants import (
27262728 FSDP_AUTO_WRAP_POLICY ,
27272729 FSDP_SHARDING_STRATEGY ,
27282730 )
27292731
27302732 fsdp_plugin_args = {}
2733+ # Handle basic FSDP options from command-line flags
27312734 for fsdp_option in self .fsdp :
27322735 if fsdp_option .upper () in FSDP_SHARDING_STRATEGY :
2733- fsdp_plugin_args ["sharding_strategy" ] = fsdp_option
2736+ # Set deprecated sharding_strategy from CLI (plugin maps to reshard_after_forward)
2737+ # Skip if config has explicit reshard_after_forward (prioritize config)
2738+ if "reshard_after_forward" not in self .fsdp_config :
2739+ fsdp_plugin_args ["sharding_strategy" ] = fsdp_option
27342740 elif fsdp_option == FSDPOption .OFFLOAD :
27352741 fsdp_plugin_args ["cpu_offload" ] = True
27362742 elif fsdp_option == FSDPOption .AUTO_WRAP :
@@ -2742,24 +2748,43 @@ def _process_fsdp_args(self):
27422748 fsdp_plugin_args ["transformer_cls_names_to_wrap" ] = "," .join (
27432749 self .fsdp_config ["transformer_layer_cls_to_wrap" ]
27442750 )
2745- fsdp_plugin_args ["fsdp_version" ] = self .fsdp_config .get ("fsdp_version" , 1 )
2746- prefetch_policy = self .fsdp_config .get ("backward_prefetch" , "NO_PREFETCH" )
2747- fsdp_plugin_args ["backward_prefetch" ] = prefetch_policy .upper ()
2748- fsdp_plugin_args ["forward_prefetch" ] = str (self .fsdp_config .get ("forward_prefetch" , "false" )).lower ()
2749-
2750- sync_module_states = str (self .fsdp_config .get ("sync_module_states" , "true" )).lower ()
2751- cpu_ram_efficient_loading = str (self .fsdp_config .get ("cpu_ram_efficient_loading" , "false" )).lower ()
2752- if sync_module_states == "false" and cpu_ram_efficient_loading == "true" :
2753- # In this case, all the processes except the main process would have random weights leading
2754- # to unexpected behaviour during training, thus throwing error here to prevent it.
2755- raise ValueError ('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`' )
2756-
2757- # we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers
2758- fsdp_plugin_args ["cpu_ram_efficient_loading" ] = cpu_ram_efficient_loading
2759- os .environ ["FSDP_CPU_RAM_EFFICIENT_LOADING" ] = cpu_ram_efficient_loading
2760-
2761- fsdp_plugin_args ["sync_module_states" ] = sync_module_states
2762- fsdp_plugin_args ["use_orig_params" ] = str (self .fsdp_config .get ("use_orig_params" , "true" )).lower ()
2751+
2752+ # Pull allowed parameters from fsdp_config
2753+ ALLOWED_FSDP_PARAMS = {f .name for f in fields (FullyShardedDataParallelPlugin )}
2754+ for key in ALLOWED_FSDP_PARAMS :
2755+ if key in self .fsdp_config and key not in fsdp_plugin_args :
2756+ fsdp_plugin_args [key ] = self .fsdp_config [key ]
2757+ fsdp_plugin_args ["fsdp_version" ] = fsdp_version
2758+
2759+ # Special HF-to-plugin map: transformer_layer_cls_to_wrap → joined cls_names
2760+ if (
2761+ "transformer_layer_cls_to_wrap" in self .fsdp_config
2762+ and "transformer_cls_names_to_wrap" not in fsdp_plugin_args
2763+ ):
2764+ fsdp_plugin_args ["transformer_cls_names_to_wrap" ] = "," .join (
2765+ self .fsdp_config ["transformer_layer_cls_to_wrap" ]
2766+ )
2767+
2768+ # Validation: sync_module_states vs cpu_ram_efficient_loading
2769+ sync_states = fsdp_plugin_args .get ("sync_module_states" , "true" )
2770+ cpu_loading = fsdp_plugin_args .get ("cpu_ram_efficient_loading" , "false" )
2771+
2772+ if isinstance (sync_states , str ):
2773+ sync_states = sync_states .lower ()
2774+ if isinstance (cpu_loading , str ):
2775+ cpu_loading = cpu_loading .lower ()
2776+ if sync_states == "false" and cpu_loading == "true" :
2777+ raise ValueError ('`sync_module_states` must be `"true"` if `cpu_ram_efficient_loading` is `"true"`.' )
2778+
2779+ # CRITICAL: Set environment variable for cpu_ram_efficient_loading
2780+ if "cpu_ram_efficient_loading" in fsdp_plugin_args :
2781+ cpu_ram_value = fsdp_plugin_args ["cpu_ram_efficient_loading" ]
2782+ # Handle both bool and string values
2783+ if isinstance (cpu_ram_value , bool ):
2784+ cpu_ram_value = str (cpu_ram_value ).lower ()
2785+ elif isinstance (cpu_ram_value , str ):
2786+ cpu_ram_value = cpu_ram_value .lower ()
2787+ os .environ ["FSDP_CPU_RAM_EFFICIENT_LOADING" ] = cpu_ram_value
27632788
27642789 return fsdp_plugin_args
27652790
0 commit comments