Skip to content

Conversation

@amanzoni1
Copy link

What does this PR do?

Fixes the bug where TrainingArguments(fsdp=True, fsdp_config={"fsdp_version": 2, ...}) defaults to FSDP version 1, ignoring the version (unless the Accelerator was initialized manually), and most params were lost.
Adds dynamic passthrough of all FSDP plugin params (FullyShardedDataParallelPlugin) from fsdp_config to the plugin args (future-proof, no hardcoding).

Changes

  • Extract fsdp_version pre-stripping to preserve it.
  • Skip stripping for "fsdp_version" to avoid mangling.
  • Dynamic copy from FullyShardedDataParallelPlugin dataclass for all fields.
  • Add test_fsdp_version_2_config for regression.

Repro

from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="/tmp/test",
    fsdp=True,
    fsdp_config={
        "fsdp_version": 2,
        "reshard_after_forward": True,
    },
)

# Call the internal method to test argument parsing
plugin_args = args._process_fsdp_args()

# Expected output after this PR:
print(f"fsdp_version: {plugin_args['fsdp_version']}")                      # Before: 1, After: 2
print(f"reshard_after_forward: {plugin_args['reshard_after_forward']}")    # Before: None, After: True

Before submitting

Who can review?

@SunMarc @3outeille

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amanzoni1 amanzoni1 force-pushed the fix-fsdp2-default-version branch 2 times, most recently from f19d66a to bcd3599 Compare December 2, 2025 11:16
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Left a few comments

Comment on lines 2745 to 2762
fsdp_plugin_args["fsdp_version"] = self.fsdp_config.get("fsdp_version", 1)
prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper()
fsdp_plugin_args["forward_prefetch"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower()

sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower()
cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower()
if sync_module_states == "false" and cpu_ram_efficient_loading == "true":
# In this case, all the processes except the main process would have random weights leading
# to unexpected behaviour during training, thus throwing error here to prevent it.
raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`')

# we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers
fsdp_plugin_args["cpu_ram_efficient_loading"] = cpu_ram_efficient_loading
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading

fsdp_plugin_args["sync_module_states"] = sync_module_states
fsdp_plugin_args["use_orig_params"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will be better to keep this for now as we can easily change default config here instead of doing it in accelerate. Also, some defaults are not the same here vs accelerate. Nevertheless, for keys that are not exposed here but we have them in FullyShardedDataParallelPlugin, it's fine to set them.

Comment on lines +2755 to +2786
if key in self.fsdp_config and key not in fsdp_plugin_args:
fsdp_plugin_args[key] = self.fsdp_config[key]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can keep this but put it at the end.

Comment on lines +2736 to +2737
# Set deprecated sharding_strategy from CLI (plugin maps to reshard_after_forward)
# Skip if config has explicit reshard_after_forward (prioritize config)
if "reshard_after_forward" not in self.fsdp_config:
fsdp_plugin_args["sharding_strategy"] = fsdp_option
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks !

if self.fsdp_config is not None and isinstance(self.fsdp_config, dict):
for k in list(self.fsdp_config.keys()):
if k.startswith("fsdp_"):
if k.startswith("fsdp_") and k != "fsdp_version":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed thanks !

Comment on lines 214 to 244
def test_fsdp_version_2_config(self):
output_dir = self.get_auto_remove_tmp_dir()
kwargs = {
"output_dir": output_dir,
"train_len": 128,
"save_steps": 5,
"learning_rate": 0.1,
"fsdp": True,
"fsdp_config": {
"fsdp_version": 2,
"reshard_after_forward": True,
},
}
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(**kwargs)
plugin_args = trainer.args._process_fsdp_args()
self.assertEqual(plugin_args["fsdp_version"], 2)
self.assertTrue(plugin_args["reshard_after_forward"])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

@amanzoni1
Copy link
Author

Thanks for the feedback @SunMarc!!
implemented the hardcoded defaults + dynamic copy at end, and needed to add a conditional skip for forward_prefetch in v2 (Accelerate raises ValueError otherwise: "ValueError: forward_prefetch is not yet implemented in FSDP2, set to None or use fsdp_version=1").
Let me know if that can works

@SunMarc
Copy link
Member

SunMarc commented Dec 2, 2025

I've merged this PR that seems to fix a bit more things ! Feel free to rebase and add your updates, especially the test is nice !

#42566

@amanzoni1 amanzoni1 force-pushed the fix-fsdp2-default-version branch from 9a8af59 to b6a9f0b Compare December 2, 2025 20:10
@amanzoni1
Copy link
Author

@SunMarc Thanks for the feedback! Rebased onto the latest main, and ready for review

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just a few nits !

Comment on lines 2782 to 2789
# HF-to-plugin map
if (
"transformer_layer_cls_to_wrap" in self.fsdp_config
and "transformer_cls_names_to_wrap" not in fsdp_plugin_args
):
fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join(
self.fsdp_config["transformer_layer_cls_to_wrap"]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not need anymore

Comment on lines 222 to 231
"fsdp_config": {
"fsdp_version": 2,
"reshard_after_forward": True,
},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add more things in the config so that it is better tested ?

@amanzoni1 amanzoni1 force-pushed the fix-fsdp2-default-version branch from 61c0e74 to 2252af4 Compare December 3, 2025 15:33
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

@SunMarc SunMarc enabled auto-merge (squash) December 3, 2025 15:50
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

auto-merge was automatically disabled December 3, 2025 16:10

Head branch was pushed to by a user without write access

@amanzoni1 amanzoni1 force-pushed the fix-fsdp2-default-version branch from 3761cdd to ff47d0d Compare December 3, 2025 16:10
@SunMarc SunMarc enabled auto-merge (squash) December 3, 2025 16:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants