@@ -258,14 +258,17 @@ class StaticLayer(CacheLayerMixin):
258258 Args:
259259 max_cache_len (`int`):
260260 Maximum number of tokens that can be stored, used for tensor preallocation.
261+ max_batch_size(`int`, *optional*):
262+ Maximum batch size that can be stored
261263 """
262264
263265 is_compileable = True
264266 is_sliding = False
265267
266- def __init__ (self , max_cache_len : int ):
268+ def __init__ (self , max_cache_len : int , max_batch_size : int | None = None ):
267269 super ().__init__ ()
268270 self .max_cache_len = max_cache_len
271+ self .max_batch_size = max_batch_size
269272
270273 def lazy_initialization (self , key_states : torch .Tensor ):
271274 """
@@ -281,26 +284,30 @@ def lazy_initialization(self, key_states: torch.Tensor):
281284 i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
282285 not be compiled anyway for performances!
283286 """
284- self .max_batch_size , self .num_heads , _ , self .head_dim = key_states .shape
287+ if self .max_batch_size is None :
288+ self .max_batch_size = key_states .shape [0 ]
289+ _ , self .num_heads , _ , self .head_dim = key_states .shape
285290 self .dtype , self .device = key_states .dtype , key_states .device
286291
287- self .keys = torch .zeros (
292+ self .keys_ = torch .zeros (
288293 (self .max_batch_size , self .num_heads , self .max_cache_len , self .head_dim ),
289294 dtype = self .dtype ,
290295 device = self .device ,
291296 )
292- self .values = torch .zeros (
297+ self .values_ = torch .zeros (
293298 (self .max_batch_size , self .num_heads , self .max_cache_len , self .head_dim ),
294299 dtype = self .dtype ,
295300 device = self .device ,
296301 )
302+ self .keys = self .keys_
303+ self .values = self .values_
297304 # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
298305 # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
299306 # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
300307 # prefill explicitly, but this should be avoided!)
301308 if not is_torchdynamo_compiling ():
302- torch ._dynamo .mark_static_address (self .keys )
303- torch ._dynamo .mark_static_address (self .values )
309+ torch ._dynamo .mark_static_address (self .keys_ )
310+ torch ._dynamo .mark_static_address (self .values_ )
304311
305312 self .is_initialized = True
306313
@@ -331,21 +338,25 @@ def update(
331338 cache_position = (
332339 cache_position if cache_position is not None else torch .arange (key_states .shape [- 2 ], device = self .device )
333340 )
334- k_out = self .keys
335- v_out = self .values
336341 batch_size = key_states .shape [0 ]
337- if k_out .shape [0 ] != batch_size :
338- k_out = k_out [:batch_size ]
339- v_out = v_out [:batch_size ]
340- # Update the cache
342+ # 3. Dynamic Slicing: Update the view to match current batch
343+ self .keys = self .keys_ [:batch_size ]
344+ self .values = self .values_ [:batch_size ]
341345 try :
342- k_out .index_copy_ (2 , cache_position , key_states )
343- v_out .index_copy_ (2 , cache_position , value_states )
346+ self . keys .index_copy_ (2 , cache_position , key_states )
347+ self . values .index_copy_ (2 , cache_position , value_states )
344348 except NotImplementedError :
345- # Fallback for devices like MPS where index_copy_ might not be supported.
346- k_out [:, :, cache_position ] = key_states
347- v_out [:, :, cache_position ] = value_states
348- return k_out , v_out
349+ self .keys [:, :, cache_position ] = key_states
350+ self .values [:, :, cache_position ] = value_states
351+
352+ return self .keys , self .values
353+
354+ def reset (self ):
355+ if self .is_initialized :
356+ self .keys_ .zero_ ()
357+ self .values_ .zero_ ()
358+ self .keys = self .keys_
359+ self .values = self .values_
349360
350361 def get_mask_sizes (self , cache_position : torch .Tensor ) -> tuple [int , int ]:
351362 """Return the length and offset of the cache, used to generate the attention mask"""
@@ -1024,6 +1035,8 @@ class StaticCache(Cache):
10241035 offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
10251036 If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
10261037 usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
1038+ max_batch_size (`int`, *optional*):
1039+ The maximum batch size that will be used with this Cache .
10271040
10281041 Example:
10291042
@@ -1052,6 +1065,7 @@ def __init__(
10521065 max_cache_len : int ,
10531066 offloading : bool = False ,
10541067 offload_only_non_sliding : bool = True ,
1068+ max_batch_size : int | None = None ,
10551069 ** kwargs ,
10561070 ):
10571071 config = config .get_text_config (decoder = True )
@@ -1071,19 +1085,39 @@ def __init__(
10711085 layers = []
10721086 for layer_type in layer_types :
10731087 if layer_type == "sliding_attention" :
1074- layer = StaticSlidingWindowLayer (max_cache_len = max_cache_len , sliding_window = config .sliding_window )
1088+ layer = StaticSlidingWindowLayer (
1089+ max_cache_len = max_cache_len , sliding_window = config .sliding_window , max_batch_size = max_batch_size
1090+ )
10751091 elif layer_type == "chunked_attention" :
10761092 # From a cache point of view, both sliding and chunked are the same in how they should behave and how many
10771093 # states they should return - only the mask changes to make them different at the end!
10781094 layer = StaticSlidingWindowLayer (
1079- max_cache_len = max_cache_len , sliding_window = config .attention_chunk_size
1095+ max_cache_len = max_cache_len ,
1096+ sliding_window = config .attention_chunk_size ,
1097+ max_batch_size = max_batch_size ,
10801098 )
10811099 else :
1082- layer = StaticLayer (max_cache_len = max_cache_len )
1100+ layer = StaticLayer (max_cache_len = max_cache_len , max_batch_size = max_batch_size )
10831101 layers .append (layer )
10841102
10851103 super ().__init__ (layers = layers , offloading = offloading , offload_only_non_sliding = offload_only_non_sliding )
10861104
1105+ def update (
1106+ self ,
1107+ key_states : torch .Tensor ,
1108+ value_states : torch .Tensor ,
1109+ layer_idx : int ,
1110+ cache_kwargs : Optional [dict [str , Any ]] = None ,
1111+ ) -> tuple [torch .Tensor , torch .Tensor ]:
1112+ return self .layers [layer_idx ].update (key_states , value_states , cache_kwargs )
1113+
1114+ def reset (self ):
1115+ for layer in self .layers :
1116+ layer .reset ()
1117+
1118+ def __len__ (self ):
1119+ return len (self .layers )
1120+
10871121
10881122class QuantizedCache (Cache ):
10891123 """
0 commit comments