Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2678,7 +2678,7 @@ def _process_fsdp_args(self):

if self.fsdp_config is not None and isinstance(self.fsdp_config, dict):
for k in list(self.fsdp_config.keys()):
if k.startswith("fsdp_"):
if k.startswith("fsdp_") and k != "fsdp_version":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed thanks !

v = self.fsdp_config.pop(k)
self.fsdp_config[k[5:]] = v

Expand Down Expand Up @@ -2722,6 +2722,7 @@ def _process_fsdp_args(self):
# accelerate integration for FSDP
fsdp_plugin_args = None
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
from accelerate.utils import FullyShardedDataParallelPlugin
from accelerate.utils.constants import (
FSDP_AUTO_WRAP_POLICY,
FSDP_SHARDING_STRATEGY,
Expand All @@ -2730,7 +2731,10 @@ def _process_fsdp_args(self):
fsdp_plugin_args = {}
for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
fsdp_plugin_args["sharding_strategy"] = fsdp_option
# Set deprecated sharding_strategy from CLI (plugin maps to reshard_after_forward)
# Skip if config has explicit reshard_after_forward (prioritize config)
if "reshard_after_forward" not in self.fsdp_config:
fsdp_plugin_args["sharding_strategy"] = fsdp_option
Comment on lines +2734 to +2737
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks !

elif fsdp_option == FSDPOption.OFFLOAD:
fsdp_plugin_args["cpu_offload"] = True
elif fsdp_option == FSDPOption.AUTO_WRAP:
Expand All @@ -2742,7 +2746,8 @@ def _process_fsdp_args(self):
fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join(
self.fsdp_config["transformer_layer_cls_to_wrap"]
)
fsdp_version = int(self.fsdp_config.get("version", 1))

fsdp_version = int(self.fsdp_config.get("fsdp_version", 1))
fsdp_plugin_args["fsdp_version"] = fsdp_version
prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
if fsdp_version == 2:
Expand Down Expand Up @@ -2774,6 +2779,12 @@ def _process_fsdp_args(self):

fsdp_plugin_args["sync_module_states"] = str_to_bool(sync_module_states)

# Pull allowed parameters from fsdp_config
ALLOWED_FSDP_PARAMS = {f.name for f in fields(FullyShardedDataParallelPlugin)}
for key in ALLOWED_FSDP_PARAMS:
if key in self.fsdp_config and key not in fsdp_plugin_args:
fsdp_plugin_args[key] = self.fsdp_config[key]

return fsdp_plugin_args


Expand Down
31 changes: 31 additions & 0 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,37 @@ def test_fsdp_config(self, sharding_strategy, dtype):
for k, v in trainer.args.fsdp_config.items():
self.assertEqual(v, self.fsdp_config[k])

def test_fsdp_version_2_config(self):
output_dir = self.get_auto_remove_tmp_dir()
kwargs = {
"output_dir": output_dir,
"train_len": 128,
"save_steps": 5,
"learning_rate": 0.1,
"fsdp": True,
"fsdp_config": {
"fsdp_version": 2,
"reshard_after_forward": True,
"auto_wrap_policy": "transformer_based_wrap",
"transformer_cls_names_to_wrap": ["BertLayer"],
"state_dict_type": "FULL_STATE_DICT",
"activation_checkpointing": True,
"cpu_offload": True,
"limit_all_gathers": True,
},
}
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(**kwargs)
plugin_args = trainer.args._process_fsdp_args()
self.assertEqual(plugin_args["fsdp_version"], 2)
self.assertTrue(plugin_args["reshard_after_forward"])
self.assertEqual(plugin_args["auto_wrap_policy"], "transformer_based_wrap")
self.assertListEqual(plugin_args["transformer_cls_names_to_wrap"], ["BertLayer"])
self.assertEqual(plugin_args["state_dict_type"], "FULL_STATE_DICT")
self.assertTrue(plugin_args["activation_checkpointing"])
self.assertTrue(plugin_args["cpu_offload"])
self.assertTrue(plugin_args["limit_all_gathers"])

@parameterized.expand(params, name_func=_parameterized_custom_name_func)
@require_torch_multi_accelerator
@run_first
Expand Down