Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion tests/test_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from pathlib import Path
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, Union

import click
import pytest
Expand Down Expand Up @@ -51,6 +51,28 @@ def opt(user: str | None = None):
assert "User: Camila" in result.output


@pytest.mark.parametrize(
("value", "expected"),
[("0", "ROOTED!"), ("12", "ID: 12"), ("name", "USER: name")],
)
def test_union(value, expected):
app = typer.Typer()

@app.command()
def opt(id_or_name: Union[int, str]):
if isinstance(id_or_name, int):
Comment on lines +62 to +63
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests should be simpler.

    @app.command()
    def opt(id_or_name: Union[float, int, str]):
        print(f"{id_or_name} ({type(id_or_name)})")

Something like this

if id_or_name == 0:
print("ROOTED!")
else:
print(f"ID: {id_or_name}")
else:
print(f"USER: {id_or_name}")

result = runner.invoke(app, [value])
assert result.exit_code == 0
assert expected in result.output


def test_optional_tuple():
app = typer.Typer()

Expand Down
41 changes: 38 additions & 3 deletions typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,30 @@ def wrapper(**kwargs: Any) -> Any:
return wrapper


class UnionParamType(click.ParamType):
@property
def name(self) -> str: # type: ignore
return " | ".join(_type.name for _type in self._types)

def __init__(self, types: List[click.ParamType]):
super().__init__()
self._types = types

def convert(
self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context]
) -> Any:
# *types, last = self._types
error_messages = []
for _type in self._types:
try:
return _type.convert(value, param, ctx)
except click.BadParameter as e:
print(type(e))
error_messages.append(str(e))
# return last.convert(value, param, ctx)
raise self.fail("\n" + "\nbut also\n".join(error_messages), param, ctx)


def get_click_type(
*, annotation: Any, parameter_info: ParameterInfo
) -> click.ParamType:
Expand Down Expand Up @@ -797,6 +821,12 @@ def get_click_type(
[item.value for item in annotation],
case_sensitive=parameter_info.case_sensitive,
)
elif get_origin(annotation) is not None and is_union(get_origin(annotation)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif get_origin(annotation) is not None and is_union(get_origin(annotation)):
elif is_union(get_origin(annotation)):

Seems that this should work

types = [
get_click_type(annotation=arg, parameter_info=parameter_info)
for arg in get_args(annotation)
]
return UnionParamType(types)
raise RuntimeError(f"Type not yet supported: {annotation}") # pragma: no cover


Expand Down Expand Up @@ -847,9 +877,14 @@ def get_click_param(
if type_ is NoneType:
continue
types.append(type_)
assert len(types) == 1, "Typer Currently doesn't support Union types"
main_type = types[0]
origin = get_origin(main_type)
if len(types) == 1:
(main_type,) = types
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, main_type = types[0] is clearer.

Also, comment above should be updated (Handle SomeType | None and Optional[SomeType])

origin = get_origin(main_type)
else:
for type_ in get_args(main_type):
assert not get_origin(type_), (
"Union types with complex sub-types are not currently supported"
)
# Handle Tuples and Lists
if lenient_issubclass(origin, List):
main_type = get_args(main_type)[0]
Expand Down
Loading