-
Notifications
You must be signed in to change notification settings - Fork 60
Adding support for BlockedKV attention in CasualLM models #618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b8c5a4b
415c1ce
cb1fd07
e926853
a1915c9
40af171
594fc59
15a282b
99ebcd4
8143414
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ | |
| # | ||
| # ----------------------------------------------------------------------------- | ||
|
|
||
| from typing import List, Optional, Tuple, Union | ||
| from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
@@ -113,13 +113,88 @@ def eager_attention_forward( | |
| attn_weights = torch.where( | ||
| attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights | ||
| ) | ||
|
|
||
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) | ||
| attn_output = torch.matmul(attn_weights, value_states) | ||
| attn_output = attn_output.transpose(1, 2).contiguous() | ||
|
|
||
| return attn_output, attn_weights | ||
|
|
||
|
|
||
| def eager_attention_forward_blockedKV( | ||
| module: nn.Module, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attention_mask: Optional[torch.Tensor], | ||
| scaling: float, | ||
| num_kv_blocks: Optional[torch.Tensor] = None, | ||
| cache_kwargs: Optional[Dict[str, Any]] = None, | ||
| layer_idx: int = None, | ||
| past_key_value: Optional[Cache] = None, | ||
| **kwargs, | ||
| ): | ||
| # Initialize result tensor | ||
| output = torch.zeros_like(query) | ||
|
|
||
| # Initialize Running Maximum | ||
| batch_size, num_heads, seq_len, _ = query.shape | ||
| current_max = torch.full((batch_size, num_heads, seq_len), float(MIN_MASKED_ATTENTION_VALUE)) | ||
|
|
||
| # Initialize Denominator | ||
| current_denominator = torch.zeros(batch_size, num_heads, seq_len) | ||
|
|
||
| past_seen_tokens = cache_kwargs.get("past_seen_tokens") | ||
| position_ids = cache_kwargs.get("position_ids") | ||
| block_size = -(-past_seen_tokens // num_kv_blocks) | ||
| masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32) | ||
|
|
||
| for j in range(num_kv_blocks): | ||
| start_index = j * block_size | ||
| end_index = (j + 1) * block_size | ||
| K_block, V_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) | ||
| K_block_states = repeat_kv(K_block, module.num_key_value_groups) | ||
| V_block_states = repeat_kv(V_block, module.num_key_value_groups) | ||
| past_seen_tokens_start = start_index | ||
| past_seen_tokens_end = torch.where( | ||
| torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), | ||
| past_seen_tokens, | ||
| end_index, | ||
| ) | ||
| causal_mask_block = _create_causal_mask( | ||
| position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start | ||
| ) | ||
|
|
||
| # Compute attention scores for the block | ||
| attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling | ||
| if attention_mask is not None: | ||
| attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) | ||
|
|
||
| # Update Running row maximum | ||
| prev_max = current_max | ||
| current_max = torch.max(prev_max, attn_weights_block.max(dim=-1).values) | ||
| delta_max = prev_max - current_max | ||
|
|
||
| current_exp = torch.exp( | ||
| attn_weights_block - current_max.unsqueeze(-1) | ||
| ) # Subract current_max from each column of attn_weights_block | ||
|
|
||
| # update running denominator | ||
| prev_denominator = current_denominator | ||
| current_denominator = prev_denominator * torch.exp(delta_max) + current_exp.sum(axis=-1) | ||
|
|
||
| prob = current_exp / current_denominator.unsqueeze(-1) | ||
|
|
||
| prev_output = output | ||
| output = ((prev_denominator / current_denominator).unsqueeze(-1)) * prev_output * torch.exp( | ||
| delta_max.unsqueeze(-1) | ||
| ) + torch.matmul(prob, V_block_states) | ||
| attn_output = output.transpose(1, 2).contiguous() | ||
| attn_weights = None | ||
|
|
||
| return attn_output, attn_weights | ||
|
|
||
|
|
||
| class QEffLlamaAttention(LlamaAttention): | ||
| """Multi-headed attention from 'Attention Is All You Need' paper""" | ||
|
|
||
|
|
@@ -136,6 +211,7 @@ def forward( | |
| batch_index: Optional[torch.LongTensor] = None, | ||
| use_cache: bool = False, | ||
| cache_position: Optional[torch.LongTensor] = None, | ||
| num_kv_blocks: Optional[torch.Tensor] = None, | ||
| **kwargs, | ||
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
| input_shape = hidden_states.shape[:-1] | ||
|
|
@@ -151,17 +227,29 @@ def forward( | |
| value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) | ||
|
|
||
| kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) | ||
| past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 | ||
| cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
| query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | ||
|
|
||
| if past_key_value is not None: | ||
| cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} | ||
| if comp_ctx_lengths is not None: | ||
| attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] | ||
| cache_kwargs["CCL"] = attention_mask.shape[-1] | ||
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||
|
|
||
| attention_interface = eager_attention_forward | ||
| if num_kv_blocks is not None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to update the past_seen_tokens? since we re passing the PKV cant we read it from the blocked attention method?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need the past_seen_tokens before we call the past_key_value.write_only(...) function since write to the cache will change the return value of past_key_value.get_seq_length(). |
||
| cache_kwargs = { | ||
| "batch_index": batch_index, | ||
| "position_ids": position_ids, | ||
| "past_seen_tokens": past_seen_tokens, | ||
| } | ||
| past_key_value.write_only(key_states, value_states, self.layer_idx, cache_kwargs) | ||
| else: | ||
| cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} | ||
| if comp_ctx_lengths is not None: | ||
| attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] | ||
| cache_kwargs["CCL"] = attention_mask.shape[-1] | ||
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||
|
|
||
| if num_kv_blocks is not None: | ||
| attention_interface = eager_attention_forward_blockedKV | ||
| else: | ||
| attention_interface = eager_attention_forward | ||
|
|
||
| attn_output, attn_weights = attention_interface( | ||
| self, | ||
|
|
@@ -170,6 +258,10 @@ def forward( | |
| value_states, | ||
| attention_mask, | ||
| scaling=self.scaling, | ||
| num_kv_blocks=num_kv_blocks, | ||
| cache_kwargs=cache_kwargs, | ||
| layer_idx=self.layer_idx, | ||
| past_key_value=past_key_value, | ||
| **kwargs, | ||
| ) | ||
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.