diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index af7468ba14c8..674613cf4edc 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -123,9 +123,7 @@ def __init__( redis_kwargs.update(kwargs) self.redis_client = get_redis_client(**redis_kwargs) - self.redis_async_client: Optional[ - Union[async_redis_client, async_redis_cluster_client] - ] = None + self.redis_async_client: Optional[Union["async_redis_client", "async_redis_cluster_client"]] = None self.redis_kwargs = redis_kwargs self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs) @@ -150,9 +148,7 @@ def __init__( _ = asyncio.get_running_loop().create_task(self.ping()) except Exception as e: if "no running event loop" in str(e): - verbose_logger.debug( - "Ignoring async redis ping. No running event loop." - ) + verbose_logger.debug("Ignoring async redis ping. No running event loop.") else: verbose_logger.error( "Error connecting to Async Redis client - {}".format(str(e)), @@ -164,9 +160,7 @@ def __init__( if hasattr(self.redis_client, "ping"): self.redis_client.ping() # type: ignore except Exception as e: - verbose_logger.error( - "Error connecting to Sync Redis client", extra={"error": str(e)} - ) + verbose_logger.error("Error connecting to Sync Redis client", extra={"error": str(e)}) if litellm.default_redis_ttl is not None: super().__init__(default_ttl=int(litellm.default_redis_ttl)) @@ -182,18 +176,12 @@ def init_async_client( cached_client = in_memory_llm_clients_cache.get_cache(key="async-redis-client") if cached_client is not None: - redis_async_client = cast( - Union[async_redis_client, async_redis_cluster_client], cached_client - ) + redis_async_client = cast(Union[async_redis_client, async_redis_cluster_client], cached_client) else: # Create new connection pool and client for current event loop self.async_redis_conn_pool = get_redis_connection_pool(**self.redis_kwargs) - redis_async_client = get_redis_async_client( - connection_pool=self.async_redis_conn_pool, **self.redis_kwargs - ) - in_memory_llm_clients_cache.set_cache( - key="async-redis-client", value=self.redis_async_client - ) + redis_async_client = get_redis_async_client(connection_pool=self.async_redis_conn_pool, **self.redis_kwargs) + in_memory_llm_clients_cache.set_cache(key="async-redis-client", value=self.redis_async_client) self.redis_async_client = redis_async_client # type: ignore return redis_async_client @@ -209,9 +197,7 @@ def check_and_fix_namespace(self, key: str) -> str: def set_cache(self, key, value, **kwargs): ttl = self.get_ttl(**kwargs) - print_verbose( - f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}" - ) + print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}") key = self.check_and_fix_namespace(key=key) try: start_time = time.time() @@ -227,13 +213,9 @@ def set_cache(self, key, value, **kwargs): ) except Exception as e: # NON blocking - notify users Redis is throwing an exception - print_verbose( - f"litellm.caching.caching: set() - Got exception from REDIS : {str(e)}" - ) + print_verbose(f"litellm.caching.caching: set() - Got exception from REDIS : {str(e)}") - def increment_cache( - self, key, value: int, ttl: Optional[float] = None, **kwargs - ) -> int: + def increment_cache(self, key, value: int, ttl: Optional[float] = None, **kwargs) -> int: _redis_client = self.redis_client start_time = time.time() set_ttl = self.get_ttl(ttl=ttl) @@ -405,9 +387,7 @@ async def async_set_cache(self, key, value, **kwargs): nx=nx, ex=ttl, ) - print_verbose( - f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" - ) + print_verbose(f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}") end_time = time.time() _duration = end_time - start_time asyncio.create_task( @@ -456,9 +436,7 @@ async def _pipeline_helper( # Iterate through each key-value pair in the cache_list and set them in the pipeline. for cache_key, cache_value in cache_list: cache_key = self.check_and_fix_namespace(key=cache_key) - print_verbose( - f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}" - ) + print_verbose(f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}") json_cache_value = json.dumps(cache_value) # Set the value with a TTL if it's provided. _td: Optional[timedelta] = None @@ -473,9 +451,7 @@ async def _pipeline_helper( results = await pipe.execute() return results - async def async_set_cache_pipeline( - self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs - ): + async def async_set_cache_pipeline(self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs): """ Use Redis Pipelines for bulk write operations """ @@ -486,9 +462,7 @@ async def async_set_cache_pipeline( _redis_client = self.init_async_client() start_time = time.time() - print_verbose( - f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}" - ) + print_verbose(f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}") cache_value: Any = None try: async with _redis_client.pipeline(transaction=False) as pipe: @@ -549,9 +523,7 @@ async def _set_cache_sadd_helper( except Exception: raise - async def async_set_cache_sadd( - self, key, value: List, ttl: Optional[float], **kwargs - ): + async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float], **kwargs): from redis.asyncio import Redis start_time = time.time() @@ -582,12 +554,8 @@ async def async_set_cache_sadd( key = self.check_and_fix_namespace(key=key) print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}") try: - await self._set_cache_sadd_helper( - redis_client=_redis_client, key=key, value=value, ttl=ttl - ) - print_verbose( - f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}" - ) + await self._set_cache_sadd_helper(redis_client=_redis_client, key=key, value=value, ttl=ttl) + print_verbose(f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}") end_time = time.time() _duration = end_time - start_time asyncio.create_task( @@ -690,9 +658,7 @@ async def async_increment( raise e async def flush_cache_buffer(self): - print_verbose( - f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}" - ) + print_verbose(f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}") await self.async_set_cache_pipeline(self.redis_batch_writing_buffer) self.redis_batch_writing_buffer = [] @@ -705,9 +671,7 @@ def _get_cache_logic(self, cached_response: Any): # cached_response is in `b{} convert it to ModelResponse cached_response = cached_response.decode("utf-8") # Convert bytes to string try: - cached_response = json.loads( - cached_response - ) # Convert string to dictionary + cached_response = json.loads(cached_response) # Convert string to dictionary except Exception: cached_response = ast.literal_eval(cached_response) return cached_response @@ -728,15 +692,11 @@ def get_cache(self, key, parent_otel_span: Optional[Span] = None, **kwargs): end_time=end_time, parent_otel_span=parent_otel_span, ) - print_verbose( - f"Got Redis Cache: key: {key}, cached_response {cached_response}" - ) + print_verbose(f"Got Redis Cache: key: {key}, cached_response {cached_response}") return self._get_cache_logic(cached_response=cached_response) except Exception as e: # NON blocking - notify users Redis is throwing an exception - verbose_logger.error( - "litellm.caching.caching: get() - Got exception from REDIS: ", e - ) + verbose_logger.error("litellm.caching.caching: get() - Got exception from REDIS: ", e) def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]: """ @@ -807,9 +767,7 @@ def batch_get_cache( verbose_logger.error(f"Error occurred in batch get cache - {str(e)}") return key_value_dict - async def async_get_cache( - self, key, parent_otel_span: Optional[Span] = None, **kwargs - ): + async def async_get_cache(self, key, parent_otel_span: Optional[Span] = None, **kwargs): from redis.asyncio import Redis _redis_client: Redis = self.init_async_client() # type: ignore @@ -819,9 +777,7 @@ async def async_get_cache( try: print_verbose(f"Get Async Redis Cache: key: {key}") cached_response = await _redis_client.get(key) - print_verbose( - f"Got Async Redis Cache: key: {key}, cached_response {cached_response}" - ) + print_verbose(f"Got Async Redis Cache: key: {key}, cached_response {cached_response}") response = self._get_cache_logic(cached_response=cached_response) end_time = time.time() @@ -853,9 +809,7 @@ async def async_get_cache( event_metadata={"key": key}, ) ) - print_verbose( - f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}" - ) + print_verbose(f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}") async def async_batch_get_cache( self, @@ -959,9 +913,7 @@ def sync_ping(self) -> bool: error=e, call_type=f"sync_ping <- {_get_call_stack_info()}", ) - verbose_logger.error( - f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}" - ) + verbose_logger.error(f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}") raise e async def ping(self) -> bool: @@ -995,9 +947,7 @@ async def ping(self) -> bool: call_type=f"async_ping <- {_get_call_stack_info()}", ) ) - verbose_logger.error( - f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}" - ) + verbose_logger.error(f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}") raise e async def delete_cache_keys(self, keys): @@ -1051,9 +1001,7 @@ async def _pipeline_increment_helper( # Execute the pipeline and return results results = await pipe.execute() # only return float values - verbose_logger.debug( - f"Increment ASYNC Redis Cache PIPELINE: results: {results}" - ) + verbose_logger.debug(f"Increment ASYNC Redis Cache PIPELINE: results: {results}") return [r for r in results if isinstance(r, float)] async def async_increment_pipeline( @@ -1076,9 +1024,7 @@ async def async_increment_pipeline( _redis_client: Redis = self.init_async_client() # type: ignore start_time = time.time() - print_verbose( - f"Increment Async Redis Cache Pipeline: increment list: {increment_list}" - ) + print_verbose(f"Increment Async Redis Cache Pipeline: increment list: {increment_list}") try: async with _redis_client.pipeline(transaction=False) as pipe: @@ -1161,11 +1107,14 @@ async def async_rpush( int: The length of the list after the push operation """ _redis_client: Any = self.init_async_client() - start_time = time.time() + _time = time.time + _str = str + + start_time = _time() try: response = await _redis_client.rpush(key, *values) ## LOGGING ## - end_time = time.time() + end_time = _time() _duration = end_time - start_time asyncio.create_task( self.service_logger_obj.async_service_success_hook( @@ -1177,8 +1126,7 @@ async def async_rpush( return response except Exception as e: # NON blocking - notify users Redis is throwing an exception - ## LOGGING ## - end_time = time.time() + end_time = _time() _duration = end_time - start_time asyncio.create_task( self.service_logger_obj.async_service_failure_hook( @@ -1188,14 +1136,10 @@ async def async_rpush( call_type=f"async_rpush <- {_get_call_stack_info()}", ) ) - verbose_logger.error( - f"LiteLLM Redis Cache RPUSH: - Got exception from REDIS : {str(e)}" - ) + verbose_logger.error(f"LiteLLM Redis Cache RPUSH: - Got exception from REDIS : {_str(e)}") raise e - async def handle_lpop_count_for_older_redis_versions( - self, pipe: pipeline, key: str, count: int - ) -> List[bytes]: + async def handle_lpop_count_for_older_redis_versions(self, pipe: pipeline, key: str, count: int) -> List[bytes]: result: List[bytes] = [] for _ in range(count): pipe.lpop(key) @@ -1228,9 +1172,7 @@ async def async_lpop( if count is not None and major_version < 7: # For Redis < 7.0, use pipeline to execute multiple LPOP commands async with _redis_client.pipeline(transaction=False) as pipe: - result = await self.handle_lpop_count_for_older_redis_versions( - pipe, key, count - ) + result = await self.handle_lpop_count_for_older_redis_versions(pipe, key, count) else: # For Redis >= 7.0 or when count is None, use native LPOP with count result = await _redis_client.lpop(key, count) @@ -1252,9 +1194,7 @@ async def async_lpop( return result.decode("utf-8") except Exception: return result - elif isinstance(result, list) and all( - isinstance(item, bytes) for item in result - ): + elif isinstance(result, list) and all(isinstance(item, bytes) for item in result): try: return [item.decode("utf-8") for item in result] except Exception: @@ -1273,7 +1213,5 @@ async def async_lpop( call_type=f"async_lpop <- {_get_call_stack_info()}", ) ) - verbose_logger.error( - f"LiteLLM Redis Cache LPOP: - Got exception from REDIS : {str(e)}" - ) + verbose_logger.error(f"LiteLLM Redis Cache LPOP: - Got exception from REDIS : {str(e)}") raise e diff --git a/litellm/litellm_core_utils/safe_json_dumps.py b/litellm/litellm_core_utils/safe_json_dumps.py index c714e36b5f93..d6c65c23a54f 100644 --- a/litellm/litellm_core_utils/safe_json_dumps.py +++ b/litellm/litellm_core_utils/safe_json_dumps.py @@ -1,5 +1,5 @@ import json -from typing import Any, Union +from typing import Any from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH @@ -10,43 +10,45 @@ def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str: If a circular reference is detected then a marker string is returned. """ + # Prebinding for perf + _str = str + _isinstance = isinstance + _id = id + def _serialize(obj: Any, seen: set, depth: int) -> Any: # Check for maximum depth. if depth > max_depth: return "MaxDepthExceeded" # Base-case: if it is a primitive, simply return it. - if isinstance(obj, (str, int, float, bool, type(None))): + if _isinstance(obj, (str, int, float, bool, type(None))): return obj + ident = _id(obj) # Check for circular reference. - if id(obj) in seen: + if ident in seen: return "CircularReference Detected" - seen.add(id(obj)) - result: Union[dict, list, tuple, set, str] - if isinstance(obj, dict): - result = {} - for k, v in obj.items(): - if isinstance(k, (str)): - result[k] = _serialize(v, seen, depth + 1) - seen.remove(id(obj)) - return result - elif isinstance(obj, list): - result = [_serialize(item, seen, depth + 1) for item in obj] - seen.remove(id(obj)) - return result - elif isinstance(obj, tuple): - result = tuple(_serialize(item, seen, depth + 1) for item in obj) - seen.remove(id(obj)) - return result - elif isinstance(obj, set): - result = sorted([_serialize(item, seen, depth + 1) for item in obj]) - seen.remove(id(obj)) - return result - else: - # Fall back to string conversion for non-serializable objects. - try: - return str(obj) - except Exception: - return "Unserializable Object" + seen.add(ident) + try: + if _isinstance(obj, dict): + result = {} + for k, v in obj.items(): + # Only allow str keys (no attempt at conversion for speed & safety) + if _isinstance(k, str): + result[k] = _serialize(v, seen, depth + 1) + return result + elif _isinstance(obj, list): + return [_serialize(item, seen, depth + 1) for item in obj] + elif _isinstance(obj, tuple): + return tuple(_serialize(item, seen, depth + 1) for item in obj) + elif _isinstance(obj, set): + return sorted([_serialize(item, seen, depth + 1) for item in obj]) + else: + # Fall back to string conversion for non-serializable objects. + try: + return _str(obj) + except Exception: + return "Unserializable Object" + finally: + seen.remove(ident) safe_data = _serialize(data, set(), 0) return json.dumps(safe_data, default=str) diff --git a/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py b/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py index 91d0bee1d3aa..52deb70182ef 100644 --- a/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py +++ b/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py @@ -61,8 +61,8 @@ def _should_commit_spend_updates_to_redis() -> bool: """ from litellm.proxy.proxy_server import general_settings - _use_redis_transaction_buffer: Optional[Union[bool, str]] = ( - general_settings.get("use_redis_transaction_buffer", False) + _use_redis_transaction_buffer: Optional[Union[bool, str]] = general_settings.get( + "use_redis_transaction_buffer", False ) if isinstance(_use_redis_transaction_buffer, str): _use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer) @@ -94,9 +94,12 @@ async def _store_transactions_in_redis( key=redis_key, values=list_of_transactions, ) - await self._emit_new_item_added_to_redis_buffer_event( - queue_size=current_redis_buffer_size, - service=service_type, + # Fire and forget emission of event, don't await + asyncio.create_task( + self._emit_new_item_added_to_redis_buffer_event( + queue_size=current_redis_buffer_size, + service=service_type, + ) ) async def store_in_memory_spend_updates_in_redis( @@ -151,15 +154,11 @@ async def store_in_memory_spend_updates_in_redis( ``` """ if self.redis_cache is None: - verbose_proxy_logger.debug( - "redis_cache is None, skipping store_in_memory_spend_updates_in_redis" - ) + verbose_proxy_logger.debug("redis_cache is None, skipping store_in_memory_spend_updates_in_redis") return # Get all transactions - db_spend_update_transactions = ( - await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions() - ) + db_spend_update_transactions = await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions() daily_spend_update_transactions = ( await daily_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions() ) @@ -170,12 +169,8 @@ async def store_in_memory_spend_updates_in_redis( await daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions() ) - verbose_proxy_logger.debug( - "ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions - ) - verbose_proxy_logger.debug( - "ALL DAILY SPEND UPDATE TRANSACTIONS: %s", daily_spend_update_transactions - ) + verbose_proxy_logger.debug("ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions) + verbose_proxy_logger.debug("ALL DAILY SPEND UPDATE TRANSACTIONS: %s", daily_spend_update_transactions) await self._store_transactions_in_redis( transactions=db_spend_update_transactions, @@ -295,9 +290,7 @@ async def get_all_daily_spend_update_transactions_from_redis_buffer( ) if list_of_transactions is None: return None - list_of_daily_spend_update_transactions = [ - json.loads(transaction) for transaction in list_of_transactions - ] + list_of_daily_spend_update_transactions = [json.loads(transaction) for transaction in list_of_transactions] return cast( Dict[str, DailyUserSpendTransaction], DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions( @@ -319,9 +312,7 @@ async def get_all_daily_team_spend_update_transactions_from_redis_buffer( ) if list_of_transactions is None: return None - list_of_daily_spend_update_transactions = [ - json.loads(transaction) for transaction in list_of_transactions - ] + list_of_daily_spend_update_transactions = [json.loads(transaction) for transaction in list_of_transactions] return cast( Dict[str, DailyTeamSpendTransaction], DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions( @@ -343,9 +334,7 @@ async def get_all_daily_tag_spend_update_transactions_from_redis_buffer( ) if list_of_transactions is None: return None - list_of_daily_spend_update_transactions = [ - json.loads(transaction) for transaction in list_of_transactions - ] + list_of_daily_spend_update_transactions = [json.loads(transaction) for transaction in list_of_transactions] return cast( Dict[str, DailyTagSpendTransaction], DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(