Skip to content

Commit cd2d5af

Browse files
authored
Merge pull request #247 from python-ellar/route_resolve_alias
fix:Route Resolvers with custom alias and name
2 parents fb6afa4 + 63d5abf commit cd2d5af

File tree

3 files changed

+91
-51
lines changed

3 files changed

+91
-51
lines changed

ellar/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
"""Ellar - Python ASGI web framework for building fast, efficient, and scalable RESTful APIs and server-side applications."""
22

33
__version__ = "0.8.4"
4-

ellar/common/params/resolvers/base.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,15 @@ async def resolve(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult:
4444

4545
@abstractmethod
4646
@t.no_type_check
47-
def create_raw_data(self, data: t.Any) -> t.Dict:
47+
def create_raw_data(
48+
self, data: t.Any, field_name: t.Optional[str] = None
49+
) -> t.Dict:
4850
"""
4951
Creates the raw data for the parameter.
5052
5153
Args:
5254
data: The resolved value of the parameter.
55+
field_name: The name of the field.
5356
5457
Returns:
5558
`dict`: A dictionary containing the raw data.
@@ -62,8 +65,10 @@ def __init__(self, model_field: ModelField, *args: t.Any, **kwargs: t.Any) -> No
6265
RouteParameterModelField, model_field
6366
)
6467

65-
def create_raw_data(self, data: t.Any) -> t.Dict:
66-
return {self.model_field.name: data}
68+
def create_raw_data(
69+
self, data: t.Any, field_name: t.Optional[str] = None
70+
) -> t.Dict:
71+
return {field_name or self.model_field.name: data}
6772

6873
def assert_field_info(self) -> None:
6974
"""
@@ -91,13 +96,21 @@ async def resolve(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult:
9196

9297
@abstractmethod
9398
@t.no_type_check
94-
async def resolve_handle(self, *args: t.Any, **kwargs: t.Any) -> ResolverResult:
99+
async def resolve_handle(
100+
self,
101+
*args: t.Any,
102+
alias: t.Optional[str] = None,
103+
name: t.Optional[str] = None,
104+
**kwargs: t.Any,
105+
) -> ResolverResult:
95106
"""
96107
Resolves the value of the parameter during request processing.
97108
98109
Args:
99110
*args: Additional positional arguments.
100111
**kwargs: Additional keyword arguments.
112+
alias: The alias of the parameter. Optional.
113+
name: The name of the parameter. Optional.
101114
102115
Returns:
103116
`ResolverResult`: A named tuple containing the resolved value, any errors, and the raw data.

ellar/common/params/resolvers/parameter.py

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -43,43 +43,47 @@ def get_received_parameter(
4343
return connection.headers
4444

4545
async def resolve_handle(
46-
self, ctx: IExecutionContext, *args: t.Any, **kwargs: t.Any
46+
self,
47+
ctx: IExecutionContext,
48+
*args: t.Any,
49+
alias: t.Optional[str] = None,
50+
name: t.Optional[str] = None,
51+
**kwargs: t.Any,
4752
) -> ResolverResult:
53+
alias = alias or self.model_field.alias
54+
name = name or self.model_field.name
4855
request_logger.debug(
4956
f"Resolving Header Parameters - '{self.__class__.__name__}'"
5057
)
5158
received_params = self.get_received_parameter(ctx=ctx)
5259
if is_sequence_field(self.model_field):
53-
value = (
54-
received_params.getlist(self.model_field.alias)
55-
or self.model_field.default
56-
)
60+
value = received_params.getlist(alias) or self.model_field.default
5761
else:
58-
value = received_params.get(self.model_field.alias)
62+
value = received_params.get(alias)
5963
self.assert_field_info()
6064
field_info = self.model_field.field_info
6165
values = {}
6266
if value is None:
6367
if self.model_field.required:
64-
errors = [
65-
self.create_error(
66-
loc=(field_info.in_.value, self.model_field.alias)
67-
)
68-
]
69-
return ResolverResult({}, errors, raw_data=self.create_raw_data(value))
68+
errors = [self.create_error(loc=(field_info.in_.value, alias))]
69+
return ResolverResult(
70+
{}, errors, raw_data=self.create_raw_data(value, field_name=name)
71+
)
7072
else:
7173
value = copy.deepcopy(self.model_field.default)
72-
values[self.model_field.name] = value
73-
return ResolverResult(values, [], raw_data=self.create_raw_data(value))
74+
values[name] = value
75+
return ResolverResult(
76+
values, [], raw_data=self.create_raw_data(value, field_name=name)
77+
)
7478

7579
v_, errors_ = self.model_field.validate(
76-
value, values, loc=(field_info.in_.value, self.model_field.alias)
80+
value, values, loc=(field_info.in_.value, alias)
7781
)
7882

7983
return ResolverResult(
80-
data={self.model_field.name: v_},
84+
data={name: v_},
8185
errors=self.validate_error_sequence(errors_),
82-
raw_data=self.create_raw_data(value),
86+
raw_data=self.create_raw_data(value, field_name=name),
8387
)
8488

8589

@@ -99,22 +103,28 @@ def get_received_parameter(cls, ctx: IExecutionContext) -> t.Mapping[str, t.Any]
99103
return connection.path_params
100104

101105
async def resolve_handle(
102-
self, ctx: IExecutionContext, **kwargs: t.Any
106+
self,
107+
ctx: IExecutionContext,
108+
alias: t.Optional[str] = None,
109+
name: t.Optional[str] = None,
110+
**kwargs: t.Any,
103111
) -> ResolverResult:
112+
alias = alias or self.model_field.alias
113+
name = name or self.model_field.name
104114
request_logger.debug(f"Resolving Path Parameters - '{self.__class__.__name__}'")
105115
received_params = self.get_received_parameter(ctx=ctx)
106-
value = received_params.get(str(self.model_field.alias))
116+
value = received_params.get(str(alias))
107117
self.assert_field_info()
108118

109119
v_, errors_ = self.model_field.validate(
110120
value,
111121
{},
112-
loc=(self.model_field.field_info.in_.value, self.model_field.alias),
122+
loc=(self.model_field.field_info.in_.value, alias),
113123
)
114124
return ResolverResult(
115-
data={self.model_field.name: v_},
125+
data={name: v_},
116126
errors=self.validate_error_sequence(errors_),
117-
raw_data=self.create_raw_data(value),
127+
raw_data=self.create_raw_data(value, field_name=name),
118128
)
119129

120130

@@ -127,44 +137,54 @@ def get_received_parameter(cls, ctx: IExecutionContext) -> t.Mapping[str, t.Any]
127137

128138
class WsBodyParameterResolver(BaseRouteParameterResolver):
129139
async def resolve_handle(
130-
self, ctx: IExecutionContext, *args: t.Any, body: t.Any, **kwargs: t.Any
140+
self,
141+
ctx: IExecutionContext,
142+
*args: t.Any,
143+
body: t.Any,
144+
alias: t.Optional[str] = None,
145+
name: t.Optional[str] = None,
146+
**kwargs: t.Any,
131147
) -> t.Tuple:
148+
alias = alias or self.model_field.alias
149+
name = name or self.model_field.name
132150
request_logger.debug(
133151
f"Resolving Websocket Body Parameters - '{self.__class__.__name__}'"
134152
)
135153
embed = getattr(self.model_field.field_info, "embed", False)
136-
received_body = {self.model_field.alias: body}
154+
received_body = {alias: body}
137155
loc = ("body",)
138156
if embed:
139157
received_body = body
140-
loc = ("body", self.model_field.alias) # type:ignore
158+
loc = ("body", alias) # type:ignore
141159
try:
142-
value = received_body.get(self.model_field.alias)
160+
value = received_body.get(alias)
143161

144162
if value is None:
145163
if self.model_field.required:
146164
return ResolverResult(
147165
None,
148166
[self.create_error(loc=loc)],
149-
raw_data=self.create_raw_data(value),
167+
raw_data=self.create_raw_data(value, field_name=name),
150168
)
151169
else:
152170
value = copy.deepcopy(self.model_field.default)
153171
return ResolverResult(
154-
{self.model_field.name: value},
172+
{name: value},
155173
[],
156-
raw_data=self.create_raw_data(value),
174+
raw_data=self.create_raw_data(value, field_name=name),
157175
)
158176

159177
v_, errors_ = self.model_field.validate(value, {}, loc=loc)
160178
return ResolverResult(
161-
data={self.model_field.name: v_},
179+
data={name: v_},
162180
errors=self.validate_error_sequence(errors_),
163-
raw_data=self.create_raw_data(value),
181+
raw_data=self.create_raw_data(value, field_name=name),
164182
)
165183
except AttributeError:
166184
errors = [self.create_error(loc=loc)]
167-
return ResolverResult(None, errors, raw_data=self.create_raw_data(None))
185+
return ResolverResult(
186+
None, errors, raw_data=self.create_raw_data(None, field_name=name)
187+
)
168188

169189

170190
class BodyParameterResolver(WsBodyParameterResolver):
@@ -228,10 +248,10 @@ async def resolve_handle(
228248

229249
class FormParameterResolver(BodyParameterResolver):
230250
async def process_and_validate(
231-
self, *, values: t.Dict, value: t.Any, loc: t.Tuple
251+
self, *, values: t.Dict, value: t.Any, loc: t.Tuple, field_name: str
232252
) -> t.Tuple:
233253
v_, errors_ = self.model_field.validate(value, values, loc=loc)
234-
values[self.model_field.name] = v_
254+
values[field_name] = v_
235255
return ResolverResult(
236256
data=values,
237257
errors=self.validate_error_sequence(errors_),
@@ -257,22 +277,26 @@ async def resolve_handle(
257277
ctx: IExecutionContext,
258278
*args: t.Any,
259279
body: t.Optional[t.Any] = None,
280+
alias: t.Optional[str] = None,
281+
name: t.Optional[str] = None,
260282
**kwargs: t.Any,
261283
) -> t.Tuple:
284+
alias = alias or self.model_field.alias
285+
name = name or self.model_field.name
262286
_body = body or await self.get_request_body(ctx)
263287
embed = getattr(self.model_field.field_info, "embed", False)
264-
received_body = {self.model_field.alias: _body}
288+
received_body = {alias: _body}
265289
loc = ("body",)
266290

267291
if embed:
268292
received_body = _body
269-
loc = ("body", self.model_field.alias) # type:ignore
293+
loc = ("body", alias) # type:ignore
270294

271295
if is_sequence_field(self.model_field) and isinstance(_body, FormData):
272-
loc = ("body", self.model_field.alias) # type: ignore
273-
value = _body.getlist(self.model_field.alias)
296+
loc = ("body", alias) # type: ignore
297+
value = _body.getlist(alias)
274298
else:
275-
value = received_body.get(self.model_field.alias) # type: ignore
299+
value = received_body.get(alias) # type: ignore
276300

277301
if (
278302
value is None
@@ -281,17 +305,21 @@ async def resolve_handle(
281305
):
282306
if self.model_field.required:
283307
return ResolverResult(
284-
None, [self.create_error(loc=loc)], self.create_raw_data(value)
308+
None,
309+
[self.create_error(loc=loc)],
310+
self.create_raw_data(value, field_name=name),
285311
)
286312
else:
287313
value = copy.deepcopy(self.model_field.default)
288314
return ResolverResult(
289-
{self.model_field.name: value},
315+
{name: value},
290316
[],
291-
raw_data=self.create_raw_data(value),
317+
raw_data=self.create_raw_data(value, field_name=name),
292318
)
293319

294-
return await self.process_and_validate(values={}, value=value, loc=loc)
320+
return await self.process_and_validate(
321+
values={}, value=value, loc=loc, field_name=name
322+
)
295323

296324

297325
class FileParameterResolver(FormParameterResolver):
@@ -302,7 +330,7 @@ def __init__(self, *args: t.Any, **kwargs: t.Any):
302330
self._is_byte_list = is_bytes_sequence_annotation(self.model_field.type_)
303331

304332
async def process_and_validate(
305-
self, *, values: t.Dict, value: t.Any, loc: t.Tuple
333+
self, *, values: t.Dict, value: t.Any, loc: t.Tuple, field_name: str
306334
) -> t.Tuple:
307335
if self._is_byte and isinstance(value, StarletteUploadFile):
308336
value = await value.read()
@@ -321,10 +349,10 @@ async def process_fn(
321349
value = serialize_sequence_value(field=self.model_field, value=results)
322350

323351
v_, errors_ = self.model_field.validate(value, values, loc=loc)
324-
values[self.model_field.name] = v_
352+
values[field_name] = v_
325353

326354
return ResolverResult(
327355
data=values,
328356
errors=self.validate_error_sequence(errors_),
329-
raw_data=self.create_raw_data(value),
357+
raw_data=self.create_raw_data(value, field_name=field_name),
330358
)

0 commit comments

Comments
 (0)