Skip to content

Commit bcd3599

Browse files
committed
Fix FSDP v2 defaulting to version 1 in TrainingArguments
1 parent 57eeb9c commit bcd3599

File tree

2 files changed

+64
-20
lines changed

2 files changed

+64
-20
lines changed

src/transformers/training_args.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2676,9 +2676,10 @@ def _process_fsdp_args(self):
26762676
with open(self.fsdp_config, encoding="utf-8") as f:
26772677
self.fsdp_config = json.load(f)
26782678

2679+
fsdp_version = self.fsdp_config.get("fsdp_version", 1)
26792680
if self.fsdp_config is not None and isinstance(self.fsdp_config, dict):
26802681
for k in list(self.fsdp_config.keys()):
2681-
if k.startswith("fsdp_"):
2682+
if k.startswith("fsdp_") and k != "fsdp_version":
26822683
v = self.fsdp_config.pop(k)
26832684
self.fsdp_config[k[5:]] = v
26842685

@@ -2722,15 +2723,20 @@ def _process_fsdp_args(self):
27222723
# accelerate integration for FSDP
27232724
fsdp_plugin_args = None
27242725
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
2726+
from accelerate.utils import FullyShardedDataParallelPlugin
27252727
from accelerate.utils.constants import (
27262728
FSDP_AUTO_WRAP_POLICY,
27272729
FSDP_SHARDING_STRATEGY,
27282730
)
27292731

27302732
fsdp_plugin_args = {}
2733+
# Handle basic FSDP options from command-line flags
27312734
for fsdp_option in self.fsdp:
27322735
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
2733-
fsdp_plugin_args["sharding_strategy"] = fsdp_option
2736+
# Set deprecated sharding_strategy from CLI (plugin maps to reshard_after_forward)
2737+
# Skip if config has explicit reshard_after_forward (prioritize config)
2738+
if "reshard_after_forward" not in self.fsdp_config:
2739+
fsdp_plugin_args["sharding_strategy"] = fsdp_option
27342740
elif fsdp_option == FSDPOption.OFFLOAD:
27352741
fsdp_plugin_args["cpu_offload"] = True
27362742
elif fsdp_option == FSDPOption.AUTO_WRAP:
@@ -2742,24 +2748,43 @@ def _process_fsdp_args(self):
27422748
fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join(
27432749
self.fsdp_config["transformer_layer_cls_to_wrap"]
27442750
)
2745-
fsdp_plugin_args["fsdp_version"] = self.fsdp_config.get("fsdp_version", 1)
2746-
prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
2747-
fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper()
2748-
fsdp_plugin_args["forward_prefetch"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower()
2749-
2750-
sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower()
2751-
cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower()
2752-
if sync_module_states == "false" and cpu_ram_efficient_loading == "true":
2753-
# In this case, all the processes except the main process would have random weights leading
2754-
# to unexpected behaviour during training, thus throwing error here to prevent it.
2755-
raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`')
2756-
2757-
# we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers
2758-
fsdp_plugin_args["cpu_ram_efficient_loading"] = cpu_ram_efficient_loading
2759-
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading
2760-
2761-
fsdp_plugin_args["sync_module_states"] = sync_module_states
2762-
fsdp_plugin_args["use_orig_params"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
2751+
2752+
# Pull allowed parameters from fsdp_config
2753+
ALLOWED_FSDP_PARAMS = {f.name for f in fields(FullyShardedDataParallelPlugin)}
2754+
for key in ALLOWED_FSDP_PARAMS:
2755+
if key in self.fsdp_config and key not in fsdp_plugin_args:
2756+
fsdp_plugin_args[key] = self.fsdp_config[key]
2757+
fsdp_plugin_args["fsdp_version"] = fsdp_version
2758+
2759+
# Special HF-to-plugin map: transformer_layer_cls_to_wrap → joined cls_names
2760+
if (
2761+
"transformer_layer_cls_to_wrap" in self.fsdp_config
2762+
and "transformer_cls_names_to_wrap" not in fsdp_plugin_args
2763+
):
2764+
fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join(
2765+
self.fsdp_config["transformer_layer_cls_to_wrap"]
2766+
)
2767+
2768+
# Validation: sync_module_states vs cpu_ram_efficient_loading
2769+
sync_states = fsdp_plugin_args.get("sync_module_states", "true")
2770+
cpu_loading = fsdp_plugin_args.get("cpu_ram_efficient_loading", "false")
2771+
2772+
if isinstance(sync_states, str):
2773+
sync_states = sync_states.lower()
2774+
if isinstance(cpu_loading, str):
2775+
cpu_loading = cpu_loading.lower()
2776+
if sync_states == "false" and cpu_loading == "true":
2777+
raise ValueError('`sync_module_states` must be `"true"` if `cpu_ram_efficient_loading` is `"true"`.')
2778+
2779+
# CRITICAL: Set environment variable for cpu_ram_efficient_loading
2780+
if "cpu_ram_efficient_loading" in fsdp_plugin_args:
2781+
cpu_ram_value = fsdp_plugin_args["cpu_ram_efficient_loading"]
2782+
# Handle both bool and string values
2783+
if isinstance(cpu_ram_value, bool):
2784+
cpu_ram_value = str(cpu_ram_value).lower()
2785+
elif isinstance(cpu_ram_value, str):
2786+
cpu_ram_value = cpu_ram_value.lower()
2787+
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_value
27632788

27642789
return fsdp_plugin_args
27652790

tests/fsdp/test_fsdp.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,25 @@ def test_fsdp_config(self, sharding_strategy, dtype):
211211
for k, v in trainer.args.fsdp_config.items():
212212
self.assertEqual(v, self.fsdp_config[k])
213213

214+
def test_fsdp_version_2_config(self):
215+
output_dir = self.get_auto_remove_tmp_dir()
216+
kwargs = {
217+
"output_dir": output_dir,
218+
"train_len": 128,
219+
"save_steps": 5,
220+
"learning_rate": 0.1,
221+
"fsdp": True,
222+
"fsdp_config": {
223+
"fsdp_version": 2,
224+
"reshard_after_forward": True,
225+
},
226+
}
227+
with mockenv_context(**self.dist_env_1_gpu):
228+
trainer = get_regression_trainer(**kwargs)
229+
plugin_args = trainer.args._process_fsdp_args()
230+
self.assertEqual(plugin_args["fsdp_version"], 2)
231+
self.assertTrue(plugin_args["reshard_after_forward"])
232+
214233
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
215234
@require_torch_multi_accelerator
216235
@run_first

0 commit comments

Comments
 (0)