From b8c5a4b1fa3750a805449bffd16c029aef70f686 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Thu, 13 Nov 2025 18:05:56 -0600 Subject: [PATCH 01/10] Adding support for BlockedKV attention in CasualLM models Signed-off-by: Vaibhav Verma --- QEfficient/customop/__init__.py | 5 +- QEfficient/customop/ctx_scatter_gather.py | 25 +++++ QEfficient/customop/ctx_scatter_gather_cb.py | 38 +++++++ QEfficient/transformers/cache_utils.py | 64 +++++++++++ .../transformers/modeling_attn_mask_utils.py | 3 +- .../models/llama/modeling_llama.py | 101 ++++++++++++++---- .../transformers/models/modeling_auto.py | 5 + .../transformers/models/pytorch_transforms.py | 18 ++++ QEfficient/utils/constants.py | 1 + .../models/test_causal_lm_models.py | 38 ++++++- 10 files changed, 273 insertions(+), 25 deletions(-) diff --git a/QEfficient/customop/__init__.py b/QEfficient/customop/__init__.py index ff0709f82..f06a8f212 100644 --- a/QEfficient/customop/__init__.py +++ b/QEfficient/customop/__init__.py @@ -5,9 +5,10 @@ # # ----------------------------------------------------------------------------- -from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFunc3D, CtxScatterFunc, CtxScatterFunc3D +from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFuncBlockedKV, CtxGatherFunc3D, CtxScatterFunc, CtxScatterFunc3D from QEfficient.customop.ctx_scatter_gather_cb import ( CtxGatherFuncCB, + CtxGatherFuncBlockedKVCB, CtxGatherFuncCB3D, CtxScatterFuncCB, CtxScatterFuncCB3D, @@ -16,12 +17,14 @@ __all__ = [ "CtxGatherFunc", + "CtxGatherFuncBlockedKV", "CtxScatterFunc", "CtxGatherFunc3D", "CtxScatterFunc3D", "CustomRMSNormAIC", "GemmaCustomRMSNormAIC", "CtxGatherFuncCB", + "CtxGatherFuncBlockedKVCB", "CtxScatterFuncCB", "CtxGatherFuncCB3D", "CtxScatterFuncCB3D", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 269ccb0be..a054c36c0 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -145,3 +145,28 @@ def setup_context(ctx, inputs, outputs): @staticmethod def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value: return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data) + +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxGatherBlockedKV(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: + ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) + return ops.GatherND(data, ctx_indices, batch_dims=2) + + +class CtxGatherFuncBlockedKV(torch.autograd.Function): + """ + Function to gather only the valid key values from KV-cache. + """ + + @staticmethod + def forward(data: torch.Tensor, ctx_indices: torch.Tensor): + batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) + head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + return data[batch_indices, head_indices, ctx_indices] + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxGatherBlockedKV, data, ctx_indices).setTypeAs(data) diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py index cc9693716..8a06bc2b1 100644 --- a/QEfficient/customop/ctx_scatter_gather_cb.py +++ b/QEfficient/customop/ctx_scatter_gather_cb.py @@ -139,6 +139,44 @@ def symbolic( return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data) +@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +def CtxGatherBlockedKVCB( + data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32 +) -> onnxscript.FLOAT: + batch_size = ops.Gather(ops.Shape(batch_index), [0]) + num_heads = ops.Gather(ops.Shape(data), [1]) + ctx_len = ops.Gather(ops.Shape(ctx_indices), [2]) + + # Expanded shape to create indices + zero = ops.Constant(value_ints=[0]) + one = ops.Constant(value_ints=[1]) + exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0) + + # Create indices + batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape) + head_idx = ops.Expand(ops.Unsqueeze(ops.Range(zero, num_heads, one), [0, 2, 3]), exp_shape) + ctx_idx = ops.Expand(ops.Unsqueeze(ctx_indices, [3]), exp_shape) + indices = ops.Concat(batch_idx, head_idx, ctx_idx, axis=3) + + return ops.GatherND(data, indices) + + +class CtxGatherFuncBlockedKVCB(torch.autograd.Function): + @staticmethod + def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor): + batch_indices = batch_index.view(-1, 1, 1) + head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + return data[batch_indices, head_indices, ctx_indices] + + @staticmethod + def setup_context(ctx, inputs, outputs): + pass + + @staticmethod + def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value: + return g.onnxscript_op(CtxGatherBlockedKVCB, data, batch_index, ctx_indices).setTypeAs(data) + + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGatherCB3D( data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32 diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 292fe0487..44e7d777a 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -14,8 +14,10 @@ from QEfficient.customop import ( CtxGatherFunc, + CtxGatherFuncBlockedKV, CtxGatherFunc3D, CtxGatherFuncCB, + CtxGatherFuncBlockedKVCB, CtxGatherFuncCB3D, CtxScatterFunc, CtxScatterFunc3D, @@ -85,6 +87,49 @@ def read_only(self, cache_kwargs): v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + + def read_only_blockedKV(self, start_index, end_index, cache_kwargs): + """ + Reads the `key_states` and `value_states` for the layer for each KV block. + + Parameters: + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + start_index (`int`): + Start index of the K/V block to read + + end_index (`int`): + End index of the K/V block to read + + Return: + A tuple containing the updated key and value states. + """ + # Gather + k_out, v_out = self.keys, self.values + position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index", None) + batch, num_kv_heads, _, _ = k_out.shape + ctx_indices = torch.arange(start=start_index, end=end_index)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncBlockedKVCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncBlockedKVCB.apply(v_out, batch_index, ctx_indices) + else: + ctx_indices = ctx_indices.expand(batch, num_kv_heads, ctx_indices.shape[-1]) + k_out = CtxGatherFuncBlockedKV.apply(k_out, ctx_indices) + v_out = CtxGatherFuncBlockedKV.apply(v_out, ctx_indices) + + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out def write_only(self, key_states, value_states, cache_kwargs): @@ -284,6 +329,25 @@ def read_only(self, layer_idx, cache_kwargs): """ return self.layers[layer_idx].read_only(cache_kwargs) + def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs): + """ + Reads the `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + start_index (`int`): + Start index of the K/V block to read + end_index (`int`): + End index of the K/V block to read + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + return self.layers[layer_idx].read_only_blockedKV(start_index, end_index, cache_kwargs) + def write_only(self, key_states, value_states, layer_idx, cache_kwargs): """ Write in the cache with the new `key_states` and `value_states` for the layer `layer_idx`. diff --git a/QEfficient/transformers/modeling_attn_mask_utils.py b/QEfficient/transformers/modeling_attn_mask_utils.py index 4faedba33..64a9d9b80 100644 --- a/QEfficient/transformers/modeling_attn_mask_utils.py +++ b/QEfficient/transformers/modeling_attn_mask_utils.py @@ -13,6 +13,7 @@ def _create_causal_mask( position_ids, target_length, + start_index: Optional[torch.Tensor] = torch.tensor(0, dtype=torch.int), sliding_window: Optional[int] = None, ): """ @@ -40,7 +41,7 @@ def _create_causal_mask( attention_mask = attention_mask.unsqueeze(1) else: query_indices = position_ids.unsqueeze(-1) - kv_indices = torch.arange(target_length).view(1, 1, -1) + kv_indices = torch.arange(start=start_index, end=target_length).view(1, 1, -1) attention_mask = kv_indices > query_indices attention_mask = attention_mask.unsqueeze(1) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 73b947dba..a705f1b9c 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict, Any import torch from torch import nn @@ -103,21 +103,70 @@ def eager_attention_forward( value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, + num_kv_blocks: Optional[torch.Tensor] = None, + cache_kwargs: Optional[Dict[str, Any]] = None, + layer_idx: int = None, + past_key_value: Optional[Cache] = None, **kwargs, ): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights + if num_kv_blocks is not None: + # Initialize result tensor + output = torch.zeros_like(query) + + # Perform blockwise computation for K and V + M = torch.full((query.shape[0], query.shape[1], query.shape[2],),float(MIN_MASKED_ATTENTION_VALUE)) # Running Maximum + D = torch.zeros((query.shape[0], query.shape[1], query.shape[2],)) + + past_seen_tokens = cache_kwargs.get("past_seen_tokens") + position_ids = cache_kwargs.get("position_ids") + num_kv_blocks = torch.tensor(num_kv_blocks, dtype=torch.int) + block_size = -(-past_seen_tokens // num_kv_blocks) + for j in range(num_kv_blocks): + start_index = j*block_size + end_index = (j+1)*block_size + K_block, V_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) + K_block_states = repeat_kv(K_block, module.num_key_value_groups) + V_block_states = repeat_kv(V_block, module.num_key_value_groups) + past_seen_tokens_start = start_index + past_seen_tokens_end = torch.where(past_seen_tokens < end_index, past_seen_tokens, end_index) + causal_mask_block = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start) + + # Compute attention scores for the block + attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights_block = torch.where(causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block) + + # Update Running row maximum + prevM = M + M = torch.max(prevM, torch.max(attn_weights_block,axis=-1).values) + deltaM = prevM - M + + currentExp = torch.exp(attn_weights_block - M.unsqueeze(-1)) # Subract M from each column of attn_weights_block + + #update Running denominator + prevD = D + D = prevD*torch.exp(deltaM) + currentExp.sum(axis = -1) + + P = currentExp/D.unsqueeze(-1) + + prevO = output + output = ((prevD/D).unsqueeze(-1))*prevO*torch.exp(deltaM.unsqueeze(-1)) + torch.matmul(P, V_block_states) # This in higher precision. + attn_output = output.transpose(1, 2).contiguous() + + else: #regular attention + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output class QEffLlamaAttention(LlamaAttention): @@ -136,6 +185,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + num_kv_blocks: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -151,31 +201,40 @@ def forward( value_states = self.v_proj(hidden_states, **kwargs).view(hidden_shape).transpose(1, 2) kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position) + past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0 cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} - if comp_ctx_lengths is not None: - attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] - cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if num_kv_blocks is not None: + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "past_seen_tokens": past_seen_tokens} + past_key_value.write_only(key_states, value_states, self.layer_idx, cache_kwargs) + else: + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward - attn_output, attn_weights = attention_interface( + attn_output = attention_interface( self, query_states, key_states, value_states, attention_mask, scaling=self.scaling, + num_kv_blocks=num_kv_blocks, + cache_kwargs=cache_kwargs, + layer_idx=self.layer_idx, + past_key_value=past_key_value, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output, **kwargs) - return attn_output, attn_weights + return attn_output class QEffLlamaDecoderLayer(LlamaDecoderLayer): @@ -202,7 +261,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, _ = self.self_attn( + hidden_states = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index cbff5be91..e65efe172 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -42,6 +42,7 @@ from QEfficient.generation.vlm_generation import VisionLanguageGeneration from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH from QEfficient.transformers.models.pytorch_transforms import ( + BlockedKVAttentionTransform, CustomOpsTransform, KVCacheExternalModuleMapperTransform, KVCacheTransform, @@ -2336,6 +2337,7 @@ def __init__( - **return_pdfs** (bool): If True, returns probability distributions along with sampled tokens. For Speculative Decoding Target Language Models, this is always True. - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + - **num_kv_blocks** (int): Number of K/V blocks for BlockedKV attention implementation. **kwargs : Additional keyword arguments passed to the base class constructor. @@ -2384,6 +2386,9 @@ def __init__( if self.is_tlm: self.model.qaic_config["return_pdfs"] = True + if self.model.qaic_config["num_kv_blocks"] is not None: + BlockedKVAttentionTransform.apply(model, num_kv_blocks=self.model.qaic_config["num_kv_blocks"]) + @property def model_name(self) -> str: """ diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 62a873b9e..0258c8e83 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +from functools import partial import warnings from types import MethodType from typing import Callable, Optional, Tuple, Union @@ -847,3 +848,20 @@ def get_decoder_layer_classes_for_export(model: nn.Module) -> set: model_decoder_classes.add(module.__class__) return model_decoder_classes + + +class BlockedKVAttentionTransform: + _module_mapping = { + QEffLlamaAttention, + } + + @classmethod + def apply(cls, model: nn.Module, num_kv_blocks) -> Tuple[nn.Module, bool]: + transformed = False + for module in model.modules(): + if type(module) in cls._module_mapping: + repl_module = type(module) + module.__class__ = repl_module + module.forward = MethodType(partial(repl_module.forward, num_kv_blocks=num_kv_blocks), module) + transformed = True # Set to True if at least one transformation occurs + return model, transformed diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 1504bdae5..9dd43b5ab 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -139,6 +139,7 @@ class Constants: MAX_QPC_LIMIT = 30 MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download NUM_SPECULATIVE_TOKENS = 2 + NUM_KV_BLOCKS = 8 MAX_TOP_K_IDS = ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS SAMPLER_OPS = { "repetition_penalties", diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 321a466ab..5717d3990 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -67,6 +67,10 @@ "Qwen/Qwen2-0.5B", ] +test_models_blockedKV = [ + "meta-llama/Llama-3.3-70B-Instruct", +] + def get_custom_n_layers(model_name): """ @@ -147,6 +151,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( qnn_config: Optional[str] = None, config: Optional[AutoConfig] = None, pytorch_hf_tokens: Optional[list] = None, + qaic_config: Optional[dict] = None, ): """ Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. @@ -179,7 +184,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( is_tlm = False if num_speculative_tokens is None else True qeff_model = QEFFAutoModelForCausalLM( - copy.deepcopy(model_hf), is_tlm=is_tlm, pretrained_model_name_or_path=model_name + copy.deepcopy(model_hf), is_tlm=is_tlm, pretrained_model_name_or_path=model_name, qaic_config=qaic_config ) pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) @@ -243,7 +248,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_hf_tokens = [pytorch_hf_tokens for _ in range(full_batch_size)] qeff_model = QEFFAutoModelForCausalLM( - model_hf, continuous_batching=True, is_tlm=is_tlm, pretrained_model_name_or_path=model_name + model_hf, continuous_batching=True, is_tlm=is_tlm, pretrained_model_name_or_path=model_name, qaic_config=qaic_config ) onnx_model_path = qeff_model.export() @@ -488,3 +493,32 @@ def test_prefiill_only_pytorch_vs_kv_vs_ort_vs_ai100_qnn(): check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name, n_layer=n_layer, prefill_only=False, enable_qnn=True, qnn_config=qnn_config_json_path ) + + +@pytest.mark.parametrize("model_name", test_models_blockedKV) +def test_causal_blockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model for KV blocking, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + n_layer = get_custom_n_layers(model_name) + + qaic_config = dict(num_kv_blocks=Constants.NUM_KV_BLOCKS) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, n_layer=n_layer, qaic_config=qaic_config + ) + +@pytest.mark.parametrize("model_name", test_models_blockedKV) +def test_causal_nonBlockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model for KV blocking, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + n_layer = get_custom_n_layers(model_name) + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, n_layer=n_layer + ) + From 415c1cecf1c176cde714aabfc46ce83b44c95556 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Fri, 14 Nov 2025 01:02:05 -0600 Subject: [PATCH 02/10] Updated num_kv_blocks checking within qaic_config to use .get() Signed-off-by: Vaibhav Verma --- QEfficient/transformers/models/modeling_auto.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index e65efe172..928c9f08f 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2386,8 +2386,9 @@ def __init__( if self.is_tlm: self.model.qaic_config["return_pdfs"] = True - if self.model.qaic_config["num_kv_blocks"] is not None: - BlockedKVAttentionTransform.apply(model, num_kv_blocks=self.model.qaic_config["num_kv_blocks"]) + num_kv_blocks = self.model.qaic_config.get("num_kv_blocks", None) + if num_kv_blocks is not None: + BlockedKVAttentionTransform.apply(model, num_kv_blocks=num_kv_blocks) @property def model_name(self) -> str: From cb1fd076a212cfe8a4b2132d560413cdb1abca39 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Fri, 14 Nov 2025 01:52:06 -0600 Subject: [PATCH 03/10] Updated modeling_auto.py to handle num_kv_blocks=None case gracefully Signed-off-by: Vaibhav Verma --- QEfficient/transformers/models/modeling_auto.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 928c9f08f..ccb8301db 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2386,7 +2386,11 @@ def __init__( if self.is_tlm: self.model.qaic_config["return_pdfs"] = True - num_kv_blocks = self.model.qaic_config.get("num_kv_blocks", None) + if self.model.qaic_config is not None: + num_kv_blocks = self.model.qaic_config.get("num_kv_blocks", None) + else: + num_kv_blocks = None + if num_kv_blocks is not None: BlockedKVAttentionTransform.apply(model, num_kv_blocks=num_kv_blocks) From e926853ba8317708963a0cbfbab372a59be88c60 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Fri, 14 Nov 2025 17:12:55 -0600 Subject: [PATCH 04/10] Fix to satisfy where op needing a tensor condition and arange needing number indices Signed-off-by: Vaibhav Verma --- .../models/llama/modeling_llama.py | 85 +++++++++++++------ 1 file changed, 57 insertions(+), 28 deletions(-) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index a705f1b9c..9e4e3f846 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import List, Optional, Tuple, Union, Dict, Any +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -109,58 +109,83 @@ def eager_attention_forward( past_key_value: Optional[Cache] = None, **kwargs, ): - if num_kv_blocks is not None: # Initialize result tensor output = torch.zeros_like(query) - + # Perform blockwise computation for K and V - M = torch.full((query.shape[0], query.shape[1], query.shape[2],),float(MIN_MASKED_ATTENTION_VALUE)) # Running Maximum - D = torch.zeros((query.shape[0], query.shape[1], query.shape[2],)) + M = torch.full( + ( + query.shape[0], + query.shape[1], + query.shape[2], + ), + float(MIN_MASKED_ATTENTION_VALUE), + ) # Running Maximum + D = torch.zeros( + ( + query.shape[0], + query.shape[1], + query.shape[2], + ) + ) past_seen_tokens = cache_kwargs.get("past_seen_tokens") position_ids = cache_kwargs.get("position_ids") - num_kv_blocks = torch.tensor(num_kv_blocks, dtype=torch.int) block_size = -(-past_seen_tokens // num_kv_blocks) for j in range(num_kv_blocks): - start_index = j*block_size - end_index = (j+1)*block_size + start_index = j * block_size + end_index = (j + 1) * block_size K_block, V_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) K_block_states = repeat_kv(K_block, module.num_key_value_groups) V_block_states = repeat_kv(V_block, module.num_key_value_groups) past_seen_tokens_start = start_index - past_seen_tokens_end = torch.where(past_seen_tokens < end_index, past_seen_tokens, end_index) - causal_mask_block = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start) - - # Compute attention scores for the block + past_seen_tokens_end = torch.where( + torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), + past_seen_tokens, + end_index, + ) + causal_mask_block = _create_causal_mask( + position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start + ) + + # Compute attention scores for the block attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights_block = torch.where(causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block) - + attn_weights_block = torch.where( + causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block + ) + # Update Running row maximum prevM = M - M = torch.max(prevM, torch.max(attn_weights_block,axis=-1).values) - deltaM = prevM - M - - currentExp = torch.exp(attn_weights_block - M.unsqueeze(-1)) # Subract M from each column of attn_weights_block - - #update Running denominator + M = torch.max(prevM, torch.max(attn_weights_block, axis=-1).values) + deltaM = prevM - M + + currentExp = torch.exp( + attn_weights_block - M.unsqueeze(-1) + ) # Subract M from each column of attn_weights_block + + # update Running denominator prevD = D - D = prevD*torch.exp(deltaM) + currentExp.sum(axis = -1) - - P = currentExp/D.unsqueeze(-1) - + D = prevD * torch.exp(deltaM) + currentExp.sum(axis=-1) + + P = currentExp / D.unsqueeze(-1) + prevO = output - output = ((prevD/D).unsqueeze(-1))*prevO*torch.exp(deltaM.unsqueeze(-1)) + torch.matmul(P, V_block_states) # This in higher precision. + output = ((prevD / D).unsqueeze(-1)) * prevO * torch.exp(deltaM.unsqueeze(-1)) + torch.matmul( + P, V_block_states + ) # This in higher precision. attn_output = output.transpose(1, 2).contiguous() - else: #regular attention + else: # regular attention key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) @@ -207,7 +232,11 @@ def forward( if past_key_value is not None: if num_kv_blocks is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "past_seen_tokens": past_seen_tokens} + cache_kwargs = { + "batch_index": batch_index, + "position_ids": position_ids, + "past_seen_tokens": past_seen_tokens, + } past_key_value.write_only(key_states, value_states, self.layer_idx, cache_kwargs) else: cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} From a1915c9081742367d97eeaa51bc20bb592a03a27 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Fri, 14 Nov 2025 17:51:16 -0600 Subject: [PATCH 05/10] Minor fix for _create_causal_mask arg order Signed-off-by: Vaibhav Verma --- QEfficient/transformers/modeling_attn_mask_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/modeling_attn_mask_utils.py b/QEfficient/transformers/modeling_attn_mask_utils.py index 64a9d9b80..629c10dd6 100644 --- a/QEfficient/transformers/modeling_attn_mask_utils.py +++ b/QEfficient/transformers/modeling_attn_mask_utils.py @@ -13,8 +13,8 @@ def _create_causal_mask( position_ids, target_length, - start_index: Optional[torch.Tensor] = torch.tensor(0, dtype=torch.int), sliding_window: Optional[int] = None, + start_index: Optional[int] = 0, ): """ A utility attention mask class that allows one to: From 40af1718d10792a86bd8d76fa4d240127c60ad52 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Mon, 17 Nov 2025 13:29:09 -0600 Subject: [PATCH 06/10] removing gated llama3.3-70B model from test + lint/format Signed-off-by: Vaibhav Verma --- QEfficient/customop/__init__.py | 10 ++++++++-- QEfficient/customop/ctx_scatter_gather.py | 1 + QEfficient/transformers/cache_utils.py | 4 ++-- .../transformers/models/pytorch_transforms.py | 2 +- .../models/test_causal_lm_models.py | 19 ++++++++++--------- 5 files changed, 22 insertions(+), 14 deletions(-) diff --git a/QEfficient/customop/__init__.py b/QEfficient/customop/__init__.py index f06a8f212..35830aa91 100644 --- a/QEfficient/customop/__init__.py +++ b/QEfficient/customop/__init__.py @@ -5,10 +5,16 @@ # # ----------------------------------------------------------------------------- -from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFuncBlockedKV, CtxGatherFunc3D, CtxScatterFunc, CtxScatterFunc3D +from QEfficient.customop.ctx_scatter_gather import ( + CtxGatherFunc, + CtxGatherFunc3D, + CtxGatherFuncBlockedKV, + CtxScatterFunc, + CtxScatterFunc3D, +) from QEfficient.customop.ctx_scatter_gather_cb import ( - CtxGatherFuncCB, CtxGatherFuncBlockedKVCB, + CtxGatherFuncCB, CtxGatherFuncCB3D, CtxScatterFuncCB, CtxScatterFuncCB3D, diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index a054c36c0..c7dc8639a 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -146,6 +146,7 @@ def setup_context(ctx, inputs, outputs): def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value: return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data) + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGatherBlockedKV(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 44e7d777a..20a092c94 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -14,10 +14,10 @@ from QEfficient.customop import ( CtxGatherFunc, - CtxGatherFuncBlockedKV, CtxGatherFunc3D, - CtxGatherFuncCB, + CtxGatherFuncBlockedKV, CtxGatherFuncBlockedKVCB, + CtxGatherFuncCB, CtxGatherFuncCB3D, CtxScatterFunc, CtxScatterFunc3D, diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 0258c8e83..02dab3549 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -5,8 +5,8 @@ # # ----------------------------------------------------------------------------- -from functools import partial import warnings +from functools import partial from types import MethodType from typing import Callable, Optional, Tuple, Union diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 5717d3990..70df7e3b8 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -68,7 +68,8 @@ ] test_models_blockedKV = [ - "meta-llama/Llama-3.3-70B-Instruct", + # "meta-llama/Llama-3.3-70B-Instruct", + "meta-llama/Llama-3.2-1B", ] @@ -248,7 +249,11 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_hf_tokens = [pytorch_hf_tokens for _ in range(full_batch_size)] qeff_model = QEFFAutoModelForCausalLM( - model_hf, continuous_batching=True, is_tlm=is_tlm, pretrained_model_name_or_path=model_name, qaic_config=qaic_config + model_hf, + continuous_batching=True, + is_tlm=is_tlm, + pretrained_model_name_or_path=model_name, + qaic_config=qaic_config, ) onnx_model_path = qeff_model.export() @@ -505,9 +510,8 @@ def test_causal_blockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): n_layer = get_custom_n_layers(model_name) qaic_config = dict(num_kv_blocks=Constants.NUM_KV_BLOCKS) - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( - model_name=model_name, n_layer=n_layer, qaic_config=qaic_config - ) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) + @pytest.mark.parametrize("model_name", test_models_blockedKV) def test_causal_nonBlockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): @@ -518,7 +522,4 @@ def test_causal_nonBlockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): """ n_layer = get_custom_n_layers(model_name) - check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( - model_name=model_name, n_layer=n_layer - ) - + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) From 594fc5951ea6658c7a6040bd1c6bffbf40147f6b Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Fri, 21 Nov 2025 15:54:17 -0600 Subject: [PATCH 07/10] separated eager attention into separate methods + minor fixes Signed-off-by: Vaibhav Verma --- .../models/llama/modeling_llama.py | 153 +++++++++--------- .../transformers/models/modeling_auto.py | 9 +- 2 files changed, 80 insertions(+), 82 deletions(-) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 9e4e3f846..542319ce4 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -103,93 +103,93 @@ def eager_attention_forward( value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, - num_kv_blocks: Optional[torch.Tensor] = None, - cache_kwargs: Optional[Dict[str, Any]] = None, - layer_idx: int = None, - past_key_value: Optional[Cache] = None, **kwargs, ): - if num_kv_blocks is not None: - # Initialize result tensor - output = torch.zeros_like(query) - - # Perform blockwise computation for K and V - M = torch.full( - ( - query.shape[0], - query.shape[1], - query.shape[2], - ), - float(MIN_MASKED_ATTENTION_VALUE), - ) # Running Maximum - D = torch.zeros( - ( - query.shape[0], - query.shape[1], - query.shape[2], - ) + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights ) - past_seen_tokens = cache_kwargs.get("past_seen_tokens") - position_ids = cache_kwargs.get("position_ids") - block_size = -(-past_seen_tokens // num_kv_blocks) - for j in range(num_kv_blocks): - start_index = j * block_size - end_index = (j + 1) * block_size - K_block, V_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) - K_block_states = repeat_kv(K_block, module.num_key_value_groups) - V_block_states = repeat_kv(V_block, module.num_key_value_groups) - past_seen_tokens_start = start_index - past_seen_tokens_end = torch.where( - torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), - past_seen_tokens, - end_index, - ) - causal_mask_block = _create_causal_mask( - position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start - ) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() - # Compute attention scores for the block - attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling - if attention_mask is not None: - attn_weights_block = torch.where( - causal_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights_block - ) + return attn_output - # Update Running row maximum - prevM = M - M = torch.max(prevM, torch.max(attn_weights_block, axis=-1).values) - deltaM = prevM - M - currentExp = torch.exp( - attn_weights_block - M.unsqueeze(-1) - ) # Subract M from each column of attn_weights_block +def eager_attention_forward_blockedKV( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: Optional[torch.Tensor] = None, + cache_kwargs: Optional[Dict[str, Any]] = None, + layer_idx: int = None, + past_key_value: Optional[Cache] = None, + **kwargs, +): + # Initialize result tensor + output = torch.zeros_like(query) + + # Initialize Running Maximum + batch_size, num_heads, seq_len, _ = query.shape + current_max = torch.full((batch_size, num_heads, seq_len), float(MIN_MASKED_ATTENTION_VALUE)) + + # Initialize Denominator + current_denominator = torch.zeros(batch_size, num_heads, seq_len) + + past_seen_tokens = cache_kwargs.get("past_seen_tokens") + position_ids = cache_kwargs.get("position_ids") + block_size = -(-past_seen_tokens // num_kv_blocks) + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32) + + for j in range(num_kv_blocks): + start_index = j * block_size + end_index = (j + 1) * block_size + K_block, V_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) + K_block_states = repeat_kv(K_block, module.num_key_value_groups) + V_block_states = repeat_kv(V_block, module.num_key_value_groups) + past_seen_tokens_start = start_index + past_seen_tokens_end = torch.where( + torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), + past_seen_tokens, + end_index, + ) + causal_mask_block = _create_causal_mask( + position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start + ) - # update Running denominator - prevD = D - D = prevD * torch.exp(deltaM) + currentExp.sum(axis=-1) + # Compute attention scores for the block + attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) - P = currentExp / D.unsqueeze(-1) + # Update Running row maximum + prev_max = current_max + current_max = torch.max(prev_max, attn_weights_block.max(dim=-1).values) + delta_max = prev_max - current_max - prevO = output - output = ((prevD / D).unsqueeze(-1)) * prevO * torch.exp(deltaM.unsqueeze(-1)) + torch.matmul( - P, V_block_states - ) # This in higher precision. - attn_output = output.transpose(1, 2).contiguous() + current_exp = torch.exp( + attn_weights_block - current_max.unsqueeze(-1) + ) # Subract current_max from each column of attn_weights_block - else: # regular attention - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) + # update running denominator + prev_denominator = current_denominator + current_denominator = prev_denominator * torch.exp(delta_max) + current_exp.sum(axis=-1) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - attn_weights = torch.where( - attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights - ) + prob = current_exp / current_denominator.unsqueeze(-1) - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() + prev_output = output + output = ((prev_denominator / current_denominator).unsqueeze(-1)) * prev_output * torch.exp( + delta_max.unsqueeze(-1) + ) + torch.matmul(prob, V_block_states) + attn_output = output.transpose(1, 2).contiguous() return attn_output @@ -245,7 +245,10 @@ def forward( cache_kwargs["CCL"] = attention_mask.shape[-1] key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface = eager_attention_forward + if num_kv_blocks is not None: + attention_interface = eager_attention_forward_blockedKV + else: + attention_interface = eager_attention_forward attn_output = attention_interface( self, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index ccb8301db..86cba9fcd 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2386,13 +2386,8 @@ def __init__( if self.is_tlm: self.model.qaic_config["return_pdfs"] = True - if self.model.qaic_config is not None: - num_kv_blocks = self.model.qaic_config.get("num_kv_blocks", None) - else: - num_kv_blocks = None - - if num_kv_blocks is not None: - BlockedKVAttentionTransform.apply(model, num_kv_blocks=num_kv_blocks) + if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: + BlockedKVAttentionTransform.apply(model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) @property def model_name(self) -> str: From 15a282b8e112c6f9a24bbd6111927e57be85ea0e Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Tue, 25 Nov 2025 11:55:55 -0600 Subject: [PATCH 08/10] Adding pytest.mark.on_qaic on the test Signed-off-by: Vaibhav Verma --- tests/transformers/models/test_causal_lm_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 70df7e3b8..ead636759 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -500,6 +500,7 @@ def test_prefiill_only_pytorch_vs_kv_vs_ort_vs_ai100_qnn(): ) +@pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV) def test_causal_blockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): """ @@ -513,6 +514,7 @@ def test_causal_blockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) +@pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV) def test_causal_nonBlockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): """ From 99ebcd4d2c5c6f850e18dc8e0bf4b1926ddcd92f Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Tue, 25 Nov 2025 18:53:54 -0600 Subject: [PATCH 09/10] recommitting to fix preflight_Qeff picking wrong version of modeling_llama.py Signed-off-by: Vaibhav Verma --- QEfficient/transformers/models/pytorch_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 02dab3549..3985ec027 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -850,7 +850,7 @@ def get_decoder_layer_classes_for_export(model: nn.Module) -> set: return model_decoder_classes -class BlockedKVAttentionTransform: +class BlockedKVAttentionTransform: _module_mapping = { QEffLlamaAttention, } From 81434141127b335e1b4b876de0626fe2aa2f2896 Mon Sep 17 00:00:00 2001 From: Vaibhav Verma Date: Wed, 26 Nov 2025 12:24:09 -0600 Subject: [PATCH 10/10] Reverted to returning attn_weights in eager_attention_forward in modeling_llama.py Signed-off-by: Vaibhav Verma --- .../transformers/models/llama/modeling_llama.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 542319ce4..fb3aed556 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -118,7 +118,7 @@ def eager_attention_forward( attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output + return attn_output, attn_weights def eager_attention_forward_blockedKV( @@ -190,8 +190,9 @@ def eager_attention_forward_blockedKV( delta_max.unsqueeze(-1) ) + torch.matmul(prob, V_block_states) attn_output = output.transpose(1, 2).contiguous() + attn_weights = None - return attn_output + return attn_output, attn_weights class QEffLlamaAttention(LlamaAttention): @@ -250,7 +251,7 @@ def forward( else: attention_interface = eager_attention_forward - attn_output = attention_interface( + attn_output, attn_weights = attention_interface( self, query_states, key_states, @@ -266,7 +267,7 @@ def forward( attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output, **kwargs) - return attn_output + return attn_output, attn_weights class QEffLlamaDecoderLayer(LlamaDecoderLayer): @@ -293,7 +294,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids,