Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions src/transformers/utils/auto_docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,29 +1288,34 @@ def _process_parameter_type(param, param_name, func):
param_name (`str`): The name of the parameter
func (`function`): The function the parameter belongs to
"""
optional = False
if param.annotation != inspect.Parameter.empty:
param_type = param.annotation
if "typing" in str(param_type):
param_type = "".join(str(param_type).split("typing.")).replace("transformers.", "~")
elif hasattr(param_type, "__module__"):
param_type = f"{param_type.__module__.replace('transformers.', '~').replace('builtins', '')}.{param.annotation.__name__}"
if param_type[0] == ".":
param_type = param_type[1:]
# a parameter for a function might have three basic elements: name, type hint, default value
# for example, it would be like (there are some whitespaces)
# age: int = 18 (if a param has no type hint, it would be like age=18)
parameter_str: str = str(param)
# see if there is type hint for given parameter
if ":" in parameter_str:
name, type_hint_and_default_value = parameter_str.split(":")
type_hint_and_default_value = type_hint_and_default_value.strip()
if "=" in type_hint_and_default_value:
type_hint, default_value = type_hint_and_default_value.split("=")
type_hint = type_hint.strip()
return type_hint, True
else:
if False:
print(
f"[ERROR] {param_type} for {param_name} of {func.__qualname__} in file {func.__code__.co_filename} has an invalid type"
)
if "ForwardRef" in param_type:
param_type = re.sub(r"ForwardRef\('([\w.]+)'\)", r"\1", param_type)
if "Optional" in param_type:
param_type = re.sub(r"Optional\[(.*?)\]", r"\1", param_type)
optional = True
type_hint = type_hint_and_default_value
type_hint = type_hint.strip()
if "Optional" in type_hint:
return type_hint, True
else:
return type_hint, False
else:
param_type = ""

return param_type, optional
# if there is no type hint
# see if there is a default value
if "=" in parameter_str:
# keep this line for debugging, if necessary
name, default_value = parameter_str.split("=")
return "", True
else:
return "", False


def _get_parameter_info(param_name, documented_params, source_args_dict, param_type, optional):
Expand Down