diff --git a/aiocache/__init__.py b/aiocache/__init__.py index 6e8715d00..d69fcce13 100644 --- a/aiocache/__init__.py +++ b/aiocache/__init__.py @@ -30,7 +30,11 @@ _AIOCACHE_CACHES.append(MemcachedCache) del aiomcache -from .decorators import cached, cached_stampede, multi_cached # noqa: E402,I202 +from .decorators import ( + cached, + cached_stampede, # noqa: E402,I202 + multi_cached, +) __all__ = ( "cached", diff --git a/aiocache/backends/memcached.py b/aiocache/backends/memcached.py index 029574165..b39a1fa69 100644 --- a/aiocache/backends/memcached.py +++ b/aiocache/backends/memcached.py @@ -90,7 +90,9 @@ async def _add(self, key, value, ttl=0, _conn=None): except aiomcache.exceptions.ValidationException as e: raise TypeError("aiomcache error: {}".format(str(e))) if not ret: - raise ValueError("Key {} already exists, use .set to update the value".format(key)) + raise ValueError( + "Key {} already exists, use .set to update the value".format(key) + ) return True diff --git a/aiocache/backends/memory.py b/aiocache/backends/memory.py index eddef5307..7ee96216e 100644 --- a/aiocache/backends/memory.py +++ b/aiocache/backends/memory.py @@ -28,7 +28,7 @@ class SimpleMemoryCache(BaseCache[str]): # TODO(PY312): https://peps.python.org/pep-0692/ def __init__(self, **kwargs): # Extract maxsize before passing kwargs to base class - self.maxsize = kwargs.pop('maxsize', None) + self.maxsize = kwargs.pop("maxsize", None) if "serializer" not in kwargs: kwargs["serializer"] = NullSerializer() super().__init__(**kwargs) diff --git a/aiocache/backends/valkey.py b/aiocache/backends/valkey.py index b6f0adb46..f216e9911 100644 --- a/aiocache/backends/valkey.py +++ b/aiocache/backends/valkey.py @@ -43,9 +43,7 @@ class ValkeyCache(BaseCache[str]): NAME = "valkey" - def __init__( - self, config: GlideClientConfiguration, **kwargs - ): + def __init__(self, config: GlideClientConfiguration, **kwargs): self.config = config if "serializer" not in kwargs: diff --git a/aiocache/base.py b/aiocache/base.py index 8aed65706..f23b86edc 100644 --- a/aiocache/base.py +++ b/aiocache/base.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from enum import Enum from types import TracebackType -from typing import Callable, Generic, List, Optional, Set, TYPE_CHECKING, Type, TypeVar +from typing import TYPE_CHECKING, Callable, Generic, List, Optional, Set, Type, TypeVar from aiocache.serializers import StringSerializer @@ -22,7 +22,6 @@ class API: - CMDS: Set[Callable[..., object]] = set() @classmethod @@ -79,7 +78,9 @@ def plugins(cls, func): async def _plugins(self, *args, **kwargs): start = time.monotonic() for plugin in self.plugins: - await getattr(plugin, "pre_{}".format(func.__name__))(self, *args, **kwargs) + await getattr(plugin, "pre_{}".format(func.__name__))( + self, *args, **kwargs + ) ret = await func(self, *args, **kwargs) @@ -152,7 +153,9 @@ def plugins(self, value): @API.aiocache_enabled(fake_return=True) @API.timeout @API.plugins - async def add(self, key, value, ttl=SENTINEL, dumps_fn=None, namespace=None, _conn=None): + async def add( + self, key, value, ttl=SENTINEL, dumps_fn=None, namespace=None, _conn=None + ): """ Stores the value in the given key with ttl if specified. Raises an error if the key already exists. @@ -205,9 +208,13 @@ async def get(self, key, default=None, loads_fn=None, namespace=None, _conn=None loads = loads_fn or self.serializer.loads ns_key = self.build_key(key, namespace) - value = loads(await self._get(ns_key, encoding=self.serializer.encoding, _conn=_conn)) + value = loads( + await self._get(ns_key, encoding=self.serializer.encoding, _conn=_conn) + ) - logger.debug("GET %s %s (%.4f)s", ns_key, value is not None, time.monotonic() - start) + logger.debug( + "GET %s %s (%.4f)s", ns_key, value is not None, time.monotonic() - start + ) return value if value is not None else default @abstractmethod @@ -262,7 +269,14 @@ async def _multi_get(self, keys, encoding, _conn=None): @API.timeout @API.plugins async def set( - self, key, value, ttl=SENTINEL, dumps_fn=None, namespace=None, _cas_token=None, _conn=None + self, + key, + value, + ttl=SENTINEL, + dumps_fn=None, + namespace=None, + _cas_token=None, + _conn=None, ): """ Stores the value in the given key with ttl if specified @@ -284,7 +298,11 @@ async def set( ns_key = self.build_key(key, namespace) res = await self._set( - ns_key, dumps(value), ttl=self._get_ttl(ttl), _cas_token=_cas_token, _conn=_conn + ns_key, + dumps(value), + ttl=self._get_ttl(ttl), + _cas_token=_cas_token, + _conn=_conn, ) logger.debug("SET %s %d (%.4f)s", ns_key, True, time.monotonic() - start) @@ -298,7 +316,9 @@ async def _set(self, key, value, ttl, _cas_token=None, _conn=None): @API.aiocache_enabled(fake_return=True) @API.timeout @API.plugins - async def multi_set(self, pairs, ttl=SENTINEL, dumps_fn=None, namespace=None, _conn=None): + async def multi_set( + self, pairs, ttl=SENTINEL, dumps_fn=None, namespace=None, _conn=None + ): """ Stores multiple values in the given keys. @@ -538,8 +558,10 @@ async def __aenter__(self): return self async def __aexit__( - self, exc_type: Optional[Type[BaseException]], - exc: Optional[BaseException], tb: Optional[TracebackType] + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], ) -> None: await self.close() @@ -562,7 +584,9 @@ def __getattr__(self, name): @classmethod def _inject_conn(cls, cmd_name): async def _do_inject_conn(self, *args, **kwargs): - return await getattr(self._cache, cmd_name)(*args, _conn=self._conn, **kwargs) + return await getattr(self._cache, cmd_name)( + *args, _conn=self._conn, **kwargs + ) return _do_inject_conn diff --git a/aiocache/decorators.py b/aiocache/decorators.py index 3322afce6..ed3d4c633 100644 --- a/aiocache/decorators.py +++ b/aiocache/decorators.py @@ -52,7 +52,13 @@ async def wrapper(*args, **kwargs): return wrapper async def decorator( - self, f, *args, cache_read=True, cache_write=True, aiocache_wait_for_write=True, **kwargs + self, + f, + *args, + cache_read=True, + cache_write=True, + aiocache_wait_for_write=True, + **kwargs, ): key = self.get_cache_key(f, args, kwargs) @@ -237,11 +243,19 @@ async def wrapper(*args, **kwargs): return wrapper async def decorator( - self, f, *args, cache_read=True, cache_write=True, aiocache_wait_for_write=True, **kwargs + self, + f, + *args, + cache_read=True, + cache_write=True, + aiocache_wait_for_write=True, + **kwargs, ): missing_keys = [] partial = {} - orig_keys, cache_keys, new_args, args_index = self.get_cache_keys(f, args, kwargs) + orig_keys, cache_keys, new_args, args_index = self.get_cache_keys( + f, args, kwargs + ) if cache_read: values = await self.get_from_cache(*cache_keys) @@ -303,7 +317,10 @@ async def get_from_cache(self, *keys): async def set_in_cache(self, result, fn, fn_args, fn_kwargs): try: await self.cache.multi_set( - [(self.key_builder(k, fn, *fn_args, **fn_kwargs), v) for k, v in result.items()], + [ + (self.key_builder(k, fn, *fn_args, **fn_kwargs), v) + for k, v in result.items() + ], ttl=self.ttl, ) except Exception: diff --git a/aiocache/plugins.py b/aiocache/plugins.py index efbb43113..4b6395317 100644 --- a/aiocache/plugins.py +++ b/aiocache/plugins.py @@ -43,9 +43,9 @@ async def do_save_time(self, client, *args, took=0, **kwargs): previous_min = client.profiling.get("{}_min".format(method)) client.profiling["{}_total".format(method)] = previous_total + 1 - client.profiling["{}_avg".format(method)] = previous_avg + (took - previous_avg) / ( - previous_total + 1 - ) + client.profiling["{}_avg".format(method)] = previous_avg + ( + took - previous_avg + ) / (previous_total + 1) client.profiling["{}_max".format(method)] = max(took, previous_max) client.profiling["{}_min".format(method)] = ( min(took, previous_min) if previous_min else took diff --git a/aiocache/serializers/__init__.py b/aiocache/serializers/__init__.py index c7499335b..8c0e4b6ae 100644 --- a/aiocache/serializers/__init__.py +++ b/aiocache/serializers/__init__.py @@ -20,6 +20,14 @@ del msgpack +try: + import msgspec +except ImportError: + logger.debug("msgspec not installed, MsgspecSerializer unavailable") +else: + from .serializers import MsgspecSerializer + + del msgspec __all__ = [ "BaseSerializer", @@ -28,4 +36,5 @@ "PickleSerializer", "JsonSerializer", "MsgPackSerializer", + "MsgspecSerializer", ] diff --git a/aiocache/serializers/serializers.py b/aiocache/serializers/serializers.py index 58a5b61b1..4cfd2a98a 100644 --- a/aiocache/serializers/serializers.py +++ b/aiocache/serializers/serializers.py @@ -1,9 +1,20 @@ import logging import pickle # noqa: S403 from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import ( + Any, + Callable, + Generic, + Literal, + Optional, + Type, + TypeVar, + Union, + overload, +) logger = logging.getLogger(__name__) +_T = TypeVar("_T") try: import ujson as json # noqa: I900 @@ -17,12 +28,22 @@ msgpack = None logger.debug("msgpack not installed, MsgPackSerializer unavailable") +try: + from msgspec.msgpack import Decoder as MsgspecDecoder + from msgspec.msgpack import Encoder as MsgspecEncoder +except ImportError: + MsgspecEncoder = None + + # Only here as extended typehinting + class MsgspecDecoder(Generic[_T]): + def decode(self, buf: Union[bytes, bytearray, memoryview[int]]) -> _T: ... + + logger.debug("msgspec not installed, MsgspecSerlizer unavailable") _NOT_SET = object() class BaseSerializer(ABC): - DEFAULT_ENCODING: Optional[str] = "utf-8" def __init__(self, *args, encoding=_NOT_SET, **kwargs): @@ -197,3 +218,73 @@ def loads(self, value): if value is None: return None return msgpack.loads(value, raw=raw, use_list=self.use_list) + + +class MsgspecSerializer(BaseSerializer, Generic[_T]): + @overload + def __init__( + self: "MsgspecSerializer[Any]", + enc_hook: Optional[Callable[[Any], Any]] = None, + decimal_format: Literal["string", "number"] = "string", + uuid_format: Literal["canonical", "hex", "bytes"] = "canonical", + order: Literal["deterministic", "sorted"] | None = None, + struct_type: None = None, + strict: bool = True, + dec_hook: Optional[Callable[[Type, Any], Any]] = None, + ext_hook: Optional[Callable[[int, memoryview[int]], Any]] = None, + ) -> None: ... + + @overload + def __init__( + self: "MsgspecSerializer[_T]", + enc_hook: Optional[Callable[[Any], Any]] = None, + decimal_format: Literal["string", "number"] = "string", + uuid_format: Literal["canonical", "hex", "bytes"] = "canonical", + order: Literal["deterministic", "sorted"] | None = None, + struct_type: Type[_T] = None, + strict: bool = True, + dec_hook: Optional[Callable[[type, Any], Any]] = None, + ext_hook: Optional[Callable[[int, memoryview[int]], Any]] = None, + ) -> None: ... + + def __init__( + self, + enc_hook: Optional[Callable[[Any], Any]] = None, + decimal_format: Literal["string", "number"] = "string", + uuid_format: Literal["canonical", "hex", "bytes"] = "canonical", + order: Literal["deterministic", "sorted"] | None = None, + struct_type: Type[_T] | None = None, + strict: bool = True, + dec_hook: Optional[Callable[[type, Any], Any]] = None, + ext_hook: Optional[Callable[[int, memoryview[int]], Any]] = None, + ): + if MsgspecEncoder is None: + raise RuntimeError("msgspec not installed, MsgspecSerializer unavailable") + + self.encoder = MsgspecEncoder( + enc_hook=enc_hook, + decimal_format=decimal_format, + uuid_format=uuid_format, + order=order, + ) + + if struct_type is not None: + self.decoder: "MsgspecDecoder[_T]" = MsgspecDecoder( + type=struct_type, dec_hook=dec_hook, ext_hook=ext_hook, strict=strict + ) + else: + self.decoder: "MsgspecDecoder[Any]" = MsgspecDecoder( + dec_hook=dec_hook, ext_hook=ext_hook, strict=strict + ) + + @overload + def dumps(self, value: _T) -> bytes: ... + + @overload + def dumps(self, value: Any) -> bytes: ... + + def dumps(self, value: _T | Any) -> bytes: + return self.encoder.encode(value) + + def loads(self, value: bytes) -> _T: + return self.decoder.decode(value) diff --git a/tests/acceptance/test_serializers.py b/tests/acceptance/test_serializers.py index 694f0a8b6..046382c1f 100644 --- a/tests/acceptance/test_serializers.py +++ b/tests/acceptance/test_serializers.py @@ -4,7 +4,7 @@ import pytest from marshmallow import Schema, fields, post_load - +from msgspec import Struct try: import ujson as json # noqa: I900 except ImportError: @@ -16,6 +16,7 @@ NullSerializer, PickleSerializer, StringSerializer, + MsgspecSerializer ) from ..utils import Keys @@ -154,3 +155,16 @@ async def test_get_set_alt_serializer_class(self, cache): cache.serializer = my_serializer await cache.set(Keys.KEY, my_obj) assert await cache.get(Keys.KEY) == my_serializer.loads(my_serializer.dumps(my_obj)) + + +class MsgspecTest(Struct): + a: int + b: str + +class TestMsgspecSerializer: + async def test_serlize_struct(self, cache): + my_serializer = MsgspecSerializer(type=MsgspecTest) + my_obj = MsgspecTest(1, "testing") + cache.serializer = my_serializer + await cache.set(Keys.KEY, my_obj) + assert await cache.get(Keys.KEY) == my_serializer.loads(my_serializer.dumps(my_obj))