|
| 1 | +from typing import Any, Union |
| 2 | + |
| 3 | +import pydantic |
| 4 | + |
| 5 | +import pytest |
| 6 | + |
| 7 | +from pydantic_partial import create_partial_model |
| 8 | +from pydantic_partial._compat import PYDANTIC_V1, PYDANTIC_V2 |
| 9 | + |
| 10 | + |
| 11 | +if PYDANTIC_V1: |
| 12 | + def _field_is_required( |
| 13 | + model: Union[type[pydantic.BaseModel], pydantic.BaseModel], |
| 14 | + field_name: str, |
| 15 | + ) -> bool: |
| 16 | + """Check if a field is required on a pydantic V1 model.""" |
| 17 | + # noinspection PyDeprecation |
| 18 | + return model.__fields__[field_name].required |
| 19 | + |
| 20 | + |
| 21 | + def _field_get_default( |
| 22 | + model: Union[type[pydantic.BaseModel], pydantic.BaseModel], |
| 23 | + field_name: str, |
| 24 | + ) -> tuple[Any, Any]: |
| 25 | + """Return field default info""" |
| 26 | + field_info = model.__fields__[field_name] |
| 27 | + return field_info.default, field_info.default_factory |
| 28 | +elif PYDANTIC_V2: |
| 29 | + def _field_is_required( |
| 30 | + model: Union[type[pydantic.BaseModel], pydantic.BaseModel], |
| 31 | + field_name: str, |
| 32 | + ) -> bool: |
| 33 | + """Check if a field is required on a pydantic V2 model.""" |
| 34 | + return model.model_fields[field_name].is_required() |
| 35 | + |
| 36 | + |
| 37 | + def _field_get_default( |
| 38 | + model: Union[type[pydantic.BaseModel], pydantic.BaseModel], |
| 39 | + field_name: str, |
| 40 | + ) -> tuple[Any, Any]: |
| 41 | + """Return field default info""" |
| 42 | + field_info = model.model_fields[field_name] |
| 43 | + return field_info.default, field_info.default_factory |
| 44 | +else: |
| 45 | + raise DeprecationWarning("Pydantic has to be in version 1 or 2.") |
| 46 | + |
| 47 | + |
| 48 | +class Something(pydantic.BaseModel): |
| 49 | + name: Union[str, None] = "Joe Doe" |
| 50 | + something_else_id: int |
| 51 | + |
| 52 | + |
| 53 | +PartialSomething = create_partial_model(Something, optional=False) |
| 54 | +PartialSomethingOptional = create_partial_model(Something, optional=True) |
| 55 | + |
| 56 | + |
| 57 | +def test_fields_not_required(): |
| 58 | + assert _field_is_required(PartialSomething, "name") is False |
| 59 | + assert _field_is_required(PartialSomething, "something_else_id") is False |
| 60 | + |
| 61 | + |
| 62 | +def test_field_defaults(): |
| 63 | + assert _field_get_default(PartialSomething, "name") == ("Joe Doe", None) |
| 64 | + assert _field_get_default(PartialSomething, "something_else_id") == (None, None) |
| 65 | + |
| 66 | + |
| 67 | +def test_validate_ok(): |
| 68 | + # It shouldn't be necessary to check that the right default values end |
| 69 | + # up in the models. That should already be done by pydantic's own tests. |
| 70 | + # We just check that validation succeeds. |
| 71 | + PartialSomething() |
| 72 | + PartialSomething(name='Jane Doe') |
| 73 | + PartialSomething(name=None) |
| 74 | + PartialSomething(something_else_id=42) |
| 75 | + PartialSomething(name='Jane Doe', something_else_id=42) |
| 76 | + PartialSomething(name=None, something_else_id=42) |
| 77 | + |
| 78 | + |
| 79 | +def test_validate_fail(): |
| 80 | + with pytest.raises(pydantic.ValidationError): |
| 81 | + PartialSomething(something_else_id=None) |
| 82 | + |
| 83 | + |
| 84 | +def test_validate_optional(): |
| 85 | + PartialSomethingOptional(something_else_id=None) |
0 commit comments