diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 04c54343404b..0eeeedea551a 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -2742,10 +2742,24 @@ def _process_fsdp_args(self): fsdp_plugin_args["transformer_cls_names_to_wrap"] = ",".join( self.fsdp_config["transformer_layer_cls_to_wrap"] ) - fsdp_plugin_args["fsdp_version"] = self.fsdp_config.get("fsdp_version", 1) + fsdp_version = int(self.fsdp_config.get("version", 1)) + fsdp_plugin_args["fsdp_version"] = fsdp_version prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") - fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper() - fsdp_plugin_args["forward_prefetch"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower() + if fsdp_version == 2: + fsdp_plugin_args["reshard_after_forward"] = str_to_bool( + str(self.fsdp_config.get("reshard_after_forward", "false")).lower() + ) + else: + fsdp_plugin_args["forward_prefetch"] = str_to_bool( + str(self.fsdp_config.get("forward_prefetch", "false")).lower() + ) + fsdp_plugin_args["backward_prefetch"] = prefetch_policy.upper() + fsdp_plugin_args["reshard_after_forward"] = str( + self.fsdp_config.get("reshard_after_forward", "false") + ).lower() + fsdp_plugin_args["use_orig_params"] = str_to_bool( + str(self.fsdp_config.get("use_orig_params", "true")).lower() + ) sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower() cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower() @@ -2755,11 +2769,10 @@ def _process_fsdp_args(self): raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`') # we need to set the env here as otherwise we get a warning in accelerate + we need to set it for transformers - fsdp_plugin_args["cpu_ram_efficient_loading"] = cpu_ram_efficient_loading + fsdp_plugin_args["cpu_ram_efficient_loading"] = str_to_bool(cpu_ram_efficient_loading) os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading - fsdp_plugin_args["sync_module_states"] = sync_module_states - fsdp_plugin_args["use_orig_params"] = str(self.fsdp_config.get("use_orig_params", "true")).lower() + fsdp_plugin_args["sync_module_states"] = str_to_bool(sync_module_states) return fsdp_plugin_args @@ -2771,3 +2784,18 @@ class ParallelMode(Enum): SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel" SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel" TPU = "tpu" + + +def str_to_bool(value, to_bool: bool = True) -> int | bool: + """ + Converts a string representation of truth to `True` (1) or `False` (0). + + True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; + """ + value = value.lower() + if value in ("y", "yes", "t", "true", "on", "1"): + return 1 if not to_bool else True + elif value in ("n", "no", "f", "false", "off", "0"): + return 0 if not to_bool else False + else: + raise ValueError(f"invalid truth value {value}")