diff --git a/jsons/_lizers_impl.py b/jsons/_lizers_impl.py index 749a3dc..3dd0755 100644 --- a/jsons/_lizers_impl.py +++ b/jsons/_lizers_impl.py @@ -4,6 +4,7 @@ This module contains functionality for setting and getting serializers and deserializers. """ +import types from typing import Optional, Dict, Sequence, Union from jsons._cache import cached @@ -156,4 +157,7 @@ def _get_parents(cls: type, lizers: list) -> list: parents.append(cls_) except (TypeError, AttributeError): pass # Some types do not support `issubclass` (e.g. Union). + if not parents and isinstance(naked_cls, types.UnionType) and \ + Union in lizers: + parents = [Union] return parents diff --git a/tests/test_union.py b/tests/test_union.py index 5dd521a..beea164 100644 --- a/tests/test_union.py +++ b/tests/test_union.py @@ -86,6 +86,27 @@ def __init__(self, y: int): with self.assertRaises(SerializationError): jsons.dump(A(1), Union[B], strict=True) + def test_dump_union_syntax(self): + class A: + def __init__(self, x: int | float): + self.x = x + + dumped = jsons.dump(A(1)) + expected = {'x': 1} + self.assertDictEqual(expected, dumped) + + dumped2 = jsons.dump(A(1), strict=True) + expected2 = {'x': 1} + self.assertDictEqual(expected2, dumped2) + + dumped = jsons.dump(A(2.0)) + expected = {'x': 2.0} + self.assertDictEqual(expected, dumped) + + dumped2 = jsons.dump(A(2.0), strict=True) + expected2 = {'x': 2.0} + self.assertDictEqual(expected2, dumped2) + def test_fail(self): with self.assertRaises(SerializationError) as err: jsons.dump('nope', Union[int, float]) @@ -126,6 +147,18 @@ def __init__(self, x: Union[datetime.datetime, A]): with self.assertRaises(DeserializationError): jsons.load({'x': 'no match in the union'}, C).x + def test_load_union_syntax(self): + class A: + def __init__(self, x: int | float): + self.x = x + + self.assertEqual(1, jsons.load({'x': 1}, A).x) + self.assertEqual(2.0, jsons.load({'x': 2.0}, A).x) + + # Test Union with invalid value. + with self.assertRaises(DeserializationError): + jsons.load({'x': 'no match in the union'}, A).x + def test_load_none(self): class C: def __init__(self, x: int, y: Optional[int]):