Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 38 additions & 100 deletions litellm/caching/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ def __init__(

redis_kwargs.update(kwargs)
self.redis_client = get_redis_client(**redis_kwargs)
self.redis_async_client: Optional[
Union[async_redis_client, async_redis_cluster_client]
] = None
self.redis_async_client: Optional[Union["async_redis_client", "async_redis_cluster_client"]] = None
self.redis_kwargs = redis_kwargs
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)

Expand All @@ -150,9 +148,7 @@ def __init__(
_ = asyncio.get_running_loop().create_task(self.ping())
except Exception as e:
if "no running event loop" in str(e):
verbose_logger.debug(
"Ignoring async redis ping. No running event loop."
)
verbose_logger.debug("Ignoring async redis ping. No running event loop.")
else:
verbose_logger.error(
"Error connecting to Async Redis client - {}".format(str(e)),
Expand All @@ -164,9 +160,7 @@ def __init__(
if hasattr(self.redis_client, "ping"):
self.redis_client.ping() # type: ignore
except Exception as e:
verbose_logger.error(
"Error connecting to Sync Redis client", extra={"error": str(e)}
)
verbose_logger.error("Error connecting to Sync Redis client", extra={"error": str(e)})

if litellm.default_redis_ttl is not None:
super().__init__(default_ttl=int(litellm.default_redis_ttl))
Expand All @@ -182,18 +176,12 @@ def init_async_client(

cached_client = in_memory_llm_clients_cache.get_cache(key="async-redis-client")
if cached_client is not None:
redis_async_client = cast(
Union[async_redis_client, async_redis_cluster_client], cached_client
)
redis_async_client = cast(Union[async_redis_client, async_redis_cluster_client], cached_client)
else:
# Create new connection pool and client for current event loop
self.async_redis_conn_pool = get_redis_connection_pool(**self.redis_kwargs)
redis_async_client = get_redis_async_client(
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
)
in_memory_llm_clients_cache.set_cache(
key="async-redis-client", value=self.redis_async_client
)
redis_async_client = get_redis_async_client(connection_pool=self.async_redis_conn_pool, **self.redis_kwargs)
in_memory_llm_clients_cache.set_cache(key="async-redis-client", value=self.redis_async_client)

self.redis_async_client = redis_async_client # type: ignore
return redis_async_client
Expand All @@ -209,9 +197,7 @@ def check_and_fix_namespace(self, key: str) -> str:

def set_cache(self, key, value, **kwargs):
ttl = self.get_ttl(**kwargs)
print_verbose(
f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
)
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}")
key = self.check_and_fix_namespace(key=key)
try:
start_time = time.time()
Expand All @@ -227,13 +213,9 @@ def set_cache(self, key, value, **kwargs):
)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
print_verbose(
f"litellm.caching.caching: set() - Got exception from REDIS : {str(e)}"
)
print_verbose(f"litellm.caching.caching: set() - Got exception from REDIS : {str(e)}")

def increment_cache(
self, key, value: int, ttl: Optional[float] = None, **kwargs
) -> int:
def increment_cache(self, key, value: int, ttl: Optional[float] = None, **kwargs) -> int:
_redis_client = self.redis_client
start_time = time.time()
set_ttl = self.get_ttl(ttl=ttl)
Expand Down Expand Up @@ -405,9 +387,7 @@ async def async_set_cache(self, key, value, **kwargs):
nx=nx,
ex=ttl,
)
print_verbose(
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
print_verbose(f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
Expand Down Expand Up @@ -456,9 +436,7 @@ async def _pipeline_helper(
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
for cache_key, cache_value in cache_list:
cache_key = self.check_and_fix_namespace(key=cache_key)
print_verbose(
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
)
print_verbose(f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}")
json_cache_value = json.dumps(cache_value)
# Set the value with a TTL if it's provided.
_td: Optional[timedelta] = None
Expand All @@ -473,9 +451,7 @@ async def _pipeline_helper(
results = await pipe.execute()
return results

async def async_set_cache_pipeline(
self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs
):
async def async_set_cache_pipeline(self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs):
"""
Use Redis Pipelines for bulk write operations
"""
Expand All @@ -486,9 +462,7 @@ async def async_set_cache_pipeline(
_redis_client = self.init_async_client()
start_time = time.time()

print_verbose(
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
)
print_verbose(f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}")
cache_value: Any = None
try:
async with _redis_client.pipeline(transaction=False) as pipe:
Expand Down Expand Up @@ -549,9 +523,7 @@ async def _set_cache_sadd_helper(
except Exception:
raise

async def async_set_cache_sadd(
self, key, value: List, ttl: Optional[float], **kwargs
):
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float], **kwargs):
from redis.asyncio import Redis

start_time = time.time()
Expand Down Expand Up @@ -582,12 +554,8 @@ async def async_set_cache_sadd(
key = self.check_and_fix_namespace(key=key)
print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
try:
await self._set_cache_sadd_helper(
redis_client=_redis_client, key=key, value=value, ttl=ttl
)
print_verbose(
f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}"
)
await self._set_cache_sadd_helper(redis_client=_redis_client, key=key, value=value, ttl=ttl)
print_verbose(f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}")
end_time = time.time()
_duration = end_time - start_time
asyncio.create_task(
Expand Down Expand Up @@ -690,9 +658,7 @@ async def async_increment(
raise e

async def flush_cache_buffer(self):
print_verbose(
f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
)
print_verbose(f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}")
await self.async_set_cache_pipeline(self.redis_batch_writing_buffer)
self.redis_batch_writing_buffer = []

Expand All @@ -705,9 +671,7 @@ def _get_cache_logic(self, cached_response: Any):
# cached_response is in `b{} convert it to ModelResponse
cached_response = cached_response.decode("utf-8") # Convert bytes to string
try:
cached_response = json.loads(
cached_response
) # Convert string to dictionary
cached_response = json.loads(cached_response) # Convert string to dictionary
except Exception:
cached_response = ast.literal_eval(cached_response)
return cached_response
Expand All @@ -728,15 +692,11 @@ def get_cache(self, key, parent_otel_span: Optional[Span] = None, **kwargs):
end_time=end_time,
parent_otel_span=parent_otel_span,
)
print_verbose(
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
)
print_verbose(f"Got Redis Cache: key: {key}, cached_response {cached_response}")
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
verbose_logger.error(
"litellm.caching.caching: get() - Got exception from REDIS: ", e
)
verbose_logger.error("litellm.caching.caching: get() - Got exception from REDIS: ", e)

def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
"""
Expand Down Expand Up @@ -807,9 +767,7 @@ def batch_get_cache(
verbose_logger.error(f"Error occurred in batch get cache - {str(e)}")
return key_value_dict

async def async_get_cache(
self, key, parent_otel_span: Optional[Span] = None, **kwargs
):
async def async_get_cache(self, key, parent_otel_span: Optional[Span] = None, **kwargs):
from redis.asyncio import Redis

_redis_client: Redis = self.init_async_client() # type: ignore
Expand All @@ -819,9 +777,7 @@ async def async_get_cache(
try:
print_verbose(f"Get Async Redis Cache: key: {key}")
cached_response = await _redis_client.get(key)
print_verbose(
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
)
print_verbose(f"Got Async Redis Cache: key: {key}, cached_response {cached_response}")
response = self._get_cache_logic(cached_response=cached_response)

end_time = time.time()
Expand Down Expand Up @@ -853,9 +809,7 @@ async def async_get_cache(
event_metadata={"key": key},
)
)
print_verbose(
f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}"
)
print_verbose(f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}")

async def async_batch_get_cache(
self,
Expand Down Expand Up @@ -959,9 +913,7 @@ def sync_ping(self) -> bool:
error=e,
call_type=f"sync_ping <- {_get_call_stack_info()}",
)
verbose_logger.error(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
)
verbose_logger.error(f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}")
raise e

async def ping(self) -> bool:
Expand Down Expand Up @@ -995,9 +947,7 @@ async def ping(self) -> bool:
call_type=f"async_ping <- {_get_call_stack_info()}",
)
)
verbose_logger.error(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
)
verbose_logger.error(f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}")
raise e

async def delete_cache_keys(self, keys):
Expand Down Expand Up @@ -1051,9 +1001,7 @@ async def _pipeline_increment_helper(
# Execute the pipeline and return results
results = await pipe.execute()
# only return float values
verbose_logger.debug(
f"Increment ASYNC Redis Cache PIPELINE: results: {results}"
)
verbose_logger.debug(f"Increment ASYNC Redis Cache PIPELINE: results: {results}")
return [r for r in results if isinstance(r, float)]

async def async_increment_pipeline(
Expand All @@ -1076,9 +1024,7 @@ async def async_increment_pipeline(
_redis_client: Redis = self.init_async_client() # type: ignore
start_time = time.time()

print_verbose(
f"Increment Async Redis Cache Pipeline: increment list: {increment_list}"
)
print_verbose(f"Increment Async Redis Cache Pipeline: increment list: {increment_list}")

try:
async with _redis_client.pipeline(transaction=False) as pipe:
Expand Down Expand Up @@ -1161,11 +1107,14 @@ async def async_rpush(
int: The length of the list after the push operation
"""
_redis_client: Any = self.init_async_client()
start_time = time.time()
_time = time.time
_str = str

start_time = _time()
try:
response = await _redis_client.rpush(key, *values)
## LOGGING ##
end_time = time.time()
end_time = _time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_success_hook(
Expand All @@ -1177,8 +1126,7 @@ async def async_rpush(
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
## LOGGING ##
end_time = time.time()
end_time = _time()
_duration = end_time - start_time
asyncio.create_task(
self.service_logger_obj.async_service_failure_hook(
Expand All @@ -1188,14 +1136,10 @@ async def async_rpush(
call_type=f"async_rpush <- {_get_call_stack_info()}",
)
)
verbose_logger.error(
f"LiteLLM Redis Cache RPUSH: - Got exception from REDIS : {str(e)}"
)
verbose_logger.error(f"LiteLLM Redis Cache RPUSH: - Got exception from REDIS : {_str(e)}")
raise e

async def handle_lpop_count_for_older_redis_versions(
self, pipe: pipeline, key: str, count: int
) -> List[bytes]:
async def handle_lpop_count_for_older_redis_versions(self, pipe: pipeline, key: str, count: int) -> List[bytes]:
result: List[bytes] = []
for _ in range(count):
pipe.lpop(key)
Expand Down Expand Up @@ -1228,9 +1172,7 @@ async def async_lpop(
if count is not None and major_version < 7:
# For Redis < 7.0, use pipeline to execute multiple LPOP commands
async with _redis_client.pipeline(transaction=False) as pipe:
result = await self.handle_lpop_count_for_older_redis_versions(
pipe, key, count
)
result = await self.handle_lpop_count_for_older_redis_versions(pipe, key, count)
else:
# For Redis >= 7.0 or when count is None, use native LPOP with count
result = await _redis_client.lpop(key, count)
Expand All @@ -1252,9 +1194,7 @@ async def async_lpop(
return result.decode("utf-8")
except Exception:
return result
elif isinstance(result, list) and all(
isinstance(item, bytes) for item in result
):
elif isinstance(result, list) and all(isinstance(item, bytes) for item in result):
try:
return [item.decode("utf-8") for item in result]
except Exception:
Expand All @@ -1273,7 +1213,5 @@ async def async_lpop(
call_type=f"async_lpop <- {_get_call_stack_info()}",
)
)
verbose_logger.error(
f"LiteLLM Redis Cache LPOP: - Got exception from REDIS : {str(e)}"
)
verbose_logger.error(f"LiteLLM Redis Cache LPOP: - Got exception from REDIS : {str(e)}")
raise e
Loading