From 01059244ff760337d492b8ef9f8a335491591508 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 2 Dec 2025 11:34:39 -0500 Subject: [PATCH 1/5] make sure the FSDP plugin args are appropriately cast to bools --- src/transformers/training_args.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 04c54343404b..b2ee0ba74883 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2745,7 +2745,7 @@ def _process_fsdp_args(self): fsdp_plugin_args["fsdp_version"] = self.fsdp_config.get("fsdp_version", 1) prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper() - fsdp_plugin_args["forward_prefetch"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower() + fsdp_plugin_args["forward_prefetch"] = str_to_bool(str(self.fsdp_config.get("forward_prefetch", "false")).lower()) sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower() cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower() @@ -2755,11 +2755,11 @@ def _process_fsdp_args(self): raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`') # we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers - fsdp_plugin_args["cpu_ram_efficient_loading"] = cpu_ram_efficient_loading + fsdp_plugin_args["cpu_ram_efficient_loading"] = str_to_bool(cpu_ram_efficient_loading) os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading - fsdp_plugin_args["sync_module_states"] = sync_module_states - fsdp_plugin_args["use_orig_params"] = str(self.fsdp_config.get("use_orig_params", "true")).lower() + fsdp_plugin_args["sync_module_states"] = str_to_bool(sync_module_states) + fsdp_plugin_args["use_orig_params"] = str_to_bool(str(self.fsdp_config.get("use_orig_params", "true")).lower()) return fsdp_plugin_args @@ -2771,3 +2771,18 @@ class ParallelMode(Enum): SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel" SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel" TPU = "tpu" + + +def str_to_bool(value, to_bool: bool = True) -> int | bool: + """ + Converts a string representation of truth to `True` (1) or `False` (0). + + True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; + """ + value = value.lower() + if value in ("y", "yes", "t", "true", "on", "1"): + return 1 if not to_bool else True + elif value in ("n", "no", "f", "false", "off", "0"): + return 0 if not to_bool else False + else: + raise ValueError(f"invalid truth value {value}") \ No newline at end of file From db15a5782bf7b4a9735825b1015fc2c4e560a77e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 2 Dec 2025 12:34:37 -0500 Subject: [PATCH 2/5] handle fsdp_version properly --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b2ee0ba74883..93a3571fe2e8 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2742,7 +2742,7 @@ def _process_fsdp_args(self): fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join( self.fsdp_config["transformer_layer_cls_to_wrap"] ) - fsdp_plugin_args["fsdp_version"] = self.fsdp_config.get("fsdp_version", 1) + fsdp_plugin_args["fsdp_version"] = self.fsdp_config.get("version", 1) prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper() fsdp_plugin_args["forward_prefetch"] = str_to_bool(str(self.fsdp_config.get("forward_prefetch", "false")).lower()) From 05ee4e27e79e587a15e814410e5663e70d8e9ea6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 2 Dec 2025 12:43:10 -0500 Subject: [PATCH 3/5] include reshard_after_forward and handle correctly for fsdp2 vs fsdp2 --- src/transformers/training_args.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 93a3571fe2e8..01017a800d80 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2742,10 +2742,16 @@ def _process_fsdp_args(self): fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join( self.fsdp_config["transformer_layer_cls_to_wrap"] ) - fsdp_plugin_args["fsdp_version"] = self.fsdp_config.get("version", 1) + fsdp_version = int(self.fsdp_config.get("version", 1)) + fsdp_plugin_args["fsdp_version"] = fsdp_version prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") - fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper() - fsdp_plugin_args["forward_prefetch"] = str_to_bool(str(self.fsdp_config.get("forward_prefetch", "false")).lower()) + if fsdp_version == 2: + fsdp_plugin_args["reshard_after_forward"] = str_to_bool(str(self.fsdp_config.get("reshard_after_forward", "false")).lower()) + else: + fsdp_plugin_args["forward_prefetch"] = str_to_bool(str(self.fsdp_config.get("forward_prefetch", "false")).lower()) + fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper() + fsdp_plugin_args["reshard_after_forward"] = str(self.fsdp_config.get("reshard_after_forward", "false")).lower() + fsdp_plugin_args["use_orig_params"] = str_to_bool(str(self.fsdp_config.get("use_orig_params", "true")).lower()) sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower() cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower() @@ -2759,7 +2765,6 @@ def _process_fsdp_args(self): os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading fsdp_plugin_args["sync_module_states"] = str_to_bool(sync_module_states) - fsdp_plugin_args["use_orig_params"] = str_to_bool(str(self.fsdp_config.get("use_orig_params", "true")).lower()) return fsdp_plugin_args From 52120c3a07a2a5ed569373e13055ce68797063ad Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 2 Dec 2025 12:47:00 -0500 Subject: [PATCH 4/5] lint --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 01017a800d80..e8f67cf4b5d7 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2790,4 +2790,4 @@ def str_to_bool(value, to_bool: bool = True) -> int | bool: elif value in ("n", "no", "f", "false", "off", "0"): return 0 if not to_bool else False else: - raise ValueError(f"invalid truth value {value}") \ No newline at end of file + raise ValueError(f"invalid truth value {value}") From 120df30179c6526eec2c47edf8490ea77fa503e3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 2 Dec 2025 12:49:36 -0500 Subject: [PATCH 5/5] chore: lint --- src/transformers/training_args.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e8f67cf4b5d7..0eeeedea551a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2746,12 +2746,20 @@ def _process_fsdp_args(self): fsdp_plugin_args["fsdp_version"] = fsdp_version prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") if fsdp_version == 2: - fsdp_plugin_args["reshard_after_forward"] = str_to_bool(str(self.fsdp_config.get("reshard_after_forward", "false")).lower()) + fsdp_plugin_args["reshard_after_forward"] = str_to_bool( + str(self.fsdp_config.get("reshard_after_forward", "false")).lower() + ) else: - fsdp_plugin_args["forward_prefetch"] = str_to_bool(str(self.fsdp_config.get("forward_prefetch", "false")).lower()) + fsdp_plugin_args["forward_prefetch"] = str_to_bool( + str(self.fsdp_config.get("forward_prefetch", "false")).lower() + ) fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper() - fsdp_plugin_args["reshard_after_forward"] = str(self.fsdp_config.get("reshard_after_forward", "false")).lower() - fsdp_plugin_args["use_orig_params"] = str_to_bool(str(self.fsdp_config.get("use_orig_params", "true")).lower()) + fsdp_plugin_args["reshard_after_forward"] = str( + self.fsdp_config.get("reshard_after_forward", "false") + ).lower() + fsdp_plugin_args["use_orig_params"] = str_to_bool( + str(self.fsdp_config.get("use_orig_params", "true")).lower() + ) sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower() cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower()