77import yaml
88from flasgger .base import Swagger as FSwagger
99from flasgger .constants import OPTIONAL_FIELDS
10- from flasgger .constants import OPTIONAL_OAS3_FIELDS
10+ try :
11+ from flasgger .constants import OPTIONAL_OAS3_FIELDS
12+ except :
13+ OPTIONAL_OAS3_FIELDS = [
14+ 'components' , 'servers'
15+ ]
1116from flasgger .utils import extract_definitions
1217from flasgger .utils import get_specs
1318from flasgger .utils import get_vendor_extension_fields
4449 fields .Int : 'number' ,
4550}
4651
47- if int (marshmallow .__version__ .split ('.' )[1 ]) == 3 :
52+ if int (marshmallow .__version__ .split ('.' )[0 ]) == 3 :
4853 FIELDS_JSON_TYPE_MAP .update ({
4954 fields .NaiveDateTime : 'string' ,
5055 fields .AwareDateTime : 'string' ,
5156 fields .Tuple : 'array' ,
5257 })
5358
5459
60+ def is_marsh_v3 ():
61+ return int (marshmallow .__version__ .split ('.' )[0 ]) == 3
62+
63+
64+ def data_schema (schema , data ):
65+ data = schema ().load (data or {})
66+ if not is_marsh_v3 ():
67+ data = schema ().dump (data .data ).data
68+ else :
69+ data = schema ().dump (data )
70+ return data
71+
72+
5573class Swagger (FSwagger ):
5674
5775 def get_apispecs (self , endpoint = 'apispec_1' ):
@@ -323,9 +341,13 @@ def parse_simple_schema(c_schema, location):
323341 values_real_types .sort (key = value .__class__ .__mro__ .index )
324342 if not values_real_types :
325343 raise '不支持的%s类型' % str (type (value ))
344+ if is_marsh_v3 ():
345+ name = getattr (value , 'data_key' , None ) or key
346+ else :
347+ name = getattr (value , 'load_from' , None ) or key
326348 tmp = {
327349 'in' : location ,
328- 'name' : getattr ( value , 'data_key' , None ) or key ,
350+ 'name' : name ,
329351 'type' : FIELDS_JSON_TYPE_MAP .get (values_real_types [0 ]),
330352 'required' : value .required if location != 'path' else True ,
331353 'description' : value .metadata .get ('doc' , '' )
@@ -339,7 +361,10 @@ def parse_json_schema(r_s):
339361 tmp = {}
340362 for key , value in (
341363 r_s .__dict__ .get ('_declared_fields' ) or r_s .__dict__ .get ('declared_fields' ) or {}).items ():
342- key = getattr (value , 'data_key' , None ) or key
364+ if is_marsh_v3 ():
365+ key = getattr (value , 'data_key' , None ) or key
366+ else :
367+ key = getattr (value , 'load_from' , None ) or key
343368 if isinstance (value , fields .Nested ):
344369 if value .many :
345370 tmp [key ] = {
@@ -453,19 +478,11 @@ def wrapper(*args, **kw):
453478 request .path_schema , request .path_schema , request .form_schema = [None ] * 3
454479 request .json_schema , request .headers_schema = [None ] * 2
455480 try :
456- if __version__ .startswith ('3.' ):
457- path_schema and setattr (request , 'path_schema' , path_schema ().load (path_params or {}))
458- query_schema and setattr (request , 'query_schema' , query_schema ().load (query_params or {}))
459- form_schema and setattr (request , 'form_schema' , form_schema ().load (form_params or {}))
460- json_schema and setattr (request , 'json_schema' , json_schema ().load (json_params or {}))
461- headers_schema and setattr (request , 'headers_schema' , headers_schema ().load (dict (header_params )))
462- else :
463- path_schema and setattr (request , 'path_schema' , path_schema ().load (path_params or {}).data )
464- query_schema and setattr (request , 'query_schema' , query_schema ().load (query_params or {}).data )
465- form_schema and setattr (request , 'form_schema' , form_schema ().load (form_params or {}).data )
466- json_schema and setattr (request , 'json_schema' , json_schema ().load (json_params or {}).data )
467- headers_schema and setattr (request , 'headers_schema' ,
468- headers_schema ().load (dict (header_params )).data )
481+ path_schema and setattr (request , 'path_schema' , data_schema (path_schema , path_params ))
482+ query_schema and setattr (request , 'query_schema' , data_schema (query_schema , query_params ))
483+ form_schema and setattr (request , 'form_schema' , data_schema (form_schema , form_params ))
484+ json_schema and setattr (request , 'json_schema' , data_schema (json_schema , json_params ))
485+ headers_schema and setattr (request , 'headers_schema' , data_schema (headers_schema , dict (header_params )))
469486 except Exception as e :
470487 return 'request error: %s' % '' .join (
471488 [('%s: %s; ' % (x , '' .join (y ))) for x , y in e .messages .items ()]), 400
@@ -474,12 +491,10 @@ def wrapper(*args, **kw):
474491 logger .info ('response data\n data: %s\n code: %s\n headers: %s\n ' , data , code , headers )
475492 try :
476493 if response_schema and response_schema .get (code ):
477- data = response_schema .get (code )().load (data or {})
478- if not __version__ .startswith ('3.' ):
479- data = data .data
494+ data = data_schema (response_schema .get (code ), data )
480495 r_headers_schema = getattr (response_schema .get (code ).Meta , 'headers' , None )
481496 if r_headers_schema :
482- headers = r_headers_schema (). load ( headers or {} )
497+ headers = data_schema ( r_headers_schema , headers )
483498 except Exception as e :
484499 return 'response error: %s' % '' .join (
485500 [('%s: %s; ' % (x , '' .join (y ))) for x , y in e .messages .items ()]), 400
0 commit comments