Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion QEfficient/customop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@
#
# -----------------------------------------------------------------------------

from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFunc3D, CtxScatterFunc, CtxScatterFunc3D
from QEfficient.customop.ctx_scatter_gather import (
CtxGatherFunc,
CtxGatherFunc3D,
CtxGatherFuncBlockedKV,
CtxScatterFunc,
CtxScatterFunc3D,
)
from QEfficient.customop.ctx_scatter_gather_cb import (
CtxGatherFuncBlockedKVCB,
CtxGatherFuncCB,
CtxGatherFuncCB3D,
CtxScatterFuncCB,
Expand All @@ -16,12 +23,14 @@

__all__ = [
"CtxGatherFunc",
"CtxGatherFuncBlockedKV",
"CtxScatterFunc",
"CtxGatherFunc3D",
"CtxScatterFunc3D",
"CustomRMSNormAIC",
"GemmaCustomRMSNormAIC",
"CtxGatherFuncCB",
"CtxGatherFuncBlockedKVCB",
"CtxScatterFuncCB",
"CtxGatherFuncCB3D",
"CtxScatterFuncCB3D",
Expand Down
26 changes: 26 additions & 0 deletions QEfficient/customop/ctx_scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,29 @@ def setup_context(ctx, inputs, outputs):
@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGatherBlockedKV(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
return ops.GatherND(data, ctx_indices, batch_dims=2)


class CtxGatherFuncBlockedKV(torch.autograd.Function):
"""
Function to gather only the valid key values from KV-cache.
"""

@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGatherBlockedKV, data, ctx_indices).setTypeAs(data)
38 changes: 38 additions & 0 deletions QEfficient/customop/ctx_scatter_gather_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,44 @@ def symbolic(
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGatherBlockedKVCB(
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
) -> onnxscript.FLOAT:
batch_size = ops.Gather(ops.Shape(batch_index), [0])
num_heads = ops.Gather(ops.Shape(data), [1])
ctx_len = ops.Gather(ops.Shape(ctx_indices), [2])

# Expanded shape to create indices
zero = ops.Constant(value_ints=[0])
one = ops.Constant(value_ints=[1])
exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)

# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape)
head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape)
ctx_idx = ops.Expand(ops.Unsqueeze(ctx_indices, [3]), exp_shape)
indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3)

return ops.GatherND(data, indices)


class CtxGatherFuncBlockedKVCB(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
batch_indices = batch_index.view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]

@staticmethod
def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGatherBlockedKVCB, data, batch_index, ctx_indices).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGatherCB3D(
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
Expand Down
64 changes: 64 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from QEfficient.customop import (
CtxGatherFunc,
CtxGatherFunc3D,
CtxGatherFuncBlockedKV,
CtxGatherFuncBlockedKVCB,
CtxGatherFuncCB,
CtxGatherFuncCB3D,
CtxScatterFunc,
Expand Down Expand Up @@ -85,6 +87,49 @@ def read_only(self, cache_kwargs):
v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)

v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)

def read_only_blockedKV(self, start_index, end_index, cache_kwargs):
"""
Reads the `key_states` and `value_states` for the layer for each KV block.

Parameters:
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

start_index (`int`):
Start index of the K/V block to read

end_index (`int`):
End index of the K/V block to read

Return:
A tuple containing the updated key and value states.
"""
# Gather
k_out, v_out = self.keys, self.values
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)
batch, num_kv_heads, _, _ = k_out.shape
ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

if batch_index is not None:
k_out = CtxGatherFuncBlockedKVCB.apply(k_out, batch_index, ctx_indices)
v_out = CtxGatherFuncBlockedKVCB.apply(v_out, batch_index, ctx_indices)
else:
ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1])
k_out = CtxGatherFuncBlockedKV.apply(k_out, ctx_indices)
v_out = CtxGatherFuncBlockedKV.apply(v_out, ctx_indices)

v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out

def write_only(self, key_states, value_states, cache_kwargs):
Expand Down Expand Up @@ -284,6 +329,25 @@ def read_only(self, layer_idx, cache_kwargs):
"""
return self.layers[layer_idx].read_only(cache_kwargs)

def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs):
"""
Reads the `key_states` and `value_states` for the layer `layer_idx`.

Parameters:
start_index (`int`):
Start index of the K/V block to read
end_index (`int`):
End index of the K/V block to read
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

Return:
A tuple containing the updated key and value states.
"""
return self.layers[layer_idx].read_only_blockedKV(start_index, end_index, cache_kwargs)

def write_only(self, key_states, value_states, layer_idx, cache_kwargs):
"""
Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Expand Down
3 changes: 2 additions & 1 deletion QEfficient/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def _create_causal_mask(
position_ids,
target_length,
sliding_window: Optional[int] = None,
start_index: Optional[int] = 0,
):
"""
A utility attention mask class that allows one to:
Expand All @@ -40,7 +41,7 @@ def _create_causal_mask(
attention_mask = attention_mask.unsqueeze(1)
else:
query_indices = position_ids.unsqueeze(-1)
kv_indices = torch.arange(target_length).view(1, 1, -1)
kv_indices = torch.arange(start=start_index, end=target_length).view(1, 1, -1)
attention_mask = kv_indices > query_indices
attention_mask = attention_mask.unsqueeze(1)

Expand Down
117 changes: 104 additions & 13 deletions QEfficient/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# -----------------------------------------------------------------------------

from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -113,11 +113,85 @@ def eager_attention_forward(
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights
return attn_output


def eager_attention_forward_blockedKV(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
num_kv_blocks: Optional[torch.Tensor] = None,
cache_kwargs: Optional[Dict[str, Any]] = None,
layer_idx: int = None,
past_key_value: Optional[Cache] = None,
**kwargs,
):
# Initialize result tensor
output = torch.zeros_like(query)

# Initialize Running Maximum
batch_size, num_heads, seq_len, _ = query.shape
current_max = torch.full((batch_size, num_heads, seq_len), float(MIN_MASKED_ATTENTION_VALUE))

# Initialize Denominator
current_denominator = torch.zeros(batch_size, num_heads, seq_len)

past_seen_tokens = cache_kwargs.get("past_seen_tokens")
position_ids = cache_kwargs.get("position_ids")
block_size = -(-past_seen_tokens // num_kv_blocks)
masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32)

for j in range(num_kv_blocks):
start_index = j * block_size
end_index = (j + 1) * block_size
K_block, V_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs)
K_block_states = repeat_kv(K_block, module.num_key_value_groups)
V_block_states = repeat_kv(V_block, module.num_key_value_groups)
past_seen_tokens_start = start_index
past_seen_tokens_end = torch.where(
torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int),
past_seen_tokens,
end_index,
)
causal_mask_block = _create_causal_mask(
position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start
)

# Compute attention scores for the block
attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block)

# Update Running row maximum
prev_max = current_max
current_max = torch.max(prev_max, attn_weights_block.max(dim=-1).values)
delta_max = prev_max - current_max

current_exp = torch.exp(
attn_weights_block - current_max.unsqueeze(-1)
) # Subract current_max from each column of attn_weights_block

# update running denominator
prev_denominator = current_denominator
current_denominator = prev_denominator * torch.exp(delta_max) + current_exp.sum(axis=-1)

prob = current_exp / current_denominator.unsqueeze(-1)

prev_output = output
output = ((prev_denominator / current_denominator).unsqueeze(-1)) * prev_output * torch.exp(
delta_max.unsqueeze(-1)
) + torch.matmul(prob, V_block_states)
attn_output = output.transpose(1, 2).contiguous()

return attn_output


class QEffLlamaAttention(LlamaAttention):
Expand All @@ -136,6 +210,7 @@ def forward(
batch_index: Optional[torch.LongTensor] = None,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
num_kv_blocks: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
Expand All @@ -151,31 +226,47 @@ def forward(
value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2)

kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
cache_kwargs["CCL"] = attention_mask.shape[-1]
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

attention_interface = eager_attention_forward

attn_output, attn_weights = attention_interface(
if num_kv_blocks is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to update the past_seen_tokens? since we re passing the PKV cant we read it from the blocked attention method?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the past_seen_tokens before we call the past_key_value.write_only(...) function since write to the cache will change the return value of past_key_value.get_seq_length().

cache_kwargs = {
"batch_index": batch_index,
"position_ids": position_ids,
"past_seen_tokens": past_seen_tokens,
}
past_key_value.write_only(key_states, value_states, self.layer_idx, cache_kwargs)
else:
cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids}
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
cache_kwargs["CCL"] = attention_mask.shape[-1]
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

if num_kv_blocks is not None:
attention_interface = eager_attention_forward_blockedKV
else:
attention_interface = eager_attention_forward

attn_output = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
scaling=self.scaling,
num_kv_blocks=num_kv_blocks,
cache_kwargs=cache_kwargs,
layer_idx=self.layer_idx,
past_key_value=past_key_value,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output, **kwargs)

return attn_output, attn_weights
return attn_output


class QEffLlamaDecoderLayer(LlamaDecoderLayer):
Expand All @@ -202,7 +293,7 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, _ = self.self_attn(
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
Expand Down
Loading
Loading