Skip to content

Commit 41e437d

Browse files
authored
Merge pull request #98 from nineaiyu/ndev
Ndev
2 parents 46caa47 + 31a4d42 commit 41e437d

File tree

27 files changed

+530
-200
lines changed

27 files changed

+530
-200
lines changed

common/apps.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def ready(self):
1515
from .celery import heatbeat # noqa
1616
from . import signal_handlers # noqa
1717
from . import tasks # noqa
18-
from .swagger.utils import OpenApiAuthenticationScheme, OpenApiPrimaryKeyRelatedField # noqa
1918
from .signals import django_ready
2019
excludes = ['migrate', 'compilemessages', 'makemigrations', 'stop']
2120
for i in excludes:

common/celery/logger.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def flush(self):
9797
f.flush()
9898

9999
def handle_task_start(self, task_id):
100-
log_path = get_celery_task_log_path(task_id.split('_')[0])
100+
# log_path = get_celery_task_log_path(task_id.split('_')[0])
101+
log_path = get_celery_task_log_path(task_id)
101102
thread_id = self.get_current_thread_id()
102103
self.task_id_thread_id_mapper[task_id] = thread_id
103104
f = open(log_path, 'ab')

common/core/db/utils.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,27 +107,40 @@ def get_filter_attrs_qs(cls, rules):
107107
return filters
108108

109109

110-
def close_old_connections():
111-
for conn in connections.all():
110+
def close_old_connections(**kwargs):
111+
for conn in connections.all(initialized_only=True):
112112
conn.close_if_unusable_or_obsolete()
113113

114114

115+
# 这个要是在 Django 请求周期外使用的,不能影响 Django 的事务管理, 在 api 中使用会影响 api 事务
115116
@contextmanager
116117
def safe_db_connection():
117-
in_atomic_block = connection.in_atomic_block # 当前是否处于事务中
118-
autocommit = transaction.get_autocommit() # 是否启用了自动提交
119-
created = False
118+
close_old_connections()
119+
yield
120+
close_old_connections()
121+
122+
123+
@contextmanager
124+
def safe_atomic_db_connection(auto_close=False):
125+
"""
126+
通用数据库连接管理器(线程安全、事务感知):
127+
- 在连接不可用时主动重建连接
128+
- 在非事务环境下自动关闭连接(可选)
129+
- 不影响 Django 请求/事务周期
130+
"""
131+
in_atomic = connection.in_atomic_block # 当前是否在事务中
132+
autocommit = transaction.get_autocommit()
133+
recreated = False
120134

121135
try:
122136
if not connection.is_usable():
123137
connection.close()
124138
connection.connect()
125-
created = True
139+
recreated = True
126140
yield
127141
finally:
128-
# 如果不是事务中(API 请求中可能需要提交事务),则关闭连接
129-
if created and not in_atomic_block and autocommit:
130-
print("close connection in safe_db_connection")
142+
# 只在非事务、autocommit 模式下,才考虑主动清理连接
143+
if auto_close or (recreated and not in_atomic and autocommit):
131144
close_old_connections()
132145

133146

common/core/fields.py

Lines changed: 155 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
from django.db.models.fields.files import FieldFile
1414
from django.utils.translation import gettext_lazy as _
1515
from rest_framework import serializers
16-
from rest_framework.fields import ChoiceField
1716
from rest_framework.request import Request
18-
from rest_framework.serializers import RelatedField, MultipleChoiceField
1917

2018
from common.core.filter import get_filter_queryset
2119
from common.fields.utils import get_file_absolute_uri
@@ -33,22 +31,56 @@ def func(obj):
3331
return func(obj)
3432

3533

36-
class LabeledChoiceField(ChoiceField):
34+
class LabeledChoiceField(serializers.ChoiceField):
35+
def __init__(self, **kwargs):
36+
self.attrs = kwargs.pop("attrs", None) or ("value", "label")
37+
super().__init__(**kwargs)
38+
3739
def to_representation(self, key):
3840
if key is None:
3941
return key
4042
label = self.choices.get(key, key)
4143
return {"value": key, "label": label}
4244

4345
def to_internal_value(self, data):
46+
if not data:
47+
return data
4448
if isinstance(data, dict):
4549
data = data.get("value")
4650
if isinstance(data, str) and "(" in data and data.endswith(")"):
4751
data = data.strip(")").split('(')[-1]
4852
return super(LabeledChoiceField, self).to_internal_value(data)
4953

50-
51-
class LabeledMultipleChoiceField(MultipleChoiceField):
54+
def get_schema(self):
55+
"""
56+
为 drf-spectacular 提供 OpenAPI schema
57+
"""
58+
if getattr(self, 'many', False):
59+
return {
60+
'type': 'array',
61+
'items': {
62+
'type': 'object',
63+
'properties': {
64+
'value': {'type': 'string'},
65+
'label': {'type': 'string'}
66+
}
67+
},
68+
'description': getattr(self, 'help_text', ''),
69+
'title': getattr(self, 'label', ''),
70+
}
71+
else:
72+
return {
73+
'type': 'object',
74+
'properties': {
75+
'value': {'type': 'string'},
76+
'label': {'type': 'string'}
77+
},
78+
'description': getattr(self, 'help_text', ''),
79+
'title': getattr(self, 'label', ''),
80+
}
81+
82+
83+
class LabeledMultipleChoiceField(serializers.MultipleChoiceField):
5284
def __init__(self, **kwargs):
5385
super().__init__(**kwargs)
5486
self.choice_mapper = {
@@ -73,7 +105,7 @@ def to_internal_value(self, data):
73105
return data
74106

75107

76-
class BasePrimaryKeyRelatedField(RelatedField):
108+
class BasePrimaryKeyRelatedField(serializers.RelatedField):
77109
"""
78110
Base class for primary key related fields.
79111
"""
@@ -89,7 +121,7 @@ def __init__(self, attrs=None, ignore_field_permission=False, **kwargs):
89121
:param attrs: 默认为 None,返回默认的 pk, 一般需要自定义
90122
:param ignore_field_permission: 忽略字段权限控制
91123
"""
92-
self.attrs = attrs
124+
self.attrs = attrs if attrs else ["pk"]
93125
self.label_format = kwargs.pop("format", None)
94126
self.input_type = kwargs.pop("input_type", None)
95127
self.input_type_prefix = kwargs.pop("input_type_prefix", None)
@@ -221,6 +253,122 @@ def to_internal_value(self, data):
221253
except (TypeError, ValueError):
222254
self.fail("incorrect_type", data_type=type(pk).__name__)
223255

256+
def get_schema(self):
257+
"""
258+
为 drf-spectacular 提供 OpenAPI schema
259+
"""
260+
# 获取字段的基本信息
261+
field_type = 'array' if self.many else 'object'
262+
263+
if field_type == 'array':
264+
# 如果是多对多关系
265+
return {
266+
'type': 'array',
267+
'items': self._get_openapi_item_schema(),
268+
'description': getattr(self, 'help_text', ''),
269+
'title': getattr(self, 'label', ''),
270+
}
271+
else:
272+
# 如果是一对一关系
273+
return {
274+
'type': 'object',
275+
'properties': self._get_openapi_properties_schema(),
276+
'description': getattr(self, 'help_text', ''),
277+
'title': getattr(self, 'label', ''),
278+
}
279+
280+
def _get_openapi_item_schema(self):
281+
"""
282+
获取数组项的 OpenAPI schema
283+
"""
284+
return self._get_openapi_object_schema()
285+
286+
def _get_openapi_object_schema(self):
287+
"""
288+
获取对象的 OpenAPI schema
289+
"""
290+
properties = {}
291+
292+
# 动态分析 attrs 中的属性类型
293+
for attr in self.attrs:
294+
# 尝试从 queryset 的 model 中获取字段信息
295+
field_type = self._infer_field_type(attr)
296+
properties[attr] = {
297+
'type': field_type,
298+
'description': f'{attr} field'
299+
}
300+
301+
return {
302+
'type': 'object',
303+
'properties': properties,
304+
'required': ['id'] if 'id' in self.attrs else []
305+
}
306+
307+
def _infer_field_type(self, attr_name):
308+
"""
309+
智能推断字段类型
310+
"""
311+
try:
312+
# 如果有 queryset,尝试从 model 中获取字段信息
313+
if hasattr(self, 'queryset') and self.queryset is not None:
314+
model = self.queryset.model
315+
if hasattr(model, '_meta') and hasattr(model._meta, 'fields'):
316+
field = model._meta.get_field(attr_name)
317+
if field:
318+
return self._map_django_field_type(field)
319+
except Exception:
320+
pass
321+
322+
# 如果没有 queryset 或无法获取字段信息,使用启发式规则
323+
return self._heuristic_field_type(attr_name)
324+
325+
def _map_django_field_type(self, field):
326+
"""
327+
将 Django 字段类型映射到 OpenAPI 类型
328+
"""
329+
field_type = type(field).__name__
330+
331+
# 整数类型
332+
if 'Integer' in field_type or 'BigInteger' in field_type or 'SmallInteger' in field_type:
333+
return 'integer'
334+
# 浮点数类型
335+
elif 'Float' in field_type or 'Decimal' in field_type:
336+
return 'number'
337+
# 布尔类型
338+
elif 'Boolean' in field_type:
339+
return 'boolean'
340+
# 日期时间类型
341+
elif 'DateTime' in field_type or 'Date' in field_type or 'Time' in field_type:
342+
return 'string'
343+
# 文件类型
344+
elif 'File' in field_type or 'Image' in field_type:
345+
return 'string'
346+
# 其他类型默认为字符串
347+
else:
348+
return 'string'
349+
350+
def _heuristic_field_type(self, attr_name):
351+
"""
352+
启发式推断字段类型
353+
"""
354+
# 基于属性名的启发式规则
355+
356+
if attr_name in ['is_active', 'enabled', 'visible'] or attr_name.startswith('is_'):
357+
return 'boolean'
358+
elif attr_name in ['count', 'number', 'size', 'amount']:
359+
return 'integer'
360+
elif attr_name in ['price', 'rate', 'percentage']:
361+
return 'number'
362+
else:
363+
# 默认返回字符串类型
364+
return 'string'
365+
366+
def _get_openapi_properties_schema(self):
367+
"""
368+
获取对象属性的 OpenAPI schema
369+
"""
370+
return self._get_openapi_object_schema()['properties']
371+
224372

225373
class PhoneField(serializers.CharField):
226374

common/core/serializers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ class Meta:
2828
table_fields = [] # 用于控制前端table的字段展示
2929
tabs = []
3030

31+
def get_field_names(self, declared_fields, info):
32+
"""将默认的id字段 转换为 pk"""
33+
fields = super().get_field_names(declared_fields, info)
34+
if 'id' in fields:
35+
return ['pk'] + [f for f in fields if f != 'id']
36+
return fields
37+
3138
def get_value(self, dictionary):
3239
# We override the default field access in order to support
3340
# nested HTML forms.

common/drf/parsers/csv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
# ~*~ coding: utf-8 ~*~
22
#
3+
from functools import cached_property
34

45
import chardet
56
import unicodecsv
67

78
from .base import BaseFileParser
89
from ..const import CSV_FILE_ESCAPE_CHARS
9-
from ..utils import lazyproperty
1010

1111

1212
class CSVFileParser(BaseFileParser):
1313
media_type = 'text/csv'
1414

15-
@lazyproperty
15+
@cached_property
1616
def match_escape_chars(self):
1717
chars = []
1818
for c in CSV_FILE_ESCAPE_CHARS:

common/drf/utils.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

common/management/commands/services/hands.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,16 @@ def compile_i18n_file():
9494

9595

9696
def download_ip_db(force=False):
97-
db_base_dir = os.path.join(APPS_DIR, 'common', 'utils', 'ip')
9897
db_path_url_mapper = {
99-
('geoip', 'GeoLite2-City.mmdb'): 'https://jms-pkg.oss-cn-beijing.aliyuncs.com/ip/GeoLite2-City.mmdb',
100-
('ipip', 'ipipfree.ipdb'): 'https://jms-pkg.oss-cn-beijing.aliyuncs.com/ip/ipipfree.ipdb'
98+
('system', 'GeoLite2-City.mmdb'): 'https://jms-pkg.oss-cn-beijing.aliyuncs.com/ip/GeoLite2-City.mmdb',
99+
('system', 'ipipfree.ipdb'): 'https://jms-pkg.oss-cn-beijing.aliyuncs.com/ip/ipipfree.ipdb'
101100
}
102101
for p, src in db_path_url_mapper.items():
103-
path = os.path.join(db_base_dir, *p)
102+
path = os.path.join(settings.DATA_DIR, *p)
104103
if not force and os.path.isfile(path) and os.path.getsize(path) > 1000:
105104
continue
106105
logger.info("Download ip db: {}".format(path))
106+
os.makedirs(os.path.dirname(path), exist_ok=True)
107107
download_file(src, path)
108108

109109

common/signal_handlers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from common.celery.decorator import get_after_app_ready_tasks, get_after_app_shutdown_clean_tasks
2424
from common.celery.logger import CeleryThreadTaskFileHandler
2525
from common.celery.utils import get_celery_task_log_path
26+
from common.signals import django_ready
2627
from common.utils import get_logger
2728
from server.utils import get_current_request
2829

@@ -162,3 +163,8 @@ def on_update_set_modifier(sender, instance=None, **kwargs):
162163

163164
if settings.DEBUG_DEV:
164165
request_finished.connect(on_request_finished_logging_db_query)
166+
167+
168+
@receiver(django_ready)
169+
def clear_response_cache(sender, **kwargs):
170+
cache.delete_pattern('magic_cache_response_*')

0 commit comments

Comments
 (0)