Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion QEfficient/customop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@
#
# -----------------------------------------------------------------------------

from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFunc3D, CtxScatterFunc, CtxScatterFunc3D
from QEfficient.customop.ctx_scatter_gather import (
CtxGatherFunc,
CtxGatherFunc3D,
CtxGatherFuncBlockedKV,
CtxScatterFunc,
CtxScatterFunc3D,
)
from QEfficient.customop.ctx_scatter_gather_cb import (
CtxGatherFuncBlockedKVCB,
CtxGatherFuncCB,
CtxGatherFuncCB3D,
CtxScatterFuncCB,
Expand All @@ -16,12 +23,14 @@

__all__ = [
"CtxGatherFunc",
"CtxGatherFuncBlockedKV",
"CtxScatterFunc",
"CtxGatherFunc3D",
"CtxScatterFunc3D",
"CustomRMSNormAIC",
"GemmaCustomRMSNormAIC",
"CtxGatherFuncCB",
"CtxGatherFuncBlockedKVCB",
"CtxScatterFuncCB",
"CtxGatherFuncCB3D",
"CtxScatterFuncCB3D",
Expand Down
26 changes: 26 additions & 0 deletions QEfficient/customop/ctx_scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,29 @@ def setup_context(ctx, inputs, outputs):
@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGatherBlockedKV(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
return ops.GatherND(data, ctx_indices, batch_dims=2)


class CtxGatherFuncBlockedKV(torch.autograd.Function):
"""
Function to gather only the valid key values from KV-cache.
"""

@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGatherBlockedKV, data, ctx_indices).setTypeAs(data)
38 changes: 38 additions & 0 deletions QEfficient/customop/ctx_scatter_gather_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,44 @@ def symbolic(
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGatherBlockedKVCB(
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
) -> onnxscript.FLOAT:
batch_size = ops.Gather(ops.Shape(batch_index), [0])
num_heads = ops.Gather(ops.Shape(data), [1])
ctx_len = ops.Gather(ops.Shape(ctx_indices), [2])

# Expanded shape to create indices
zero = ops.Constant(value_ints=[0])
one = ops.Constant(value_ints=[1])
exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)

# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape)
head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape)
ctx_idx = ops.Expand(ops.Unsqueeze(ctx_indices, [3]), exp_shape)
indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3)

return ops.GatherND(data, indices)


class CtxGatherFuncBlockedKVCB(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
batch_indices = batch_index.view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGatherBlockedKVCB, data, batch_index, ctx_indices).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGatherCB3D(
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
Expand Down
64 changes: 64 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from QEfficient.customop import (
CtxGatherFunc,
CtxGatherFunc3D,
CtxGatherFuncBlockedKV,
CtxGatherFuncBlockedKVCB,
CtxGatherFuncCB,
CtxGatherFuncCB3D,
CtxScatterFunc,
Expand Down Expand Up @@ -85,6 +87,49 @@ def read_only(self, cache_kwargs):
v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)

v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)

def read_only_blockedKV(self, start_index, end_index, cache_kwargs):
"""
Reads the `key_states` and `value_states` for the layer for each KV block.

Parameters:
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

start_index (`int`):
Start index of the K/V block to read

end_index (`int`):
End index of the K/V block to read

Return:
A tuple containing the updated key and value states.
"""
# Gather
k_out, v_out = self.keys, self.values
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)
batch, num_kv_heads, _, _ = k_out.shape
ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

if batch_index is not None:
k_out = CtxGatherFuncBlockedKVCB.apply(k_out, batch_index, ctx_indices)
v_out = CtxGatherFuncBlockedKVCB.apply(v_out, batch_index, ctx_indices)
else:
ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1])
k_out = CtxGatherFuncBlockedKV.apply(k_out, ctx_indices)
v_out = CtxGatherFuncBlockedKV.apply(v_out, ctx_indices)

v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out

def write_only(self, key_states, value_states, cache_kwargs):
Expand Down Expand Up @@ -284,6 +329,25 @@ def read_only(self, layer_idx, cache_kwargs):
"""
return self.layers[layer_idx].read_only(cache_kwargs)

def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs):
"""
Reads the `key_states` and `value_states` for the layer `layer_idx`.

Parameters:
start_index (`int`):
Start index of the K/V block to read
end_index (`int`):
End index of the K/V block to read
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

Return:
A tuple containing the updated key and value states.
"""
return self.layers[layer_idx].read_only_blockedKV(start_index, end_index, cache_kwargs)

def write_only(self, key_states, value_states, layer_idx, cache_kwargs):
"""
Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Expand Down
3 changes: 2 additions & 1 deletion QEfficient/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def _create_causal_mask(
position_ids,
target_length,
sliding_window: Optional[int] = None,
start_index: Optional[int] = 0,
):
"""
A utility attention mask class that allows one to:
Expand All @@ -40,7 +41,7 @@ def _create_causal_mask(
attention_mask = attention_mask.unsqueeze(1)
else:
query_indices = position_ids.unsqueeze(-1)
kv_indices = torch.arange(target_length).view(1, 1, -1)
kv_indices = torch.arange(start=start_index, end=target_length).view(1, 1, -1)
attention_mask = kv_indices > query_indices
attention_mask = attention_mask.unsqueeze(1)

Expand Down
108 changes: 100 additions & 8 deletions QEfficient/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# -----------------------------------------------------------------------------

from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -113,13 +113,88 @@ def eager_attention_forward(
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights


def eager_attention_forward_blockedKV(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
num_kv_blocks: Optional[torch.Tensor] = None,
cache_kwargs: Optional[Dict[str, Any]] = None,
layer_idx: int = None,
past_key_value: Optional[Cache] = None,
**kwargs,
):
# Initialize result tensor
output = torch.zeros_like(query)

# Initialize Running Maximum
batch_size, num_heads, seq_len, _ = query.shape
current_max = torch.full((batch_size, num_heads, seq_len), float(MIN_MASKED_ATTENTION_VALUE))

# Initialize Denominator
current_denominator = torch.zeros(batch_size, num_heads, seq_len)

past_seen_tokens = cache_kwargs.get("past_seen_tokens")
position_ids = cache_kwargs.get("position_ids")
block_size = -(-past_seen_tokens // num_kv_blocks)
masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32)

for j in range(num_kv_blocks):
start_index = j * block_size
end_index = (j + 1) * block_size
K_block, V_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs)
K_block_states = repeat_kv(K_block, module.num_key_value_groups)
V_block_states = repeat_kv(V_block, module.num_key_value_groups)
past_seen_tokens_start = start_index
past_seen_tokens_end = torch.where(
torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int),
past_seen_tokens,
end_index,
)
causal_mask_block = _create_causal_mask(
position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start
)

# Compute attention scores for the block
attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block)

# Update Running row maximum
prev_max = current_max
current_max = torch.max(prev_max, attn_weights_block.max(dim=-1).values)
delta_max = prev_max - current_max

current_exp = torch.exp(
attn_weights_block - current_max.unsqueeze(-1)
) # Subract current_max from each column of attn_weights_block

# update running denominator
prev_denominator = current_denominator
current_denominator = prev_denominator * torch.exp(delta_max) + current_exp.sum(axis=-1)

prob = current_exp / current_denominator.unsqueeze(-1)

prev_output = output
output = ((prev_denominator / current_denominator).unsqueeze(-1)) * prev_output * torch.exp(
delta_max.unsqueeze(-1)
) + torch.matmul(prob, V_block_states)
attn_output = output.transpose(1, 2).contiguous()
attn_weights = None

return attn_output, attn_weights


class QEffLlamaAttention(LlamaAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand All @@ -136,6 +211,7 @@ def forward(
batch_index: Optional[torch.LongTensor] = None,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
num_kv_blocks: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
Expand All @@ -151,17 +227,29 @@ def forward(
value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2)

kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
cache_kwargs["CCL"] = attention_mask.shape[-1]
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface = eager_attention_forward
if num_kv_blocks is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to update the past_seen_tokens? since we re passing the PKV cant we read it from the blocked attention method?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the past_seen_tokens before we call the past_key_value.write_only(...) function since write to the cache will change the return value of past_key_value.get_seq_length().

cache_kwargs = {
"batch_index": batch_index,
"position_ids": position_ids,
"past_seen_tokens": past_seen_tokens,
}
past_key_value.write_only(key_states, value_states, self.layer_idx, cache_kwargs)
else:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
cache_kwargs["CCL"] = attention_mask.shape[-1]
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if num_kv_blocks is not None:
attention_interface = eager_attention_forward_blockedKV
else:
attention_interface = eager_attention_forward

attn_output, attn_weights = attention_interface(
self,
Expand All @@ -170,6 +258,10 @@ def forward(
value_states,
attention_mask,
scaling=self.scaling,
num_kv_blocks=num_kv_blocks,
cache_kwargs=cache_kwargs,
layer_idx=self.layer_idx,
past_key_value=past_key_value,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
Expand Down
5 changes: 5 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from QEfficient.generation.vlm_generation import VisionLanguageGeneration
from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH
from QEfficient.transformers.models.pytorch_transforms import (
BlockedKVAttentionTransform,
CustomOpsTransform,
KVCacheExternalModuleMapperTransform,
KVCacheTransform,
Expand Down Expand Up @@ -2336,6 +2337,7 @@ def __init__(
- **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens.
For Speculative Decoding Target Language Models, this is always True.
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
- **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation.
**kwargs :
Additional keyword arguments passed to the base class constructor.

Expand Down Expand Up @@ -2384,6 +2386,9 @@ def __init__(
if self.is_tlm:
self.model.qaic_config["return_pdfs"] = True

if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None:
BlockedKVAttentionTransform.apply(model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks"))

@property
def model_name(self) -> str:
"""
Expand Down
Loading