From 139cb6bf1813f98d9e3fad05a75eb76d66511e68 Mon Sep 17 00:00:00 2001 From: amanzoni1 Date: Mon, 1 Dec 2025 19:12:02 +0400 Subject: [PATCH 1/4] Fix FSDP v2 defaulting to version 1 in TrainingArguments --- src/transformers/training_args.py | 28 +++++++++++++++++++++++++--- tests/fsdp/test_fsdp.py | 19 +++++++++++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 778fffdc312a..29bc68d26685 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2678,7 +2678,7 @@ def _process_fsdp_args(self): if self.fsdp_config is not None and isinstance(self.fsdp_config, dict): for k in list(self.fsdp_config.keys()): - if k.startswith("fsdp_"): + if k.startswith("fsdp_") and k != "fsdp_version": v = self.fsdp_config.pop(k) self.fsdp_config[k[5:]] = v @@ -2722,15 +2722,20 @@ def _process_fsdp_args(self): # accelerate integration for FSDP fsdp_plugin_args = None if len(self.fsdp) > 0 and not self.fsdp_config["xla"]: + from accelerate.utils import FullyShardedDataParallelPlugin from accelerate.utils.constants import ( FSDP_AUTO_WRAP_POLICY, FSDP_SHARDING_STRATEGY, ) fsdp_plugin_args = {} + # Handle basic FSDP options from command-line flags for fsdp_option in self.fsdp: if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: - fsdp_plugin_args["sharding_strategy"] = fsdp_option + # Set deprecated sharding_strategy from CLI (plugin maps to reshard_after_forward) + # Skip if config has explicit reshard_after_forward (prioritize config) + if "reshard_after_forward" not in self.fsdp_config: + fsdp_plugin_args["sharding_strategy"] = fsdp_option elif fsdp_option == FSDPOption.OFFLOAD: fsdp_plugin_args["cpu_offload"] = True elif fsdp_option == FSDPOption.AUTO_WRAP: @@ -2742,7 +2747,8 @@ def _process_fsdp_args(self): fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join( self.fsdp_config["transformer_layer_cls_to_wrap"] ) - fsdp_version = int(self.fsdp_config.get("version", 1)) + + fsdp_version = int(self.fsdp_config.get("fsdp_version", 1)) fsdp_plugin_args["fsdp_version"] = fsdp_version prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") if fsdp_version == 2: @@ -2768,12 +2774,28 @@ def _process_fsdp_args(self): # to unexpected behaviour during training, thus throwing error here to prevent it. raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`') + # we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers fsdp_plugin_args["cpu_ram_efficient_loading"] = str_to_bool(cpu_ram_efficient_loading) os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading fsdp_plugin_args["sync_module_states"] = str_to_bool(sync_module_states) + # HF-to-plugin map + if ( + "transformer_layer_cls_to_wrap" in self.fsdp_config + and "transformer_cls_names_to_wrap" not in fsdp_plugin_args + ): + fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join( + self.fsdp_config["transformer_layer_cls_to_wrap"] + ) + + # Pull allowed parameters from fsdp_config + ALLOWED_FSDP_PARAMS = {f.name for f in fields(FullyShardedDataParallelPlugin)} + for key in ALLOWED_FSDP_PARAMS: + if key in self.fsdp_config and key not in fsdp_plugin_args: + fsdp_plugin_args[key] = self.fsdp_config[key] + return fsdp_plugin_args diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 7f0cb0482bdb..526ba6edff56 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -211,6 +211,25 @@ def test_fsdp_config(self, sharding_strategy, dtype): for k, v in trainer.args.fsdp_config.items(): self.assertEqual(v, self.fsdp_config[k]) + def test_fsdp_version_2_config(self): + output_dir = self.get_auto_remove_tmp_dir() + kwargs = { + "output_dir": output_dir, + "train_len": 128, + "save_steps": 5, + "learning_rate": 0.1, + "fsdp": True, + "fsdp_config": { + "fsdp_version": 2, + "reshard_after_forward": True, + }, + } + with mockenv_context(**self.dist_env_1_gpu): + trainer = get_regression_trainer(**kwargs) + plugin_args = trainer.args._process_fsdp_args() + self.assertEqual(plugin_args["fsdp_version"], 2) + self.assertTrue(plugin_args["reshard_after_forward"]) + @parameterized.expand(params, name_func=_parameterized_custom_name_func) @require_torch_multi_accelerator @run_first From f14bd4f6b253ea2e7640be35be213707011bedcd Mon Sep 17 00:00:00 2001 From: amanzoni1 Date: Wed, 3 Dec 2025 00:06:59 +0400 Subject: [PATCH 2/4] Fix FSDP2 params and add test --- src/transformers/training_args.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 29bc68d26685..ba6e7dbeaa49 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2729,7 +2729,6 @@ def _process_fsdp_args(self): ) fsdp_plugin_args = {} - # Handle basic FSDP options from command-line flags for fsdp_option in self.fsdp: if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: # Set deprecated sharding_strategy from CLI (plugin maps to reshard_after_forward) From 1140f6935aec2f8c0f17f3d3c618a32ec6ac1ac9 Mon Sep 17 00:00:00 2001 From: amanzoni1 Date: Wed, 3 Dec 2025 00:14:07 +0400 Subject: [PATCH 3/4] Fix FSDP2 params and add test --- src/transformers/training_args.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index ba6e7dbeaa49..f7fae390c047 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2773,7 +2773,6 @@ def _process_fsdp_args(self): # to unexpected behaviour during training, thus throwing error here to prevent it. raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`') - # we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers fsdp_plugin_args["cpu_ram_efficient_loading"] = str_to_bool(cpu_ram_efficient_loading) os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading From ff47d0ddcbd39e508273147c2355800d5afa56bf Mon Sep 17 00:00:00 2001 From: amanzoni1 Date: Wed, 3 Dec 2025 19:27:33 +0400 Subject: [PATCH 4/4] Better FSDP2 test --- src/transformers/training_args.py | 9 --------- tests/fsdp/test_fsdp.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index f7fae390c047..58ad379ee8d7 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2779,15 +2779,6 @@ def _process_fsdp_args(self): fsdp_plugin_args["sync_module_states"] = str_to_bool(sync_module_states) - # HF-to-plugin map - if ( - "transformer_layer_cls_to_wrap" in self.fsdp_config - and "transformer_cls_names_to_wrap" not in fsdp_plugin_args - ): - fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join( - self.fsdp_config["transformer_layer_cls_to_wrap"] - ) - # Pull allowed parameters from fsdp_config ALLOWED_FSDP_PARAMS = {f.name for f in fields(FullyShardedDataParallelPlugin)} for key in ALLOWED_FSDP_PARAMS: diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 526ba6edff56..9c7bd744505e 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -222,6 +222,12 @@ def test_fsdp_version_2_config(self): "fsdp_config": { "fsdp_version": 2, "reshard_after_forward": True, + "auto_wrap_policy": "transformer_based_wrap", + "transformer_cls_names_to_wrap": ["BertLayer"], + "state_dict_type": "FULL_STATE_DICT", + "activation_checkpointing": True, + "cpu_offload": True, + "limit_all_gathers": True, }, } with mockenv_context(**self.dist_env_1_gpu): @@ -229,6 +235,12 @@ def test_fsdp_version_2_config(self): plugin_args = trainer.args._process_fsdp_args() self.assertEqual(plugin_args["fsdp_version"], 2) self.assertTrue(plugin_args["reshard_after_forward"]) + self.assertEqual(plugin_args["auto_wrap_policy"], "transformer_based_wrap") + self.assertListEqual(plugin_args["transformer_cls_names_to_wrap"], ["BertLayer"]) + self.assertEqual(plugin_args["state_dict_type"], "FULL_STATE_DICT") + self.assertTrue(plugin_args["activation_checkpointing"]) + self.assertTrue(plugin_args["cpu_offload"]) + self.assertTrue(plugin_args["limit_all_gathers"]) @parameterized.expand(params, name_func=_parameterized_custom_name_func) @require_torch_multi_accelerator