Skip to content

Commit 9e46a31

Browse files
committed
Adding support for BlockedKV attention in CasualLM models
1 parent ed965fd commit 9e46a31

File tree

10 files changed

+269
-22
lines changed

10 files changed

+269
-22
lines changed

QEfficient/customop/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFunc3D, CtxScatterFunc, CtxScatterFunc3D
8+
from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFuncBlockedKV, CtxGatherFunc3D, CtxScatterFunc, CtxScatterFunc3D
99
from QEfficient.customop.ctx_scatter_gather_cb import (
1010
CtxGatherFuncCB,
11+
CtxGatherFuncBlockedKVCB,
1112
CtxGatherFuncCB3D,
1213
CtxScatterFuncCB,
1314
CtxScatterFuncCB3D,
@@ -16,12 +17,14 @@
1617

1718
__all__ = [
1819
"CtxGatherFunc",
20+
"CtxGatherFuncBlockedKV",
1921
"CtxScatterFunc",
2022
"CtxGatherFunc3D",
2123
"CtxScatterFunc3D",
2224
"CustomRMSNormAIC",
2325
"GemmaCustomRMSNormAIC",
2426
"CtxGatherFuncCB",
27+
"CtxGatherFuncBlockedKVCB",
2528
"CtxScatterFuncCB",
2629
"CtxGatherFuncCB3D",
2730
"CtxScatterFuncCB3D",

QEfficient/customop/ctx_scatter_gather.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,28 @@ def setup_context(ctx, inputs, outputs):
139139
@staticmethod
140140
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
141141
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
142+
143+
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
144+
def CtxGatherBlockedKV(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
145+
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
146+
return ops.GatherND(data, ctx_indices, batch_dims=2)
147+
148+
149+
class CtxGatherFuncBlockedKV(torch.autograd.Function):
150+
"""
151+
Function to gather only the valid key values from KV-cache.
152+
"""
153+
154+
@staticmethod
155+
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
156+
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
157+
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
158+
return data[batch_indices, head_indices, ctx_indices]
159+
160+
@staticmethod
161+
def setup_context(ctx, inputs, outputs):
162+
pass
163+
164+
@staticmethod
165+
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
166+
return g.onnxscript_op(CtxGatherBlockedKV, data, ctx_indices).setTypeAs(data)

QEfficient/customop/ctx_scatter_gather_cb.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,44 @@ def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_in
133133
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data)
134134

135135

136+
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
137+
def CtxGatherBlockedKVCB(
138+
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
139+
) -> onnxscript.FLOAT:
140+
batch_size = ops.Gather(ops.Shape(batch_index), [0])
141+
num_heads = ops.Gather(ops.Shape(data), [1])
142+
ctx_len = ops.Gather(ops.Shape(ctx_indices), [2])
143+
144+
# Expanded shape to create indices
145+
zero = ops.Constant(value_ints=[0])
146+
one = ops.Constant(value_ints=[1])
147+
exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
148+
149+
# Create indices
150+
batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape)
151+
head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape)
152+
ctx_idx = ops.Expand(ops.Unsqueeze(ctx_indices, [3]), exp_shape)
153+
indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3)
154+
155+
return ops.GatherND(data, indices)
156+
157+
158+
class CtxGatherFuncBlockedKVCB(torch.autograd.Function):
159+
@staticmethod
160+
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
161+
batch_indices = batch_index.view(-1, 1, 1)
162+
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
163+
return data[batch_indices, head_indices, ctx_indices]
164+
165+
@staticmethod
166+
def setup_context(ctx, inputs, outputs):
167+
pass
168+
169+
@staticmethod
170+
def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
171+
return g.onnxscript_op(CtxGatherBlockedKVCB, data, batch_index, ctx_indices).setTypeAs(data)
172+
173+
136174
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
137175
def CtxGatherCB3D(
138176
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32

QEfficient/transformers/cache_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
from QEfficient.customop import (
1616
CtxGatherFunc,
17+
CtxGatherFuncBlockedKV,
1718
CtxGatherFunc3D,
1819
CtxGatherFuncCB,
20+
CtxGatherFuncBlockedKVCB,
1921
CtxGatherFuncCB3D,
2022
CtxScatterFunc,
2123
CtxScatterFunc3D,
@@ -60,6 +62,49 @@ def read_only(self, cache_kwargs):
6062
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
6163

6264
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
65+
66+
def read_only_blockedKV(self, start_index, end_index, cache_kwargs):
67+
"""
68+
Reads the `key_states` and `value_states` for the layer for each KV block.
69+
70+
Parameters:
71+
cache_kwargs (`Dict[str, Any]`, `optional`):
72+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
73+
74+
start_index (`int`):
75+
Start index of the K/V block to read
76+
77+
end_index (`int`):
78+
End index of the K/V block to read
79+
80+
Return:
81+
A tuple containing the updated key and value states.
82+
"""
83+
# Gather
84+
k_out, v_out = self.keys, self.values
85+
position_ids = cache_kwargs.get("position_ids")
86+
batch_index = cache_kwargs.get("batch_index", None)
87+
batch, num_kv_heads, _, _ = k_out.shape
88+
ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...]
89+
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
90+
invalid_mask = ctx_indices > gather_limit
91+
92+
if torch.onnx.is_in_onnx_export():
93+
invalid_idx_value = torch.iinfo(torch.int32).max
94+
else:
95+
invalid_idx_value = 0
96+
97+
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
98+
99+
if batch_index is not None:
100+
k_out = CtxGatherFuncBlockedKVCB.apply(k_out, batch_index, ctx_indices)
101+
v_out = CtxGatherFuncBlockedKVCB.apply(v_out, batch_index, ctx_indices)
102+
else:
103+
ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1])
104+
k_out = CtxGatherFuncBlockedKV.apply(k_out, ctx_indices)
105+
v_out = CtxGatherFuncBlockedKV.apply(v_out, ctx_indices)
106+
107+
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
63108
return k_out, v_out
64109

65110
def write_only(self, key_states, value_states, cache_kwargs):
@@ -262,6 +307,25 @@ def read_only(self, layer_idx, cache_kwargs):
262307
"""
263308
return self.layers[layer_idx].read_only(cache_kwargs)
264309

310+
def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs):
311+
"""
312+
Reads the `key_states` and `value_states` for the layer `layer_idx`.
313+
314+
Parameters:
315+
start_index (`int`):
316+
Start index of the K/V block to read
317+
end_index (`int`):
318+
End index of the K/V block to read
319+
layer_idx (`int`):
320+
The index of the layer to cache the states for.
321+
cache_kwargs (`Dict[str, Any]`, `optional`):
322+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
323+
324+
Return:
325+
A tuple containing the updated key and value states.
326+
"""
327+
return self.layers[layer_idx].read_only_blockedKV(start_index, end_index, cache_kwargs)
328+
265329
def write_only(self, key_states, value_states, layer_idx, cache_kwargs):
266330
"""
267331
Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

QEfficient/transformers/modeling_attn_mask_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
def _create_causal_mask(
1414
position_ids,
1515
target_length,
16+
start_index: Optional[torch.Tensor] = torch.tensor(0, dtype=torch.int),
1617
sliding_window: Optional[int] = None,
1718
):
1819
"""
@@ -40,7 +41,7 @@ def _create_causal_mask(
4041
attention_mask = attention_mask.unsqueeze(1)
4142
else:
4243
query_indices = position_ids.unsqueeze(-1)
43-
kv_indices = torch.arange(target_length).view(1, 1, -1)
44+
kv_indices = torch.arange(start=start_index, end=target_length).view(1, 1, -1)
4445
attention_mask = kv_indices > query_indices
4546
attention_mask = attention_mask.unsqueeze(1)
4647

QEfficient/transformers/models/llama/modeling_llama.py

Lines changed: 77 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from typing import List, Optional, Tuple, Union
8+
from typing import List, Optional, Tuple, Union, Dict, Any
99

1010
import torch
1111
from torch import nn
@@ -103,21 +103,70 @@ def eager_attention_forward(
103103
value: torch.Tensor,
104104
attention_mask: Optional[torch.Tensor],
105105
scaling: float,
106+
num_kv_blocks: Optional[torch.Tensor] = None,
107+
cache_kwargs: Optional[Dict[str, Any]] = None,
108+
layer_idx: int = None,
109+
past_key_value: Optional[Cache] = None,
106110
**kwargs,
107111
):
108-
key_states = repeat_kv(key, module.num_key_value_groups)
109-
value_states = repeat_kv(value, module.num_key_value_groups)
110112

111-
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
112-
if attention_mask is not None:
113-
attn_weights = torch.where(
114-
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
115-
)
116-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
117-
attn_output = torch.matmul(attn_weights, value_states)
118-
attn_output = attn_output.transpose(1, 2).contiguous()
119-
120-
return attn_output, attn_weights
113+
if num_kv_blocks is not None:
114+
# Initialize result tensor
115+
output = torch.zeros_like(query)
116+
117+
# Perform blockwise computation for K and V
118+
M = torch.full((query.shape[0], query.shape[1], query.shape[2],),float(MIN_MASKED_ATTENTION_VALUE)) # Running Maximum
119+
D = torch.zeros((query.shape[0], query.shape[1], query.shape[2],))
120+
121+
past_seen_tokens = cache_kwargs.get("past_seen_tokens")
122+
position_ids = cache_kwargs.get("position_ids")
123+
num_kv_blocks = torch.tensor(num_kv_blocks, dtype=torch.int)
124+
block_size = -(-past_seen_tokens // num_kv_blocks)
125+
for j in range(num_kv_blocks):
126+
start_index = j*block_size
127+
end_index = (j+1)*block_size
128+
K_block, V_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs)
129+
K_block_states = repeat_kv(K_block, module.num_key_value_groups)
130+
V_block_states = repeat_kv(V_block, module.num_key_value_groups)
131+
past_seen_tokens_start = start_index
132+
past_seen_tokens_end = torch.where(past_seen_tokens < end_index, past_seen_tokens, end_index)
133+
causal_mask_block = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start)
134+
135+
# Compute attention scores for the block
136+
attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling
137+
if attention_mask is not None:
138+
attn_weights_block = torch.where(causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block)
139+
140+
# Update Running row maximum
141+
prevM = M
142+
M = torch.max(prevM, torch.max(attn_weights_block,axis=-1).values)
143+
deltaM = prevM - M
144+
145+
currentExp = torch.exp(attn_weights_block - M.unsqueeze(-1)) # Subract M from each column of attn_weights_block
146+
147+
#update Running denominator
148+
prevD = D
149+
D = prevD*torch.exp(deltaM) + currentExp.sum(axis = -1)
150+
151+
P = currentExp/D.unsqueeze(-1)
152+
153+
prevO = output
154+
output = ((prevD/D).unsqueeze(-1))*prevO*torch.exp(deltaM.unsqueeze(-1)) + torch.matmul(P, V_block_states) # This in higher precision.
155+
attn_output = output.transpose(1, 2).contiguous()
156+
157+
else: #regular attention
158+
key_states = repeat_kv(key, module.num_key_value_groups)
159+
value_states = repeat_kv(value, module.num_key_value_groups)
160+
161+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
162+
if attention_mask is not None:
163+
attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights)
164+
165+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
166+
attn_output = torch.matmul(attn_weights, value_states)
167+
attn_output = attn_output.transpose(1, 2).contiguous()
168+
169+
return attn_output
121170

122171

123172
class QEffLlamaAttention(LlamaAttention):
@@ -135,6 +184,7 @@ def forward(
135184
batch_index: Optional[torch.LongTensor] = None,
136185
use_cache: bool = False,
137186
cache_position: Optional[torch.LongTensor] = None,
187+
num_kv_blocks: Optional[torch.Tensor] = None,
138188
**kwargs,
139189
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
140190
input_shape = hidden_states.shape[:-1]
@@ -150,28 +200,37 @@ def forward(
150200
value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2)
151201

152202
kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
203+
past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0
153204
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
154205
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
155206

156207
if past_key_value is not None:
157-
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
158-
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
208+
if num_kv_blocks is not None:
209+
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "past_seen_tokens": past_seen_tokens}
210+
past_key_value.write_only(key_states, value_states, self.layer_idx, cache_kwargs)
211+
else:
212+
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
213+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
159214

160215
attention_interface = eager_attention_forward
161216

162-
attn_output, attn_weights = attention_interface(
217+
attn_output = attention_interface(
163218
self,
164219
query_states,
165220
key_states,
166221
value_states,
167222
attention_mask,
168223
scaling=self.scaling,
224+
num_kv_blocks=num_kv_blocks,
225+
cache_kwargs=cache_kwargs,
226+
layer_idx=self.layer_idx,
227+
past_key_value=past_key_value,
169228
**kwargs,
170229
)
171230
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
172231
attn_output = self.o_proj(attn_output, **kwargs)
173232

174-
return attn_output, attn_weights
233+
return attn_output
175234

176235

177236
class QEffLlamaDecoderLayer(LlamaDecoderLayer):
@@ -197,7 +256,7 @@ def forward(
197256
hidden_states = self.input_layernorm(hidden_states)
198257

199258
# Self Attention
200-
hidden_states, _ = self.self_attn(
259+
hidden_states = self.self_attn(
201260
hidden_states=hidden_states,
202261
attention_mask=attention_mask,
203262
position_ids=position_ids,

QEfficient/transformers/models/modeling_auto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from QEfficient.generation.vlm_generation import VisionLanguageGeneration
4040
from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH
4141
from QEfficient.transformers.models.pytorch_transforms import (
42+
BlockedKVAttentionTransform,
4243
CustomOpsTransform,
4344
KVCacheExternalModuleMapperTransform,
4445
KVCacheTransform,
@@ -2137,6 +2138,7 @@ def __init__(
21372138
- **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens.
21382139
For Speculative Decoding Target Language Models, this is always True.
21392140
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
2141+
- **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation.
21402142
**kwargs :
21412143
Additional keyword arguments passed to the base class constructor.
21422144
@@ -2182,6 +2184,9 @@ def __init__(
21822184
if self.is_tlm:
21832185
self.model.qaic_config["return_pdfs"] = True
21842186

2187+
if self.model.qaic_config["num_kv_blocks"] is not None:
2188+
BlockedKVAttentionTransform.apply(model, num_kv_blocks=self.model.qaic_config["num_kv_blocks"])
2189+
21852190
@property
21862191
def model_name(self) -> str:
21872192
"""

0 commit comments

Comments
 (0)