Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion aiocache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
_AIOCACHE_CACHES.append(MemcachedCache)
del aiomcache

from .decorators import cached, cached_stampede, multi_cached # noqa: E402,I202
from .decorators import (
cached,
cached_stampede, # noqa: E402,I202
multi_cached,
)

__all__ = (
"cached",
Expand Down
4 changes: 3 additions & 1 deletion aiocache/backends/memcached.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ async def _add(self, key, value, ttl=0, _conn=None):
except aiomcache.exceptions.ValidationException as e:
raise TypeError("aiomcache error: {}".format(str(e)))
if not ret:
raise ValueError("Key {} already exists, use .set to update the value".format(key))
raise ValueError(
"Key {} already exists, use .set to update the value".format(key)
)

return True

Expand Down
2 changes: 1 addition & 1 deletion aiocache/backends/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class SimpleMemoryCache(BaseCache[str]):
# TODO(PY312): https://peps.python.org/pep-0692/
def __init__(self, **kwargs):
# Extract maxsize before passing kwargs to base class
self.maxsize = kwargs.pop('maxsize', None)
self.maxsize = kwargs.pop("maxsize", None)
if "serializer" not in kwargs:
kwargs["serializer"] = NullSerializer()
super().__init__(**kwargs)
Expand Down
4 changes: 1 addition & 3 deletions aiocache/backends/valkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ class ValkeyCache(BaseCache[str]):

NAME = "valkey"

def __init__(
self, config: GlideClientConfiguration, **kwargs
):
def __init__(self, config: GlideClientConfiguration, **kwargs):
self.config = config

if "serializer" not in kwargs:
Expand Down
48 changes: 36 additions & 12 deletions aiocache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from enum import Enum
from types import TracebackType
from typing import Callable, Generic, List, Optional, Set, TYPE_CHECKING, Type, TypeVar
from typing import TYPE_CHECKING, Callable, Generic, List, Optional, Set, Type, TypeVar

from aiocache.serializers import StringSerializer

Expand All @@ -22,7 +22,6 @@


class API:

CMDS: Set[Callable[..., object]] = set()

@classmethod
Expand Down Expand Up @@ -79,7 +78,9 @@ def plugins(cls, func):
async def _plugins(self, *args, **kwargs):
start = time.monotonic()
for plugin in self.plugins:
await getattr(plugin, "pre_{}".format(func.__name__))(self, *args, **kwargs)
await getattr(plugin, "pre_{}".format(func.__name__))(
self, *args, **kwargs
)

ret = await func(self, *args, **kwargs)

Expand Down Expand Up @@ -152,7 +153,9 @@ def plugins(self, value):
@API.aiocache_enabled(fake_return=True)
@API.timeout
@API.plugins
async def add(self, key, value, ttl=SENTINEL, dumps_fn=None, namespace=None, _conn=None):
async def add(
self, key, value, ttl=SENTINEL, dumps_fn=None, namespace=None, _conn=None
):
"""
Stores the value in the given key with ttl if specified. Raises an error if the
key already exists.
Expand Down Expand Up @@ -205,9 +208,13 @@ async def get(self, key, default=None, loads_fn=None, namespace=None, _conn=None
loads = loads_fn or self.serializer.loads
ns_key = self.build_key(key, namespace)

value = loads(await self._get(ns_key, encoding=self.serializer.encoding, _conn=_conn))
value = loads(
await self._get(ns_key, encoding=self.serializer.encoding, _conn=_conn)
)

logger.debug("GET %s %s (%.4f)s", ns_key, value is not None, time.monotonic() - start)
logger.debug(
"GET %s %s (%.4f)s", ns_key, value is not None, time.monotonic() - start
)
return value if value is not None else default

@abstractmethod
Expand Down Expand Up @@ -262,7 +269,14 @@ async def _multi_get(self, keys, encoding, _conn=None):
@API.timeout
@API.plugins
async def set(
self, key, value, ttl=SENTINEL, dumps_fn=None, namespace=None, _cas_token=None, _conn=None
self,
key,
value,
ttl=SENTINEL,
dumps_fn=None,
namespace=None,
_cas_token=None,
_conn=None,
):
"""
Stores the value in the given key with ttl if specified
Expand All @@ -284,7 +298,11 @@ async def set(
ns_key = self.build_key(key, namespace)

res = await self._set(
ns_key, dumps(value), ttl=self._get_ttl(ttl), _cas_token=_cas_token, _conn=_conn
ns_key,
dumps(value),
ttl=self._get_ttl(ttl),
_cas_token=_cas_token,
_conn=_conn,
)

logger.debug("SET %s %d (%.4f)s", ns_key, True, time.monotonic() - start)
Expand All @@ -298,7 +316,9 @@ async def _set(self, key, value, ttl, _cas_token=None, _conn=None):
@API.aiocache_enabled(fake_return=True)
@API.timeout
@API.plugins
async def multi_set(self, pairs, ttl=SENTINEL, dumps_fn=None, namespace=None, _conn=None):
async def multi_set(
self, pairs, ttl=SENTINEL, dumps_fn=None, namespace=None, _conn=None
):
"""
Stores multiple values in the given keys.

Expand Down Expand Up @@ -538,8 +558,10 @@ async def __aenter__(self):
return self

async def __aexit__(
self, exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException], tb: Optional[TracebackType]
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
await self.close()

Expand All @@ -562,7 +584,9 @@ def __getattr__(self, name):
@classmethod
def _inject_conn(cls, cmd_name):
async def _do_inject_conn(self, *args, **kwargs):
return await getattr(self._cache, cmd_name)(*args, _conn=self._conn, **kwargs)
return await getattr(self._cache, cmd_name)(
*args, _conn=self._conn, **kwargs
)

return _do_inject_conn

Expand Down
25 changes: 21 additions & 4 deletions aiocache/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ async def wrapper(*args, **kwargs):
return wrapper

async def decorator(
self, f, *args, cache_read=True, cache_write=True, aiocache_wait_for_write=True, **kwargs
self,
f,
*args,
cache_read=True,
cache_write=True,
aiocache_wait_for_write=True,
**kwargs,
):
key = self.get_cache_key(f, args, kwargs)

Expand Down Expand Up @@ -237,11 +243,19 @@ async def wrapper(*args, **kwargs):
return wrapper

async def decorator(
self, f, *args, cache_read=True, cache_write=True, aiocache_wait_for_write=True, **kwargs
self,
f,
*args,
cache_read=True,
cache_write=True,
aiocache_wait_for_write=True,
**kwargs,
):
missing_keys = []
partial = {}
orig_keys, cache_keys, new_args, args_index = self.get_cache_keys(f, args, kwargs)
orig_keys, cache_keys, new_args, args_index = self.get_cache_keys(
f, args, kwargs
)

if cache_read:
values = await self.get_from_cache(*cache_keys)
Expand Down Expand Up @@ -303,7 +317,10 @@ async def get_from_cache(self, *keys):
async def set_in_cache(self, result, fn, fn_args, fn_kwargs):
try:
await self.cache.multi_set(
[(self.key_builder(k, fn, *fn_args, **fn_kwargs), v) for k, v in result.items()],
[
(self.key_builder(k, fn, *fn_args, **fn_kwargs), v)
for k, v in result.items()
],
ttl=self.ttl,
)
except Exception:
Expand Down
6 changes: 3 additions & 3 deletions aiocache/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ async def do_save_time(self, client, *args, took=0, **kwargs):
previous_min = client.profiling.get("{}_min".format(method))

client.profiling["{}_total".format(method)] = previous_total + 1
client.profiling["{}_avg".format(method)] = previous_avg + (took - previous_avg) / (
previous_total + 1
)
client.profiling["{}_avg".format(method)] = previous_avg + (
took - previous_avg
) / (previous_total + 1)
client.profiling["{}_max".format(method)] = max(took, previous_max)
client.profiling["{}_min".format(method)] = (
min(took, previous_min) if previous_min else took
Expand Down
9 changes: 9 additions & 0 deletions aiocache/serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@

del msgpack

try:
import msgspec
except ImportError:
logger.debug("msgspec not installed, MsgspecSerializer unavailable")
else:
from .serializers import MsgspecSerializer

del msgspec

__all__ = [
"BaseSerializer",
Expand All @@ -28,4 +36,5 @@
"PickleSerializer",
"JsonSerializer",
"MsgPackSerializer",
"MsgspecSerializer",
]
95 changes: 93 additions & 2 deletions aiocache/serializers/serializers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import logging
import pickle # noqa: S403
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import (
Any,
Callable,
Generic,
Literal,
Optional,
Type,
TypeVar,
Union,
overload,
)

logger = logging.getLogger(__name__)
_T = TypeVar("_T")

try:
import ujson as json # noqa: I900
Expand All @@ -17,12 +28,22 @@
msgpack = None
logger.debug("msgpack not installed, MsgPackSerializer unavailable")

try:
from msgspec.msgpack import Decoder as MsgspecDecoder
from msgspec.msgpack import Encoder as MsgspecEncoder
except ImportError:
MsgspecEncoder = None

# Only here as extended typehinting
class MsgspecDecoder(Generic[_T]):
def decode(self, buf: Union[bytes, bytearray, memoryview[int]]) -> _T: ...

logger.debug("msgspec not installed, MsgspecSerlizer unavailable")

_NOT_SET = object()


class BaseSerializer(ABC):

DEFAULT_ENCODING: Optional[str] = "utf-8"

def __init__(self, *args, encoding=_NOT_SET, **kwargs):
Expand Down Expand Up @@ -197,3 +218,73 @@ def loads(self, value):
if value is None:
return None
return msgpack.loads(value, raw=raw, use_list=self.use_list)


class MsgspecSerializer(BaseSerializer, Generic[_T]):
@overload
def __init__(
self: "MsgspecSerializer[Any]",
enc_hook: Optional[Callable[[Any], Any]] = None,
decimal_format: Literal["string", "number"] = "string",
uuid_format: Literal["canonical", "hex", "bytes"] = "canonical",
order: Literal["deterministic", "sorted"] | None = None,
struct_type: None = None,
strict: bool = True,
dec_hook: Optional[Callable[[Type, Any], Any]] = None,
ext_hook: Optional[Callable[[int, memoryview[int]], Any]] = None,
) -> None: ...

@overload
def __init__(
self: "MsgspecSerializer[_T]",
enc_hook: Optional[Callable[[Any], Any]] = None,
decimal_format: Literal["string", "number"] = "string",
uuid_format: Literal["canonical", "hex", "bytes"] = "canonical",
order: Literal["deterministic", "sorted"] | None = None,
struct_type: Type[_T] = None,
strict: bool = True,
dec_hook: Optional[Callable[[type, Any], Any]] = None,
ext_hook: Optional[Callable[[int, memoryview[int]], Any]] = None,
) -> None: ...

def __init__(
self,
enc_hook: Optional[Callable[[Any], Any]] = None,
decimal_format: Literal["string", "number"] = "string",
uuid_format: Literal["canonical", "hex", "bytes"] = "canonical",
order: Literal["deterministic", "sorted"] | None = None,
struct_type: Type[_T] | None = None,
strict: bool = True,
dec_hook: Optional[Callable[[type, Any], Any]] = None,
ext_hook: Optional[Callable[[int, memoryview[int]], Any]] = None,
):
if MsgspecEncoder is None:
raise RuntimeError("msgspec not installed, MsgspecSerializer unavailable")

self.encoder = MsgspecEncoder(
enc_hook=enc_hook,
decimal_format=decimal_format,
uuid_format=uuid_format,
order=order,
)

if struct_type is not None:
self.decoder: "MsgspecDecoder[_T]" = MsgspecDecoder(
type=struct_type, dec_hook=dec_hook, ext_hook=ext_hook, strict=strict
)
else:
self.decoder: "MsgspecDecoder[Any]" = MsgspecDecoder(
dec_hook=dec_hook, ext_hook=ext_hook, strict=strict
)

@overload
def dumps(self, value: _T) -> bytes: ...

@overload
def dumps(self, value: Any) -> bytes: ...

def dumps(self, value: _T | Any) -> bytes:
return self.encoder.encode(value)

def loads(self, value: bytes) -> _T:
return self.decoder.decode(value)
16 changes: 15 additions & 1 deletion tests/acceptance/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
from marshmallow import Schema, fields, post_load

from msgspec import Struct
try:
import ujson as json # noqa: I900
except ImportError:
Expand All @@ -16,6 +16,7 @@
NullSerializer,
PickleSerializer,
StringSerializer,
MsgspecSerializer
)
from ..utils import Keys

Expand Down Expand Up @@ -154,3 +155,16 @@ async def test_get_set_alt_serializer_class(self, cache):
cache.serializer = my_serializer
await cache.set(Keys.KEY, my_obj)
assert await cache.get(Keys.KEY) == my_serializer.loads(my_serializer.dumps(my_obj))


class MsgspecTest(Struct):
a: int
b: str

class TestMsgspecSerializer:
async def test_serlize_struct(self, cache):
my_serializer = MsgspecSerializer(type=MsgspecTest)
my_obj = MsgspecTest(1, "testing")
cache.serializer = my_serializer
await cache.set(Keys.KEY, my_obj)
assert await cache.get(Keys.KEY) == my_serializer.loads(my_serializer.dumps(my_obj))