diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index 72a2f245cf19..f50c691038a6 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -1288,29 +1288,34 @@ def _process_parameter_type(param, param_name, func): param_name (`str`): The name of the parameter func (`function`): The function the parameter belongs to """ - optional = False - if param.annotation != inspect.Parameter.empty: - param_type = param.annotation - if "typing" in str(param_type): - param_type = "".join(str(param_type).split("typing.")).replace("transformers.", "~") - elif hasattr(param_type, "__module__"): - param_type = f"{param_type.__module__.replace('transformers.', '~').replace('builtins', '')}.{param.annotation.__name__}" - if param_type[0] == ".": - param_type = param_type[1:] + # a parameter for a function might have three basic elements: name, type hint, default value + # for example, it would be like (there are some whitespaces) + # age: int = 18 (if a param has no type hint, it would be like age=18) + parameter_str: str = str(param) + # see if there is type hint for given parameter + if ":" in parameter_str: + name, type_hint_and_default_value = parameter_str.split(":") + type_hint_and_default_value = type_hint_and_default_value.strip() + if "=" in type_hint_and_default_value: + type_hint, default_value = type_hint_and_default_value.split("=") + type_hint = type_hint.strip() + return type_hint, True else: - if False: - print( - f"[ERROR] {param_type} for {param_name} of {func.__qualname__} in file {func.__code__.co_filename} has an invalid type" - ) - if "ForwardRef" in param_type: - param_type = re.sub(r"ForwardRef\('([\w.]+)'\)", r"\1", param_type) - if "Optional" in param_type: - param_type = re.sub(r"Optional\[(.*?)\]", r"\1", param_type) - optional = True + type_hint = type_hint_and_default_value + type_hint = type_hint.strip() + if "Optional" in type_hint: + return type_hint, True + else: + return type_hint, False else: - param_type = "" - - return param_type, optional + # if there is no type hint + # see if there is a default value + if "=" in parameter_str: + # keep this line for debugging, if necessary + name, default_value = parameter_str.split("=") + return "", True + else: + return "", False def _get_parameter_info(param_name, documented_params, source_args_dict, param_type, optional):