Skip to content

Commit 1a6e271

Browse files
committed
static cache bug
1 parent d89e30b commit 1a6e271

File tree

1 file changed

+55
-21
lines changed

1 file changed

+55
-21
lines changed

src/transformers/cache_utils.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,17 @@ class StaticLayer(CacheLayerMixin):
258258
Args:
259259
max_cache_len (`int`):
260260
Maximum number of tokens that can be stored, used for tensor preallocation.
261+
max_batch_size(`int`, *optional*):
262+
Maximum batch size that can be stored
261263
"""
262264

263265
is_compileable = True
264266
is_sliding = False
265267

266-
def __init__(self, max_cache_len: int):
268+
def __init__(self, max_cache_len: int, max_batch_size: int | None = None):
267269
super().__init__()
268270
self.max_cache_len = max_cache_len
271+
self.max_batch_size = max_batch_size
269272

270273
def lazy_initialization(self, key_states: torch.Tensor):
271274
"""
@@ -281,26 +284,30 @@ def lazy_initialization(self, key_states: torch.Tensor):
281284
i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
282285
not be compiled anyway for performances!
283286
"""
284-
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
287+
if self.max_batch_size is None:
288+
self.max_batch_size = key_states.shape[0]
289+
_, self.num_heads, _, self.head_dim = key_states.shape
285290
self.dtype, self.device = key_states.dtype, key_states.device
286291

287-
self.keys = torch.zeros(
292+
self.keys_ = torch.zeros(
288293
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
289294
dtype=self.dtype,
290295
device=self.device,
291296
)
292-
self.values = torch.zeros(
297+
self.values_ = torch.zeros(
293298
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
294299
dtype=self.dtype,
295300
device=self.device,
296301
)
302+
self.keys = self.keys_
303+
self.values = self.values_
297304
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
298305
# breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
299306
# As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
300307
# prefill explicitly, but this should be avoided!)
301308
if not is_torchdynamo_compiling():
302-
torch._dynamo.mark_static_address(self.keys)
303-
torch._dynamo.mark_static_address(self.values)
309+
torch._dynamo.mark_static_address(self.keys_)
310+
torch._dynamo.mark_static_address(self.values_)
304311

305312
self.is_initialized = True
306313

@@ -331,21 +338,25 @@ def update(
331338
cache_position = (
332339
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
333340
)
334-
k_out = self.keys
335-
v_out = self.values
336341
batch_size = key_states.shape[0]
337-
if k_out.shape[0] != batch_size:
338-
k_out = k_out[:batch_size]
339-
v_out = v_out[:batch_size]
340-
# Update the cache
342+
# 3. Dynamic Slicing: Update the view to match current batch
343+
self.keys = self.keys_[:batch_size]
344+
self.values = self.values_[:batch_size]
341345
try:
342-
k_out.index_copy_(2, cache_position, key_states)
343-
v_out.index_copy_(2, cache_position, value_states)
346+
self.keys.index_copy_(2, cache_position, key_states)
347+
self.values.index_copy_(2, cache_position, value_states)
344348
except NotImplementedError:
345-
# Fallback for devices like MPS where index_copy_ might not be supported.
346-
k_out[:, :, cache_position] = key_states
347-
v_out[:, :, cache_position] = value_states
348-
return k_out, v_out
349+
self.keys[:, :, cache_position] = key_states
350+
self.values[:, :, cache_position] = value_states
351+
352+
return self.keys, self.values
353+
354+
def reset(self):
355+
if self.is_initialized:
356+
self.keys_.zero_()
357+
self.values_.zero_()
358+
self.keys = self.keys_
359+
self.values = self.values_
349360

350361
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
351362
"""Return the length and offset of the cache, used to generate the attention mask"""
@@ -1024,6 +1035,8 @@ class StaticCache(Cache):
10241035
offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
10251036
If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
10261037
usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
1038+
max_batch_size (`int`, *optional*):
1039+
The maximum batch size that will be used with this Cache .
10271040
10281041
Example:
10291042
@@ -1052,6 +1065,7 @@ def __init__(
10521065
max_cache_len: int,
10531066
offloading: bool = False,
10541067
offload_only_non_sliding: bool = True,
1068+
max_batch_size: int | None = None,
10551069
**kwargs,
10561070
):
10571071
config = config.get_text_config(decoder=True)
@@ -1071,19 +1085,39 @@ def __init__(
10711085
layers = []
10721086
for layer_type in layer_types:
10731087
if layer_type == "sliding_attention":
1074-
layer = StaticSlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window)
1088+
layer = StaticSlidingWindowLayer(
1089+
max_cache_len=max_cache_len, sliding_window=config.sliding_window, max_batch_size=max_batch_size
1090+
)
10751091
elif layer_type == "chunked_attention":
10761092
# From a cache point of view, both sliding and chunked are the same in how they should behave and how many
10771093
# states they should return - only the mask changes to make them different at the end!
10781094
layer = StaticSlidingWindowLayer(
1079-
max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size
1095+
max_cache_len=max_cache_len,
1096+
sliding_window=config.attention_chunk_size,
1097+
max_batch_size=max_batch_size,
10801098
)
10811099
else:
1082-
layer = StaticLayer(max_cache_len=max_cache_len)
1100+
layer = StaticLayer(max_cache_len=max_cache_len, max_batch_size=max_batch_size)
10831101
layers.append(layer)
10841102

10851103
super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
10861104

1105+
def update(
1106+
self,
1107+
key_states: torch.Tensor,
1108+
value_states: torch.Tensor,
1109+
layer_idx: int,
1110+
cache_kwargs: Optional[dict[str, Any]] = None,
1111+
) -> tuple[torch.Tensor, torch.Tensor]:
1112+
return self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
1113+
1114+
def reset(self):
1115+
for layer in self.layers:
1116+
layer.reset()
1117+
1118+
def __len__(self):
1119+
return len(self.layers)
1120+
10871121

10881122
class QuantizedCache(Cache):
10891123
"""

0 commit comments

Comments
 (0)