Skip to content

Commit 8197fd7

Browse files
committed
Revert "merge from upstream"
This reverts commit e6ca918.
1 parent e6ca918 commit 8197fd7

File tree

1 file changed

+58
-14
lines changed

1 file changed

+58
-14
lines changed

litellm/caching/caching_handler.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
This contains LLMCachingHandler
2+
This contains LLMCachingHandler
33
44
This exposes two methods:
55
- async_get_cache
@@ -18,6 +18,7 @@
1818
import datetime
1919
import inspect
2020
import threading
21+
from functools import lru_cache, wraps
2122
from typing import (
2223
TYPE_CHECKING,
2324
Any,
@@ -35,11 +36,13 @@
3536

3637
import litellm
3738
from litellm._logging import print_verbose, verbose_logger
39+
from litellm._service_logger import ServiceLogging
40+
from litellm.caching import InMemoryCache
3841
from litellm.caching.caching import S3Cache
39-
from litellm.types.caching import CachedEmbedding
4042
from litellm.litellm_core_utils.logging_utils import (
4143
_assemble_complete_response_from_streaming_chunks,
4244
)
45+
from litellm.types.caching import CachedEmbedding
4346
from litellm.types.rerank import RerankResponse
4447
from litellm.types.utils import (
4548
CallTypes,
@@ -68,7 +71,12 @@ class CachingHandlerResponse(BaseModel):
6871

6972
cached_result: Optional[Any] = None
7073
final_embedding_cached_response: Optional[EmbeddingResponse] = None
71-
embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
74+
embedding_all_elements_cache_hit: bool = (
75+
False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
76+
)
77+
78+
79+
in_memory_cache_obj = InMemoryCache()
7280

7381

7482
class LLMCachingHandler:
@@ -78,11 +86,20 @@ def __init__(
7886
request_kwargs: Dict[str, Any],
7987
start_time: datetime.datetime,
8088
):
89+
from litellm.caching import DualCache, RedisCache
90+
8191
self.async_streaming_chunks: List[ModelResponse] = []
8292
self.sync_streaming_chunks: List[ModelResponse] = []
8393
self.request_kwargs = request_kwargs
8494
self.original_function = original_function
8595
self.start_time = start_time
96+
if litellm.cache is not None and isinstance(litellm.cache.cache, RedisCache):
97+
self.dual_cache: Optional[DualCache] = DualCache(
98+
redis_cache=litellm.cache.cache,
99+
in_memory_cache=in_memory_cache_obj,
100+
)
101+
else:
102+
self.dual_cache = None
86103
pass
87104

88105
async def _async_get_cache(
@@ -115,10 +132,16 @@ async def _async_get_cache(
115132
Raises:
116133
None
117134
"""
135+
from litellm.litellm_core_utils.core_helpers import (
136+
_get_parent_otel_span_from_kwargs,
137+
)
118138
from litellm.utils import CustomStreamWrapper
119139

140+
kwargs = kwargs.copy()
120141
args = args or ()
121142

143+
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
144+
kwargs["parent_otel_span"] = parent_otel_span
122145
final_embedding_cached_response: Optional[EmbeddingResponse] = None
123146
embedding_all_elements_cache_hit: bool = False
124147
cached_result: Optional[Any] = None
@@ -306,13 +329,15 @@ def handle_kwargs_input_list_or_str(self, kwargs: Dict[str, Any]) -> List[str]:
306329
else:
307330
raise ValueError("input must be a string or a list")
308331

309-
def _extract_model_from_cached_results(self, non_null_list: List[Tuple[int, CachedEmbedding]]) -> Optional[str]:
332+
def _extract_model_from_cached_results(
333+
self, non_null_list: List[Tuple[int, CachedEmbedding]]
334+
) -> Optional[str]:
310335
"""
311336
Helper method to extract the model name from cached results.
312-
337+
313338
Args:
314339
non_null_list: List of (idx, cr) tuples where cr is the cached result dict
315-
340+
316341
Returns:
317342
Optional[str]: The model name if found, None otherwise
318343
"""
@@ -558,7 +583,12 @@ async def _retrieve_from_cache(
558583
preset_cache_key = litellm.cache.get_cache_key(
559584
**{**new_kwargs, "input": i}
560585
)
561-
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
586+
tasks.append(
587+
litellm.cache.async_get_cache(
588+
cache_key=preset_cache_key,
589+
dynamic_cache_object=self.dual_cache,
590+
)
591+
)
562592
cached_result = await asyncio.gather(*tasks)
563593
## check if cached result is None ##
564594
if cached_result is not None and isinstance(cached_result, list):
@@ -567,9 +597,14 @@ async def _retrieve_from_cache(
567597
cached_result = None
568598
else:
569599
if litellm.cache._supports_async() is True:
570-
cached_result = await litellm.cache.async_get_cache(**new_kwargs)
600+
## check if dual cache is supported ##
601+
cached_result = await litellm.cache.async_get_cache(
602+
dynamic_cache_object=self.dual_cache, **new_kwargs
603+
)
571604
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
572-
cached_result = litellm.cache.get_cache(**new_kwargs)
605+
cached_result = litellm.cache.get_cache(
606+
dynamic_cache_object=self.dual_cache, **new_kwargs
607+
)
573608
return cached_result
574609

575610
def _convert_cached_result_to_model_response(
@@ -735,6 +770,9 @@ async def async_set_cache(
735770
Raises:
736771
None
737772
"""
773+
from litellm.litellm_core_utils.core_helpers import (
774+
_get_parent_otel_span_from_kwargs,
775+
)
738776

739777
if litellm.cache is None:
740778
return
@@ -746,6 +784,8 @@ async def async_set_cache(
746784
args,
747785
)
748786
)
787+
parent_otel_span = _get_parent_otel_span_from_kwargs(new_kwargs)
788+
new_kwargs["parent_otel_span"] = parent_otel_span
749789
# [OPTIONAL] ADD TO CACHE
750790
if self._should_store_result_in_cache(
751791
original_function=original_function, kwargs=new_kwargs
@@ -764,7 +804,9 @@ async def async_set_cache(
764804
) # s3 doesn't support bulk writing. Exclude.
765805
):
766806
asyncio.create_task(
767-
litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
807+
litellm.cache.async_add_cache_pipeline(
808+
result, dynamic_cache_object=self.dual_cache, **new_kwargs
809+
)
768810
)
769811
elif isinstance(litellm.cache.cache, S3Cache):
770812
threading.Thread(
@@ -775,7 +817,9 @@ async def async_set_cache(
775817
else:
776818
asyncio.create_task(
777819
litellm.cache.async_add_cache(
778-
result.model_dump_json(), **new_kwargs
820+
result.model_dump_json(),
821+
dynamic_cache_object=self.dual_cache,
822+
**new_kwargs,
779823
)
780824
)
781825
else:
@@ -933,9 +977,9 @@ def _update_litellm_logging_obj_environment(
933977
}
934978

935979
if litellm.cache is not None:
936-
litellm_params[
937-
"preset_cache_key"
938-
] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
980+
litellm_params["preset_cache_key"] = (
981+
litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
982+
)
939983
else:
940984
litellm_params["preset_cache_key"] = None
941985

0 commit comments

Comments
 (0)