From c902ac3e687465318958ed87b4d71cd89de0e4ce Mon Sep 17 00:00:00 2001 From: Ariel Fridman Date: Fri, 7 Feb 2025 15:31:10 +0200 Subject: [PATCH 1/3] Adds support for unions with no complex types --- typer/main.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/typer/main.py b/typer/main.py index 8beb61f7a5..cd11254c21 100644 --- a/typer/main.py +++ b/typer/main.py @@ -700,6 +700,27 @@ def wrapper(**kwargs: Any) -> Any: update_wrapper(wrapper, callback) return wrapper +class UnionParamType(click.ParamType): + @property + def name(self) -> str: # type: ignore + return ' | '.join(_type.name for _type in self._types) + + def __init__(self, types: List[click.ParamType]): + super().__init__() + self._types = types + + def convert(self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context]) -> Any: + # *types, last = self._types + error_messages = [] + for _type in self._types: + try: + return _type.convert(value, param, ctx) + except click.BadParameter as e: + print(type(e)) + error_messages.append(str(e)) + # return last.convert(value, param, ctx) + raise self.fail('\n' + '\nbut also\n'.join(error_messages), param, ctx) + def get_click_type( *, annotation: Any, parameter_info: ParameterInfo @@ -791,6 +812,9 @@ def get_click_type( [item.value for item in annotation], case_sensitive=parameter_info.case_sensitive, ) + elif get_origin(annotation) is not None and is_union(get_origin(annotation)): + types = [get_click_type(annotation=arg, parameter_info=parameter_info) for arg in get_args(annotation)] + return UnionParamType(types) raise RuntimeError(f"Type not yet supported: {annotation}") # pragma: no cover @@ -841,9 +865,14 @@ def get_click_param( if type_ is NoneType: continue types.append(type_) - assert len(types) == 1, "Typer Currently doesn't support Union types" - main_type = types[0] - origin = get_origin(main_type) + if len(types) == 1: + main_type, = types + origin = get_origin(main_type) + else: + for type_ in get_args(main_type): + assert not get_origin(type_), ( + "Union types with complex sub-types are not currently supported" + ) # Handle Tuples and Lists if lenient_issubclass(origin, List): main_type = get_args(main_type)[0] From 639de2b623dcdc83f937e72123147528ba863e1a Mon Sep 17 00:00:00 2001 From: Ariel Fridman Date: Fri, 7 Feb 2025 15:36:00 +0200 Subject: [PATCH 2/3] Adds tests --- tests/test_type_conversion.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/test_type_conversion.py b/tests/test_type_conversion.py index 904a686d2e..0a78c67def 100644 --- a/tests/test_type_conversion.py +++ b/tests/test_type_conversion.py @@ -1,6 +1,6 @@ from enum import Enum from pathlib import Path -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union import click import pytest @@ -50,6 +50,31 @@ def opt(user: str | None = None): assert result.exit_code == 0 assert "User: Camila" in result.output +@pytest.mark.parametrize( + ("value", "expected"), + [ + ("0", "ROOTED!"), + ("12", "ID: 12"), + ("name", "USER: name") + ], +) +def test_union(value, expected): + app = typer.Typer() + + @app.command() + def opt(id_or_name: Union[int, str]): + if isinstance(id_or_name, int): + if id_or_name == 0: + print("ROOTED!") + else: + print(f"ID: {id_or_name}") + else: + print(f"USER: {id_or_name}") + + result = runner.invoke(app, [value]) + assert result.exit_code == 0 + assert expected in result.output + def test_optional_tuple(): app = typer.Typer() From 24bc37c12f0e62fc578a559574e4e5fff0e130a3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Feb 2025 13:46:04 +0000 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20for?= =?UTF-8?q?mat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_type_conversion.py | 15 ++++++--------- typer/main.py | 18 ++++++++++++------ 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/test_type_conversion.py b/tests/test_type_conversion.py index 0a78c67def..7934d8f508 100644 --- a/tests/test_type_conversion.py +++ b/tests/test_type_conversion.py @@ -50,17 +50,14 @@ def opt(user: str | None = None): assert result.exit_code == 0 assert "User: Camila" in result.output + @pytest.mark.parametrize( ("value", "expected"), - [ - ("0", "ROOTED!"), - ("12", "ID: 12"), - ("name", "USER: name") - ], + [("0", "ROOTED!"), ("12", "ID: 12"), ("name", "USER: name")], ) def test_union(value, expected): app = typer.Typer() - + @app.command() def opt(id_or_name: Union[int, str]): if isinstance(id_or_name, int): @@ -70,11 +67,11 @@ def opt(id_or_name: Union[int, str]): print(f"ID: {id_or_name}") else: print(f"USER: {id_or_name}") - + result = runner.invoke(app, [value]) assert result.exit_code == 0 - assert expected in result.output - + assert expected in result.output + def test_optional_tuple(): app = typer.Typer() diff --git a/typer/main.py b/typer/main.py index cd11254c21..4f553cb262 100644 --- a/typer/main.py +++ b/typer/main.py @@ -700,16 +700,19 @@ def wrapper(**kwargs: Any) -> Any: update_wrapper(wrapper, callback) return wrapper + class UnionParamType(click.ParamType): @property - def name(self) -> str: # type: ignore - return ' | '.join(_type.name for _type in self._types) + def name(self) -> str: # type: ignore + return " | ".join(_type.name for _type in self._types) def __init__(self, types: List[click.ParamType]): super().__init__() self._types = types - def convert(self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context]) -> Any: + def convert( + self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context] + ) -> Any: # *types, last = self._types error_messages = [] for _type in self._types: @@ -719,7 +722,7 @@ def convert(self, value: Any, param: Optional[click.Parameter], ctx: Optional[cl print(type(e)) error_messages.append(str(e)) # return last.convert(value, param, ctx) - raise self.fail('\n' + '\nbut also\n'.join(error_messages), param, ctx) + raise self.fail("\n" + "\nbut also\n".join(error_messages), param, ctx) def get_click_type( @@ -813,7 +816,10 @@ def get_click_type( case_sensitive=parameter_info.case_sensitive, ) elif get_origin(annotation) is not None and is_union(get_origin(annotation)): - types = [get_click_type(annotation=arg, parameter_info=parameter_info) for arg in get_args(annotation)] + types = [ + get_click_type(annotation=arg, parameter_info=parameter_info) + for arg in get_args(annotation) + ] return UnionParamType(types) raise RuntimeError(f"Type not yet supported: {annotation}") # pragma: no cover @@ -866,7 +872,7 @@ def get_click_param( continue types.append(type_) if len(types) == 1: - main_type, = types + (main_type,) = types origin = get_origin(main_type) else: for type_ in get_args(main_type):