55#
66# -----------------------------------------------------------------------------
77
8- from typing import List , Optional , Tuple , Union
8+ from typing import List , Optional , Tuple , Union , Dict , Any
99
1010import torch
1111from torch import nn
@@ -103,21 +103,70 @@ def eager_attention_forward(
103103 value : torch .Tensor ,
104104 attention_mask : Optional [torch .Tensor ],
105105 scaling : float ,
106+ num_kv_blocks : Optional [torch .Tensor ] = None ,
107+ cache_kwargs : Optional [Dict [str , Any ]] = None ,
108+ layer_idx : int = None ,
109+ past_key_value : Optional [Cache ] = None ,
106110 ** kwargs ,
107111):
108- key_states = repeat_kv (key , module .num_key_value_groups )
109- value_states = repeat_kv (value , module .num_key_value_groups )
110112
111- attn_weights = torch .matmul (query , key_states .transpose (2 , 3 )) * scaling
112- if attention_mask is not None :
113- attn_weights = torch .where (
114- attention_mask , torch .tensor (MIN_MASKED_ATTENTION_VALUE , dtype = torch .float32 ), attn_weights
115- )
116- attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
117- attn_output = torch .matmul (attn_weights , value_states )
118- attn_output = attn_output .transpose (1 , 2 ).contiguous ()
119-
120- return attn_output , attn_weights
113+ if num_kv_blocks is not None :
114+ # Initialize result tensor
115+ output = torch .zeros_like (query )
116+
117+ # Perform blockwise computation for K and V
118+ M = torch .full ((query .shape [0 ], query .shape [1 ], query .shape [2 ],),float (MIN_MASKED_ATTENTION_VALUE )) # Running Maximum
119+ D = torch .zeros ((query .shape [0 ], query .shape [1 ], query .shape [2 ],))
120+
121+ past_seen_tokens = cache_kwargs .get ("past_seen_tokens" )
122+ position_ids = cache_kwargs .get ("position_ids" )
123+ num_kv_blocks = torch .tensor (num_kv_blocks , dtype = torch .int )
124+ block_size = - (- past_seen_tokens // num_kv_blocks )
125+ for j in range (num_kv_blocks ):
126+ start_index = j * block_size
127+ end_index = (j + 1 )* block_size
128+ K_block , V_block = past_key_value .read_only_blockedKV (start_index , end_index , layer_idx , cache_kwargs )
129+ K_block_states = repeat_kv (K_block , module .num_key_value_groups )
130+ V_block_states = repeat_kv (V_block , module .num_key_value_groups )
131+ past_seen_tokens_start = start_index
132+ past_seen_tokens_end = torch .where (past_seen_tokens < end_index , past_seen_tokens , end_index )
133+ causal_mask_block = _create_causal_mask (position_ids = position_ids , target_length = past_seen_tokens_end , start_index = past_seen_tokens_start )
134+
135+ # Compute attention scores for the block
136+ attn_weights_block = torch .matmul (query , K_block_states .transpose (2 , 3 )) * scaling
137+ if attention_mask is not None :
138+ attn_weights_block = torch .where (causal_mask_block , torch .tensor (MIN_MASKED_ATTENTION_VALUE , dtype = torch .float32 ), attn_weights_block )
139+
140+ # Update Running row maximum
141+ prevM = M
142+ M = torch .max (prevM , torch .max (attn_weights_block ,axis = - 1 ).values )
143+ deltaM = prevM - M
144+
145+ currentExp = torch .exp (attn_weights_block - M .unsqueeze (- 1 )) # Subract M from each column of attn_weights_block
146+
147+ #update Running denominator
148+ prevD = D
149+ D = prevD * torch .exp (deltaM ) + currentExp .sum (axis = - 1 )
150+
151+ P = currentExp / D .unsqueeze (- 1 )
152+
153+ prevO = output
154+ output = ((prevD / D ).unsqueeze (- 1 ))* prevO * torch .exp (deltaM .unsqueeze (- 1 )) + torch .matmul (P , V_block_states ) # This in higher precision.
155+ attn_output = output .transpose (1 , 2 ).contiguous ()
156+
157+ else : #regular attention
158+ key_states = repeat_kv (key , module .num_key_value_groups )
159+ value_states = repeat_kv (value , module .num_key_value_groups )
160+
161+ attn_weights = torch .matmul (query , key_states .transpose (2 , 3 )) * scaling
162+ if attention_mask is not None :
163+ attn_weights = torch .where (attention_mask , torch .tensor (MIN_MASKED_ATTENTION_VALUE , dtype = torch .float32 ), attn_weights )
164+
165+ attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
166+ attn_output = torch .matmul (attn_weights , value_states )
167+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
168+
169+ return attn_output
121170
122171
123172class QEffLlamaAttention (LlamaAttention ):
@@ -135,6 +184,7 @@ def forward(
135184 batch_index : Optional [torch .LongTensor ] = None ,
136185 use_cache : bool = False ,
137186 cache_position : Optional [torch .LongTensor ] = None ,
187+ num_kv_blocks : Optional [torch .Tensor ] = None ,
138188 ** kwargs ,
139189 ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
140190 input_shape = hidden_states .shape [:- 1 ]
@@ -150,28 +200,37 @@ def forward(
150200 value_states = self .v_proj (hidden_states , ** kwargs ).view (hidden_shape ).transpose (1 , 2 )
151201
152202 kv_seq_len = past_key_value .get_seq_length (self .layer_idx , cache_position )
203+ past_seen_tokens = past_key_value .get_seq_length () if past_key_value is not None else 0
153204 cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
154205 query_states , key_states = qeff_apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
155206
156207 if past_key_value is not None :
157- cache_kwargs = {"batch_index" : batch_index , "position_ids" : position_ids }
158- key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
208+ if num_kv_blocks is not None :
209+ cache_kwargs = {"batch_index" : batch_index , "position_ids" : position_ids , "past_seen_tokens" : past_seen_tokens }
210+ past_key_value .write_only (key_states , value_states , self .layer_idx , cache_kwargs )
211+ else :
212+ cache_kwargs = {"batch_index" : batch_index , "position_ids" : position_ids }
213+ key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
159214
160215 attention_interface = eager_attention_forward
161216
162- attn_output , attn_weights = attention_interface (
217+ attn_output = attention_interface (
163218 self ,
164219 query_states ,
165220 key_states ,
166221 value_states ,
167222 attention_mask ,
168223 scaling = self .scaling ,
224+ num_kv_blocks = num_kv_blocks ,
225+ cache_kwargs = cache_kwargs ,
226+ layer_idx = self .layer_idx ,
227+ past_key_value = past_key_value ,
169228 ** kwargs ,
170229 )
171230 attn_output = attn_output .reshape (* input_shape , - 1 ).contiguous ()
172231 attn_output = self .o_proj (attn_output , ** kwargs )
173232
174- return attn_output , attn_weights
233+ return attn_output
175234
176235
177236class QEffLlamaDecoderLayer (LlamaDecoderLayer ):
@@ -197,7 +256,7 @@ def forward(
197256 hidden_states = self .input_layernorm (hidden_states )
198257
199258 # Self Attention
200- hidden_states , _ = self .self_attn (
259+ hidden_states = self .self_attn (
201260 hidden_states = hidden_states ,
202261 attention_mask = attention_mask ,
203262 position_ids = position_ids ,
0 commit comments