Skip to content

Commit ef16e13

Browse files
author
chenxl
committed
fix marshmallow==2.x bug and optimize code.
1 parent 666784a commit ef16e13

File tree

4 files changed

+55
-28
lines changed

4 files changed

+55
-28
lines changed

README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ $ pip install -U Flasgger-Marshmallow
1818
import logging
1919

2020
from flasgger import Swagger
21+
# use basePath from flasgger_marshmallow import Swagger
2122
from flask import Flask, request
2223
from flask_restful import Api, Resource
2324
from marshmallow import Schema, fields
@@ -104,7 +105,11 @@ class responseHeadersSchema(Schema):
104105
X_RateLimit_Limit = fields.Integer(
105106
required=True, default=1, doc='Request limit per hour',
106107
data_key='X-RateLimit-Limit'
107-
)
108+
) # marshmallow 3
109+
# X_RateLimit_Limit = fields.Integer(
110+
# required=True, default=1, doc='Request limit per hour',
111+
# load_from='X-RateLimit-Limit', dump_to='X-RateLimit-Limit'
112+
# ) # marshmallow 2
108113

109114
class Meta:
110115
strict = True
@@ -151,7 +156,7 @@ class User(Resource):
151156

152157
# 获取校验后的数据
153158
logger.info('%s, %s', type(request.query_schema), request.query_schema)
154-
return {'user_name': '陈小龙'}
159+
return {"count": 1, "page": 1, "users": [{'username': '陈小龙'}]}
155160

156161
@swagger_decorator(query_schema=QueryUserSchema, response_schema={302: RedirectResponseSchema})
157162
def put(self):
@@ -192,7 +197,9 @@ class Username(Resource):
192197

193198
@swagger_decorator(path_schema=UsernamePathSchema,
194199
form_schema=UpdateUserSchema,
195-
response_schema={200: UserDetailResponseSchema})
200+
response_schema={200: UserDetailResponseSchema},
201+
tags=["AAA"]
202+
)
196203
def put(self, username):
197204
"""
198205
更新用户信息

example/flask_swagger.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

33
from flasgger import Swagger
4+
# use basePath from flasgger_marshmallow import Swagger
45
from flask import Flask, request
56
from flask_restful import Api, Resource
67
from marshmallow import Schema, fields
@@ -87,7 +88,11 @@ class responseHeadersSchema(Schema):
8788
X_RateLimit_Limit = fields.Integer(
8889
required=True, default=1, doc='Request limit per hour',
8990
data_key='X-RateLimit-Limit'
90-
)
91+
) # marshmallow 3
92+
# X_RateLimit_Limit = fields.Integer(
93+
# required=True, default=1, doc='Request limit per hour',
94+
# load_from='X-RateLimit-Limit', dump_to='X-RateLimit-Limit'
95+
# ) # marshmallow 2
9196

9297
class Meta:
9398
strict = True
@@ -134,7 +139,7 @@ def get(self):
134139

135140
# 获取校验后的数据
136141
logger.info('%s, %s', type(request.query_schema), request.query_schema)
137-
return {'user_name': '陈小龙'}
142+
return {"count": 1, "page": 1, "users": [{'username': '陈小龙'}]}
138143

139144
@swagger_decorator(query_schema=QueryUserSchema, response_schema={302: RedirectResponseSchema})
140145
def put(self):

flasgger_marshmallow/__init__.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
import yaml
88
from flasgger.base import Swagger as FSwagger
99
from flasgger.constants import OPTIONAL_FIELDS
10-
from flasgger.constants import OPTIONAL_OAS3_FIELDS
10+
try:
11+
from flasgger.constants import OPTIONAL_OAS3_FIELDS
12+
except :
13+
OPTIONAL_OAS3_FIELDS = [
14+
'components', 'servers'
15+
]
1116
from flasgger.utils import extract_definitions
1217
from flasgger.utils import get_specs
1318
from flasgger.utils import get_vendor_extension_fields
@@ -44,14 +49,27 @@
4449
fields.Int: 'number',
4550
}
4651

47-
if int(marshmallow.__version__.split('.')[1]) == 3:
52+
if int(marshmallow.__version__.split('.')[0]) == 3:
4853
FIELDS_JSON_TYPE_MAP.update({
4954
fields.NaiveDateTime: 'string',
5055
fields.AwareDateTime: 'string',
5156
fields.Tuple: 'array',
5257
})
5358

5459

60+
def is_marsh_v3():
61+
return int(marshmallow.__version__.split('.')[0]) == 3
62+
63+
64+
def data_schema(schema, data):
65+
data = schema().load(data or {})
66+
if not is_marsh_v3():
67+
data = schema().dump(data.data).data
68+
else:
69+
data = schema().dump(data)
70+
return data
71+
72+
5573
class Swagger(FSwagger):
5674

5775
def get_apispecs(self, endpoint='apispec_1'):
@@ -323,9 +341,13 @@ def parse_simple_schema(c_schema, location):
323341
values_real_types.sort(key=value.__class__.__mro__.index)
324342
if not values_real_types:
325343
raise '不支持的%s类型' % str(type(value))
344+
if is_marsh_v3():
345+
name = getattr(value, 'data_key', None) or key
346+
else:
347+
name = getattr(value, 'load_from', None) or key
326348
tmp = {
327349
'in': location,
328-
'name': getattr(value, 'data_key', None) or key,
350+
'name': name,
329351
'type': FIELDS_JSON_TYPE_MAP.get(values_real_types[0]),
330352
'required': value.required if location != 'path' else True,
331353
'description': value.metadata.get('doc', '')
@@ -339,7 +361,10 @@ def parse_json_schema(r_s):
339361
tmp = {}
340362
for key, value in (
341363
r_s.__dict__.get('_declared_fields') or r_s.__dict__.get('declared_fields') or {}).items():
342-
key = getattr(value, 'data_key', None) or key
364+
if is_marsh_v3():
365+
key = getattr(value, 'data_key', None) or key
366+
else:
367+
key = getattr(value, 'load_from', None) or key
343368
if isinstance(value, fields.Nested):
344369
if value.many:
345370
tmp[key] = {
@@ -453,19 +478,11 @@ def wrapper(*args, **kw):
453478
request.path_schema, request.path_schema, request.form_schema = [None] * 3
454479
request.json_schema, request.headers_schema = [None] * 2
455480
try:
456-
if __version__.startswith('3.'):
457-
path_schema and setattr(request, 'path_schema', path_schema().load(path_params or {}))
458-
query_schema and setattr(request, 'query_schema', query_schema().load(query_params or {}))
459-
form_schema and setattr(request, 'form_schema', form_schema().load(form_params or {}))
460-
json_schema and setattr(request, 'json_schema', json_schema().load(json_params or {}))
461-
headers_schema and setattr(request, 'headers_schema', headers_schema().load(dict(header_params)))
462-
else:
463-
path_schema and setattr(request, 'path_schema', path_schema().load(path_params or {}).data)
464-
query_schema and setattr(request, 'query_schema', query_schema().load(query_params or {}).data)
465-
form_schema and setattr(request, 'form_schema', form_schema().load(form_params or {}).data)
466-
json_schema and setattr(request, 'json_schema', json_schema().load(json_params or {}).data)
467-
headers_schema and setattr(request, 'headers_schema',
468-
headers_schema().load(dict(header_params)).data)
481+
path_schema and setattr(request, 'path_schema', data_schema(path_schema, path_params))
482+
query_schema and setattr(request, 'query_schema', data_schema(query_schema, query_params))
483+
form_schema and setattr(request, 'form_schema', data_schema(form_schema, form_params))
484+
json_schema and setattr(request, 'json_schema', data_schema(json_schema, json_params))
485+
headers_schema and setattr(request, 'headers_schema', data_schema(headers_schema, dict(header_params)))
469486
except Exception as e:
470487
return 'request error: %s' % ''.join(
471488
[('%s: %s; ' % (x, ''.join(y))) for x, y in e.messages.items()]), 400
@@ -474,12 +491,10 @@ def wrapper(*args, **kw):
474491
logger.info('response data\ndata: %s\ncode: %s\nheaders: %s\n', data, code, headers)
475492
try:
476493
if response_schema and response_schema.get(code):
477-
data = response_schema.get(code)().load(data or {})
478-
if not __version__.startswith('3.'):
479-
data = data.data
494+
data = data_schema(response_schema.get(code), data)
480495
r_headers_schema = getattr(response_schema.get(code).Meta, 'headers', None)
481496
if r_headers_schema:
482-
headers = r_headers_schema().load(headers or {})
497+
headers = data_schema(r_headers_schema, headers)
483498
except Exception as e:
484499
return 'response error: %s' % ''.join(
485500
[('%s: %s; ' % (x, ''.join(y))) for x, y in e.messages.items()]), 400

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727

2828
URL = 'https://github.com/flask-rabmq/Flasgger-Marshmallow'
2929

30-
VERSION = "0.0.7"
30+
VERSION = "0.0.8"
3131

3232
LICENSE = "MIT"
3333

34-
INSTALL_REQUIRES = ["flask>=1.0.0", "flasgger>=0.9.5", "marshmallow>=2.18.1", "PyYAML"]
34+
INSTALL_REQUIRES = ["flask>=1.0.0, <2.0.0", "Jinja2>=2.10.1, <3.0", "flasgger>=0.9.3", "marshmallow>=2.18.1", "PyYAML"]
3535

3636
setup(
3737
name=NAME,

0 commit comments

Comments
 (0)