Skip to content

Commit 9a8af59

Browse files
committed
Update FSDP hardcoded params + dynamic copy per review
1 parent bcd3599 commit 9a8af59

File tree

1 file changed

+26
-28
lines changed

1 file changed

+26
-28
lines changed

src/transformers/training_args.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2676,7 +2676,6 @@ def _process_fsdp_args(self):
26762676
with open(self.fsdp_config, encoding="utf-8") as f:
26772677
self.fsdp_config = json.load(f)
26782678

2679-
fsdp_version = self.fsdp_config.get("fsdp_version", 1)
26802679
if self.fsdp_config is not None and isinstance(self.fsdp_config, dict):
26812680
for k in list(self.fsdp_config.keys()):
26822681
if k.startswith("fsdp_") and k != "fsdp_version":
@@ -2749,14 +2748,28 @@ def _process_fsdp_args(self):
27492748
self.fsdp_config["transformer_layer_cls_to_wrap"]
27502749
)
27512750

2752-
# Pull allowed parameters from fsdp_config
2753-
ALLOWED_FSDP_PARAMS = {f.name for f in fields(FullyShardedDataParallelPlugin)}
2754-
for key in ALLOWED_FSDP_PARAMS:
2755-
if key in self.fsdp_config and key not in fsdp_plugin_args:
2756-
fsdp_plugin_args[key] = self.fsdp_config[key]
2757-
fsdp_plugin_args["fsdp_version"] = fsdp_version
2751+
fsdp_plugin_args["fsdp_version"] = self.fsdp_config.get("fsdp_version", 1)
2752+
prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
2753+
fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper()
2754+
if fsdp_plugin_args["fsdp_version"] == 1:
2755+
# forward_prefetch is not yet implemented in FSDP2
2756+
fsdp_plugin_args["forward_prefetch"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower()
2757+
2758+
sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower()
2759+
cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower()
2760+
if sync_module_states == "false" and cpu_ram_efficient_loading == "true":
2761+
# In this case, all the processes except the main process would have random weights leading
2762+
# to unexpected behaviour during training, thus throwing error here to prevent it.
2763+
raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`')
2764+
2765+
# we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers
2766+
fsdp_plugin_args["cpu_ram_efficient_loading"] = cpu_ram_efficient_loading
2767+
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading
27582768

2759-
# Special HF-to-plugin map: transformer_layer_cls_to_wrap → joined cls_names
2769+
fsdp_plugin_args["sync_module_states"] = sync_module_states
2770+
fsdp_plugin_args["use_orig_params"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
2771+
2772+
# HF-to-plugin map
27602773
if (
27612774
"transformer_layer_cls_to_wrap" in self.fsdp_config
27622775
and "transformer_cls_names_to_wrap" not in fsdp_plugin_args
@@ -2765,26 +2778,11 @@ def _process_fsdp_args(self):
27652778
self.fsdp_config["transformer_layer_cls_to_wrap"]
27662779
)
27672780

2768-
# Validation: sync_module_states vs cpu_ram_efficient_loading
2769-
sync_states = fsdp_plugin_args.get("sync_module_states", "true")
2770-
cpu_loading = fsdp_plugin_args.get("cpu_ram_efficient_loading", "false")
2771-
2772-
if isinstance(sync_states, str):
2773-
sync_states = sync_states.lower()
2774-
if isinstance(cpu_loading, str):
2775-
cpu_loading = cpu_loading.lower()
2776-
if sync_states == "false" and cpu_loading == "true":
2777-
raise ValueError('`sync_module_states` must be `"true"` if `cpu_ram_efficient_loading` is `"true"`.')
2778-
2779-
# CRITICAL: Set environment variable for cpu_ram_efficient_loading
2780-
if "cpu_ram_efficient_loading" in fsdp_plugin_args:
2781-
cpu_ram_value = fsdp_plugin_args["cpu_ram_efficient_loading"]
2782-
# Handle both bool and string values
2783-
if isinstance(cpu_ram_value, bool):
2784-
cpu_ram_value = str(cpu_ram_value).lower()
2785-
elif isinstance(cpu_ram_value, str):
2786-
cpu_ram_value = cpu_ram_value.lower()
2787-
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_value
2781+
# Pull allowed parameters from fsdp_config
2782+
ALLOWED_FSDP_PARAMS = {f.name for f in fields(FullyShardedDataParallelPlugin)}
2783+
for key in ALLOWED_FSDP_PARAMS:
2784+
if key in self.fsdp_config and key not in fsdp_plugin_args:
2785+
fsdp_plugin_args[key] = self.fsdp_config[key]
27882786

27892787
return fsdp_plugin_args
27902788

0 commit comments

Comments
 (0)