-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Fix FSDP2 defaulting to version 1 in TrainingArguments; add dynamic plugin param passthrough #42521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
f19d66a to
bcd3599
Compare
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! Left a few comments
src/transformers/training_args.py
Outdated
| fsdp_plugin_args["fsdp_version"] = self.fsdp_config.get("fsdp_version", 1) | ||
| prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") | ||
| fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper() | ||
| fsdp_plugin_args["forward_prefetch"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower() | ||
|
|
||
| sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower() | ||
| cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower() | ||
| if sync_module_states == "false" and cpu_ram_efficient_loading == "true": | ||
| # In this case, all the processes except the main process would have random weights leading | ||
| # to unexpected behaviour during training, thus throwing error here to prevent it. | ||
| raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`') | ||
|
|
||
| # we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers | ||
| fsdp_plugin_args["cpu_ram_efficient_loading"] = cpu_ram_efficient_loading | ||
| os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading | ||
|
|
||
| fsdp_plugin_args["sync_module_states"] = sync_module_states | ||
| fsdp_plugin_args["use_orig_params"] = str(self.fsdp_config.get("use_orig_params", "true")).lower() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it will be better to keep this for now as we can easily change default config here instead of doing it in accelerate. Also, some defaults are not the same here vs accelerate. Nevertheless, for keys that are not exposed here but we have them in FullyShardedDataParallelPlugin, it's fine to set them.
| if key in self.fsdp_config and key not in fsdp_plugin_args: | ||
| fsdp_plugin_args[key] = self.fsdp_config[key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can keep this but put it at the end.
| # Set deprecated sharding_strategy from CLI (plugin maps to reshard_after_forward) | ||
| # Skip if config has explicit reshard_after_forward (prioritize config) | ||
| if "reshard_after_forward" not in self.fsdp_config: | ||
| fsdp_plugin_args["sharding_strategy"] = fsdp_option |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks !
| if self.fsdp_config is not None and isinstance(self.fsdp_config, dict): | ||
| for k in list(self.fsdp_config.keys()): | ||
| if k.startswith("fsdp_"): | ||
| if k.startswith("fsdp_") and k != "fsdp_version": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed thanks !
| def test_fsdp_version_2_config(self): | ||
| output_dir = self.get_auto_remove_tmp_dir() | ||
| kwargs = { | ||
| "output_dir": output_dir, | ||
| "train_len": 128, | ||
| "save_steps": 5, | ||
| "learning_rate": 0.1, | ||
| "fsdp": True, | ||
| "fsdp_config": { | ||
| "fsdp_version": 2, | ||
| "reshard_after_forward": True, | ||
| }, | ||
| } | ||
| with mockenv_context(**self.dist_env_1_gpu): | ||
| trainer = get_regression_trainer(**kwargs) | ||
| plugin_args = trainer.args._process_fsdp_args() | ||
| self.assertEqual(plugin_args["fsdp_version"], 2) | ||
| self.assertTrue(plugin_args["reshard_after_forward"]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
|
Thanks for the feedback @SunMarc!! |
|
I've merged this PR that seems to fix a bit more things ! Feel free to rebase and add your updates, especially the test is nice ! |
9a8af59 to
b6a9f0b
Compare
|
@SunMarc Thanks for the feedback! Rebased onto the latest main, and ready for review |
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just a few nits !
src/transformers/training_args.py
Outdated
| # HF-to-plugin map | ||
| if ( | ||
| "transformer_layer_cls_to_wrap" in self.fsdp_config | ||
| and "transformer_cls_names_to_wrap" not in fsdp_plugin_args | ||
| ): | ||
| fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join( | ||
| self.fsdp_config["transformer_layer_cls_to_wrap"] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not need anymore
| "fsdp_config": { | ||
| "fsdp_version": 2, | ||
| "reshard_after_forward": True, | ||
| }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add more things in the config so that it is better tested ?
61c0e74 to
2252af4
Compare
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks !
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Head branch was pushed to by a user without write access
3761cdd to
ff47d0d
Compare
What does this PR do?
Fixes the bug where
TrainingArguments(fsdp=True, fsdp_config={"fsdp_version": 2, ...})defaults to FSDP version 1, ignoring the version (unless the Accelerator was initialized manually), and most params were lost.Adds dynamic passthrough of all FSDP plugin params (FullyShardedDataParallelPlugin) from fsdp_config to the plugin args (future-proof, no hardcoding).
Changes
fsdp_versionpre-stripping to preserve it.FullyShardedDataParallelPlugindataclass for all fields.test_fsdp_version_2_configfor regression.Repro
Before submitting
Pull Request section?
to it if that's the case.
https://discuss.huggingface.co/t/how-to-start-fsdp2-when-using-trainer/151885
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@SunMarc @3outeille
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.