diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index b1878862b9f..e217f0bbcaa 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -210,4 +210,4 @@ def test_aclgraph_enable(): # after check_and_update_config, mode should be VLLM_COMPILE and piecewise cudagraph NPUPlatform.check_and_update_config(VllmConfig) assert VllmConfig.compilation_config.mode == CompilationMode.VLLM_COMPILE - assert VllmConfig.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE + assert VllmConfig.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE \ No newline at end of file diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 48d23a7c49b..e9b9b1f6200 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -17,6 +17,7 @@ from uuid import uuid4 from vllm.logger import logger +from vllm.triton_utils import HAS_TRITON def check_kv_extra_config(vllm_config): @@ -231,7 +232,10 @@ class AscendCompilationConfig: deployed on Ascend platforms. """ - def __init__(self, fuse_norm_quant: bool = True, **kwargs): + def __init__(self, + fuse_norm_quant: bool = True, + fuse_qknorm_rope: bool = False, + **kwargs): """ Initialize the configuration. @@ -239,11 +243,12 @@ def __init__(self, fuse_norm_quant: bool = True, **kwargs): fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization. When set to True, the system will optimize norm and quant operations. Default: True - + fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization. + Default: False **kwargs: Additional optional parameters for forward compatibility and configuration extension. """ self.fuse_norm_quant = fuse_norm_quant - # Add more compilation related configs here as needed + self.fuse_qknorm_rope = HAS_TRITON or fuse_qknorm_rope class XliteGraphConfig: diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index b4343e76f0c..b00bdabced6 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -209,37 +209,6 @@ def get_mc2_mask(): return _reserved_mc2_mask -def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, - device): - global _cos - global _sin - if _cos is not None: - return - compilation_config = vllm_config.compilation_config - model_config = vllm_config.model_config - if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: - rope_dim = model_config.hf_text_config.qk_rope_head_dim - _cos = torch.ones(max_num_reqs * decode_token_per_req, - 1, - 1, - rope_dim, - dtype=dtype, - device=device) - _sin = torch.zeros(max_num_reqs * decode_token_per_req, - 1, - 1, - rope_dim, - dtype=dtype, - device=device) - else: - _cos = None - _sin = None - - -def get_cos_and_sin(): - return _cos, _sin - - def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig) -> Optional[MoECommType]: """1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/mla_cp.py index 6aa3848dc72..980f3c5050a 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/mla_cp.py @@ -16,7 +16,6 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import MLAAttentionSpec -from vllm_ascend.ascend_forward_context import get_cos_and_sin from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata, AscendMLAMetadataBuilder, @@ -29,6 +28,7 @@ wait_for_kv_layer_from_connector) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) +from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.shared_weight_layer import ( is_hidden_layer, reach_layer_for_shared_weight_series) from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch @@ -286,7 +286,7 @@ def build( decode_metadata = None if num_decodes > 0: - cos, sin = get_cos_and_sin() + cos, sin = get_cos_and_sin_mla() # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + 1].tolist() diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 29643c0505b..e250fdbc0ad 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -22,7 +22,6 @@ from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import get_cos_and_sin from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, maybe_save_kv_layer_to_connector, @@ -32,6 +31,7 @@ from vllm_ascend.compilation.acl_graph import (get_graph_params, get_mtp_graph_params, update_graph_params_workspaces) +from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.shared_weight_layer import ( is_hidden_layer, post_process_after_loading_for_shared_weight_series, reach_layer_for_shared_weight_series, @@ -531,7 +531,7 @@ def build( decode_metadata = None if num_decodes > 0: - cos, sin = get_cos_and_sin() + cos, sin = get_cos_and_sin_mla() # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + 1].tolist() diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index f6b338a807c..a53e9423412 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -16,12 +16,12 @@ from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import get_cos_and_sin from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, trans_rope_weight, transdata, wait_for_kv_layer_from_connector) +from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.shared_weight_layer import ( is_hidden_layer, post_process_after_loading_for_shared_weight_series, reach_layer_for_shared_weight_series, @@ -187,7 +187,7 @@ def build( cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1] seq_lens = common_attn_metadata.seq_lens[:num_reqs] - cos, sin = get_cos_and_sin() + cos, sin = get_cos_and_sin_mla() assert self.cos_cache is not None and self.sin_cache is not None new_cos = self.cos_cache[input_positions][:, None, None] diff --git a/vllm_ascend/compilation/graph_fusion_pass_manager.py b/vllm_ascend/compilation/graph_fusion_pass_manager.py index 2922869453b..e311b2602a7 100644 --- a/vllm_ascend/compilation/graph_fusion_pass_manager.py +++ b/vllm_ascend/compilation/graph_fusion_pass_manager.py @@ -50,4 +50,7 @@ def configure(self, config: VllmConfig): from .passes.norm_quant_fusion_pass import \ AddRMSNormQuantFusionPass self.passes.append(AddRMSNormQuantFusionPass(config)) - # Add more passes here as needed + + if self.ascend_compilation_config.get("fuse_qknorm_rope", True): + from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass + self.passes.append(QKNormRopeFusionPass(config)) diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py new file mode 100644 index 00000000000..d0f1aa53296 --- /dev/null +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -0,0 +1,293 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + +import torch +import torch._inductor.pattern_matcher as pm +from torch._inductor.pattern_matcher import (PatternMatcherPass, + PatternPrettyPrinter) +from vllm.attention.layer import Attention +from vllm.compilation.vllm_inductor_pass import VllmInductorPass +from vllm.config import (VllmConfig, get_current_vllm_config, + get_layers_from_vllm_config) + + +class QKNormRopeFusionPattern: + + def __init__(self, + vllm_config, + head_dim, + num_heads, + num_kv_heads, + eps=1e-6): + self.vllm_config = vllm_config + self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.eps = eps + vllm_config = get_current_vllm_config() + self.device = vllm_config.device_config.device if vllm_config.device_config else None + + def get_inputs(self): + T = 5 + qkv = torch.empty(T, + self.q_size + 2 * self.kv_size, + dtype=torch.bfloat16, + device="npu") + q_weight = torch.empty(self.head_dim, + dtype=torch.bfloat16, + device="npu") + k_weight = torch.empty(self.head_dim, + dtype=torch.bfloat16, + device="npu") + cos = torch.empty(1, + T, + 1, + self.head_dim, + dtype=torch.bfloat16, + device="npu") + sin = torch.empty(1, + T, + 1, + self.head_dim, + dtype=torch.bfloat16, + device="npu") + return [qkv, q_weight, k_weight, cos, sin] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, + k_weight: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor): + + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, + self.eps) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, + self.eps) + + q_flat = q_norm_out.view(q.shape) + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, + self.head_dim) + + k_flat = k_norm_out.view(k.shape) + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, + self.head_dim) + + q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb( + q_reshape, k_reshape, cos, sin) + + return q_rope, k_rope, v + + def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, + k_weight: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor): + results = torch.ops.vllm.qkv_rmsnorm_rope( + input=qkv, + q_weight=q_weight, + k_weight=k_weight, + q_hidden_size=self.q_size, + kv_hidden_size=self.kv_size, + head_dim=self.head_dim, + eps=self.eps, + q_bias=None, + k_bias=None, + sin=sin, + cos=cos) + + return results + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class QKNormRopeFusionPatternWithBias: + + def __init__(self, + vllm_config, + head_dim, + num_heads, + num_kv_heads, + eps=1e-6): + self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.eps = eps + self.vllm_config = vllm_config + self.device = vllm_config.device_config.device if vllm_config.device_config else None + + def get_inputs(self): + T = 5 + qkv = torch.empty(T, + self.q_size + 2 * self.kv_size, + dtype=torch.bfloat16, + device="npu") + q_weight = torch.empty(self.head_dim, + dtype=torch.bfloat16, + device="npu") + k_weight = torch.empty(self.head_dim, + dtype=torch.bfloat16, + device="npu") + q_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + k_bias = torch.empty(self.head_dim, dtype=torch.bfloat16, device="npu") + cos = torch.empty(1, + T, + 1, + self.head_dim, + dtype=torch.bfloat16, + device="npu") + sin = torch.empty(1, + T, + 1, + self.head_dim, + dtype=torch.bfloat16, + device="npu") + + return [qkv, q_weight, k_weight, q_bias, k_bias, cos, sin] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(qkv: torch.Tensor, q_weight: torch.Tensor, + k_weight: torch.Tensor, q_bias: torch.Tensor, + k_bias: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_norm_out, _ = torch.ops.npu.npu_rms_norm(q_by_head, q_weight, + self.eps) + q_normed = q_norm_out + q_bias + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_norm_out, _ = torch.ops.npu.npu_rms_norm(k_by_head, k_weight, + self.eps) + k_normed = k_norm_out + k_bias + + q_flat = q_normed.view(q.shape) + q_reshape = q_flat.contiguous().view(1, q_flat.shape[0], -1, + self.head_dim) + + k_flat = k_normed.view(k.shape) + k_reshape = k_flat.contiguous().view(1, k_flat.shape[0], -1, + self.head_dim) + + q_rope, k_rope = torch.ops.npu.npu_apply_rotary_pos_emb( + q_reshape, k_reshape, cos, sin) + + return q_rope, k_rope, v + + def replacement(qkv: torch.Tensor, q_weight: torch.Tensor, + k_weight: torch.Tensor, q_bias: torch.Tensor, + k_bias: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor): + results = torch.ops.vllm.qkv_rmsnorm_rope( + input=qkv, + q_weight=q_weight, + k_weight=k_weight, + q_hidden_size=self.q_size, + kv_hidden_size=self.kv_size, + head_dim=self.head_dim, + eps=self.eps, + q_bias=q_bias, + k_bias=k_bias, + cos=cos, + sin=sin) + return results + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class QKNormRopeFusionPass(VllmInductorPass): + """ + A pass for fusing QKV split and RMSNorm operations into a single qk_rmsnorm operator. + """ + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass( + pass_name="qknorm_rope_fusion_pass") + + dtype = vllm_config.model_config.dtype + if dtype not in (torch.bfloat16, torch.float16): + logging.info( + "QKNorm and Rope fusion not enabled: unsupported dtype %s", + dtype) + return + + # use one attn layer to get meta (such as head_dim) for QKNormRopeFusionPattern + attn_layers: dict[str, Attention] = get_layers_from_vllm_config( + vllm_config, Attention) + if len(attn_layers) == 0: + logging.info( + "QKNorm and Rope fusion enabled, but no Attention layers were discovered." + ) + return + layer = next(iter(attn_layers.values())) + for epsilon in [1e-6, 1e-5]: + if layer.head_size != 128: + logging.debug( + "QKNorm and Rope fusion not enabled: head_dim %d is not equal of 128", + layer.head_size) + continue + QKNormRopeFusionPattern(vllm_config=vllm_config, + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon).register( + self.pattern_match_passes) + + QKNormRopeFusionPatternWithBias(vllm_config=vllm_config, + head_dim=layer.head_size, + num_heads=layer.num_heads, + num_kv_heads=layer.num_kv_heads, + eps=epsilon).register( + self.pattern_match_passes) + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.matched_count = self.pattern_match_passes.apply(graph) + logging.debug("Fused %s QKNorm and Rope patterns", self.matched_count) + logging.debug("Patterns registered for replacement:") + pattern_idx = 0 + for pattern_entry in self.pattern_match_passes.patterns.values(): + for p in pattern_entry: + p_str = PatternPrettyPrinter.run(p.pattern) + logging.debug("Pattern %d: %s", pattern_idx, p_str) + pattern_idx += 1 + self.end_and_log() + + def is_applicable(self, runtime_shape): + """ + Check if the pass is applicable for the current configuration. + """ + return True diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index e121f2a442c..aadb6164705 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -16,10 +16,15 @@ # import torch +from vllm.triton_utils import HAS_TRITON import vllm_ascend.ops.fused_moe.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa import vllm_ascend.ops.register_custom_ops # noqa + +if HAS_TRITON: + import vllm_ascend.ops.triton.linearnorm.split_qkv_rmsnorm_rope # noqa + import vllm_ascend.ops.vocab_parallel_embedding # noqa from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.rotary_embedding import ( diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index ef398faef00..7f86047028f 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -20,14 +20,117 @@ import torch import torch_npu -from vllm.forward_context import get_forward_context +from vllm.config import CUDAGraphMode from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, - get_ascend_device_type) + get_ascend_device_type, is_vl_model) + +# Currently, rope ops used on npu requires detached cos && sin as inputs. +# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable. +# So we have to preprocess cos_sin_cache int cos && sin. In the future, +# we shall implement a new rope ops which accept cos_sin_cache as inputs. +# NOTE(Angazenn): MLA && SFA models uses attn_metadata to pass cos && sin +# to rope in AscendMLA(SFA)Impl. However, since rope is isolated from +# AscendAttentionBackendImpl for GQA models, we cannot pass cos && sin by +# attn_metadata. This causes that rope in GQA models must pass cos && sin +# by different approaches. +_cos_mla: Optional[torch.Tensor] = None +_sin_mla: Optional[torch.Tensor] = None +_cos_sin_cache: Optional[torch.Tensor] = None +_cos: Optional[torch.Tensor] = None +_sin: Optional[torch.Tensor] = None +_cos_slice: Optional[torch.Tensor] = None +_sin_slice: Optional[torch.Tensor] = None + + +def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, + device): + global _cos_mla + global _sin_mla + global _cos + global _sin + + if _cos_mla is not None or \ + _sin_mla is not None or \ + _cos is not None or \ + _sin is not None: + return + + compilation_config = vllm_config.compilation_config + model_config = vllm_config.model_config + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + + if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + rope_dim = model_config.hf_text_config.qk_rope_head_dim + _cos_mla = torch.ones(max_num_reqs * decode_token_per_req, + 1, + 1, + rope_dim, + dtype=dtype, + device=device) + _sin_mla = torch.zeros(max_num_reqs * decode_token_per_req, + 1, + 1, + rope_dim, + dtype=dtype, + device=device) + elif not is_vl_model(vllm_config) and not vllm_config.model_config.use_mla: + rope_dim = model_config.get_head_size() + # For models using partial rope like Qwen3-Next. + if hasattr(model_config.hf_text_config, "partial_rotary_factor"): + rope_dim = int(rope_dim * + model_config.hf_text_config.partial_rotary_factor) + _cos = torch.ones(1, + max_num_batched_tokens, + 1, + rope_dim, + dtype=dtype, + device=device) + _sin = torch.zeros(1, + max_num_batched_tokens, + 1, + rope_dim, + dtype=dtype, + device=device) + + +def get_cos_and_sin_mla(): + return _cos_mla, _sin_mla + + +def _record_cos_sin_cache(cos_sin_cache): + global _cos_sin_cache + if _cos_sin_cache is not None: + return + _cos_sin_cache = cos_sin_cache + + +def update_cos_sin(positions): + global _cos + global _sin + global _cos_slice + global _sin_slice + + if _cos_sin_cache is None or \ + _cos is None or \ + _sin is None: + return + + num_tokens = positions.size(0) + _cos[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view( + num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[0] + _sin[:, :num_tokens] = _cos_sin_cache.index_select(0, positions).view( + num_tokens, 2, -1).repeat(1, 1, 2).chunk(2, dim=-2)[1] + _cos_slice = _cos[:, :num_tokens] + _sin_slice = _sin[:, :num_tokens] + + +def get_cos_and_sin_slice(): + return _cos_slice, _sin_slice def _custom_rotary_embedding_enabled(query, neox_style, head_size): @@ -65,8 +168,9 @@ def _rope_forward_oot( raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: - if hasattr(self, "cos") and hasattr(self, "sin") and \ - self.cos is not None and self.sin is not None: + cos, sin = get_cos_and_sin_slice() + if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ + -1] == 128 and cos is not None and sin is not None: # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. # This method requires head_size and rotary_dim equal 128 and neox_style is True query = query.contiguous().view(1, query.shape[0], -1, @@ -75,7 +179,7 @@ def _rope_forward_oot( # Although this function modifies in-place, please retain the function's return value. # Otherwise, the graph fusion operation may fail. query, key = torch_npu.npu_apply_rotary_pos_emb( - query, key, self.cos, self.sin) + query, key, cos, sin) elif self.rotary_dim < self.head_size: num_tokens = query.shape[0] query = query.view(num_tokens, -1, self.head_size) @@ -125,10 +229,9 @@ def __init__( is_neox_style: bool, dtype: torch.dtype, ) -> None: - self.cos = None - self.sin = None super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) + _record_cos_sin_cache(self.cos_sin_cache) def forward_oot( self, @@ -141,20 +244,6 @@ def forward_oot( is_neox_style = self.is_neox_style if is_neox_style_override is not None: is_neox_style = is_neox_style_override - forward_context = get_forward_context() - is_first_layer = forward_context.is_first_layer - # Generate cos and sin outside layers to avoid repeated calculation. - if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[ - -1] == 128: - if is_first_layer: - cos_sin = self.cos_sin_cache.index_select(0, positions) - last_dim = cos_sin.size()[-1] - cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat( - 1, 1, 2).chunk(2, dim=-2) - # BSNH - self.cos = cos.view(1, -1, 1, last_dim).contiguous() - self.sin = sin.view(1, -1, 1, last_dim).contiguous() - forward_context.is_first_layer = False return _rope_forward_oot(self, positions, query, key, is_neox_style, offsets) @@ -176,8 +265,6 @@ def __init__( beta_fast: int = 32, beta_slow: int = 1, ) -> None: - self.cos = None - self.sin = None extra_kwargs = { "extrapolation_factor": extrapolation_factor, "attn_factor": attn_factor, @@ -186,6 +273,7 @@ def __init__( } super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, scaling_factor, dtype, **extra_kwargs) + _record_cos_sin_cache(self.cos_sin_cache) def forward_oot( self, diff --git a/vllm_ascend/ops/triton/linearnorm/__init__.py b/vllm_ascend/ops/triton/linearnorm/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py new file mode 100644 index 00000000000..58e56ae5ee5 --- /dev/null +++ b/vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py @@ -0,0 +1,305 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from typing import Optional + +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from vllm.utils.torch_utils import direct_register_custom_op + +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num + + +@triton.jit +def split_qkv_rmsnorm_rope_kernel( + input_ptr, + sin_ptr, + cos_ptr, + q_ptr, + k_ptr, + v_ptr, + q_weight_ptr, + q_bias_ptr, + k_weight_ptr, + k_bias_ptr, + batch_size, + q_hidden_size: tl.constexpr, + kv_hidden_size: tl.constexpr, + total_hidden_size: tl.constexpr, + eps: tl.constexpr, + Q_BLOCK_SIZE: tl.constexpr, + KV_BLOCK_SIZE: tl.constexpr, + BIAS: tl.constexpr, + HEAD_DIM: tl.constexpr, + HALF_HEAD_DIM: tl.constexpr, +): + row_pid = tl.program_id(0) + col_pid = tl.program_id(1) + row_step = tl.num_programs(0) + # q + weight_values = tl.load(q_weight_ptr + tl.arange(0, HEAD_DIM)) + if BIAS: + bias_values = tl.load(q_bias_ptr + tl.arange(0, HEAD_DIM)) + input_offset = row_pid * total_hidden_size + output_offset = row_pid * q_hidden_size + input_offset_step = row_step * total_hidden_size + output_offset_step = row_step * q_hidden_size + for row_idx in tl.range(row_pid, batch_size, row_step): + col_indices = col_pid * Q_BLOCK_SIZE + tl.arange(0, Q_BLOCK_SIZE) + valid_mask = col_indices < q_hidden_size + input_values = (tl.load(input_ptr + input_offset + col_indices, + mask=valid_mask, + other=0.0).to(tl.float32).reshape( + Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)) + squares = input_values * input_values + variances = tl.sum(squares, axis=1) / HEAD_DIM + reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( + Q_BLOCK_SIZE // HEAD_DIM, 1) + normalized_values = (input_values * reciprocal_std + ) # (Q_BLOCK_SIZE//HEAD_DIM, HEAD_DIM) + if BIAS: + normalized_values = (normalized_values * weight_values + + bias_values).to(tl.bfloat16) + else: + normalized_values = (normalized_values * weight_values).to( + tl.bfloat16) + + sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM) + sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM) + cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM) + x1 = tl.extract_slice( + normalized_values, + offsets=(0, 0), + sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + x2 = tl.extract_slice( + normalized_values, + offsets=(0, HALF_HEAD_DIM), + sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + cat_x = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), + dtype=tl.bfloat16) + cat_x = tl.insert_slice( + cat_x, + -x2, + offsets=(0, 0), + sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + cat_x = tl.insert_slice( + cat_x, + x1, + offsets=(0, HALF_HEAD_DIM), + sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + roped_q = cat_x * sin + normalized_values * cos + tl.store( + q_ptr + output_offset + col_indices, + roped_q.reshape(Q_BLOCK_SIZE).to(q_ptr.dtype.element_ty), + mask=valid_mask, + ) + input_offset += input_offset_step + output_offset += output_offset_step + + weight_values = tl.load(k_weight_ptr + tl.arange(0, HEAD_DIM)) + if BIAS: + bias_values = tl.load(k_bias_ptr + tl.arange(0, HEAD_DIM)) + input_offset = row_pid * total_hidden_size + q_hidden_size + output_offset = row_pid * kv_hidden_size + output_offset_step = row_step * kv_hidden_size + for row_idx in tl.range(row_pid, batch_size, row_step): + col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) + valid_mask = col_indices < kv_hidden_size + input_values = (tl.load(input_ptr + input_offset + col_indices, + mask=valid_mask, + other=0.0).to(tl.float32).reshape( + KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM)) + squares = input_values * input_values + variances = tl.sum(squares, axis=1) / HEAD_DIM + reciprocal_std = (1 / tl.sqrt(variances + eps)).reshape( + KV_BLOCK_SIZE // HEAD_DIM, 1) + normalized_values = (input_values * reciprocal_std + ) # (KV_BLOCK_SIZE/HEAD_DIM, HEAD_DIM) + if BIAS: + normalized_values = (normalized_values * weight_values + + bias_values).to(tl.bfloat16) + else: + normalized_values = (normalized_values * weight_values).to( + tl.bfloat16) + sc_offsets = row_idx * HEAD_DIM + tl.arange(0, HEAD_DIM) + sin = (tl.load(sin_ptr + sc_offsets)).reshape(1, HEAD_DIM) + cos = (tl.load(cos_ptr + sc_offsets)).reshape(1, HEAD_DIM) + x1 = tl.extract_slice( + normalized_values, + offsets=(0, 0), + sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + x2 = tl.extract_slice( + normalized_values, + offsets=(0, HALF_HEAD_DIM), + sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + cat_x = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), + dtype=tl.bfloat16) + cat_x = tl.insert_slice( + cat_x, + -x2, + offsets=(0, 0), + sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + cat_x = tl.insert_slice( + cat_x, + x1, + offsets=(0, HALF_HEAD_DIM), + sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM), + strides=(1, 1), + ) + roped_k = cat_x * sin + normalized_values * cos + + tl.store( + k_ptr + output_offset + col_indices, + roped_k.to(tl.bfloat16).reshape(KV_BLOCK_SIZE), + mask=valid_mask, + ) + input_offset += input_offset_step + output_offset += output_offset_step + + input_offset = row_pid * total_hidden_size + q_hidden_size + kv_hidden_size + output_offset = row_pid * kv_hidden_size + for _ in tl.range(row_pid, batch_size, row_step): + col_indices = col_pid * KV_BLOCK_SIZE + tl.arange(0, KV_BLOCK_SIZE) + valid_mask = col_indices < kv_hidden_size + input_values = tl.load(input_ptr + input_offset + col_indices, + mask=valid_mask, + other=0.0) + tl.store(v_ptr + output_offset + col_indices, + input_values, + mask=valid_mask) + input_offset += input_offset_step + output_offset += output_offset_step + + +def split_qkv_rmsnorm_rope_impl( + input: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hidden_size: int, + kv_hidden_size: int, + head_dim: int, + eps: float, + q_bias: Optional[torch.Tensor], + k_bias: Optional[torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + KV_BLOCK_SIZE = triton.next_power_of_2(head_dim) + assert KV_BLOCK_SIZE == head_dim + assert q_hidden_size % kv_hidden_size == 0 + Q_BLOCK_SIZE = q_hidden_size // kv_hidden_size * head_dim + batch_size = input.shape[0] + total_hidden_size = q_hidden_size + kv_hidden_size * 2 + q_output = torch.empty(batch_size, + q_hidden_size, + device=input.device, + dtype=input.dtype) + k_output = torch.empty(batch_size, + kv_hidden_size, + device=input.device, + dtype=input.dtype) + v_output = torch.empty(batch_size, + kv_hidden_size, + device=input.device, + dtype=input.dtype) + n_cols = kv_hidden_size // KV_BLOCK_SIZE + num_vectorcore = get_vectorcore_num() + assert num_vectorcore % n_cols == 0 + n_rows = num_vectorcore // n_cols + BIAS = q_bias is not None + + split_qkv_rmsnorm_rope_kernel[(n_rows, n_cols, 1)]( + input, + sin, + cos, + q_output, + k_output, + v_output, + q_weight, + q_bias, + k_weight, + k_bias, + batch_size, + q_hidden_size, + kv_hidden_size, + total_hidden_size, + eps, + Q_BLOCK_SIZE, + KV_BLOCK_SIZE, + BIAS, + head_dim, + head_dim // 2, + ) + return q_output, k_output, v_output + + +def split_qkv_rmsnorm_rope_impl_fake( + input: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hidden_size: int, + kv_hidden_size: int, + head_dim: int, + eps: float, + q_bias: Optional[torch.Tensor] = None, + k_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Fake implementation for shape inference during Dynamo/AOT tracing. + # Note: sin and cos are not used in shape computation, but must be present in signature. + batch_size = input.shape[0] + q_output = torch.empty( + batch_size, + q_hidden_size, + device=input.device, + dtype=input.dtype, + ) + k_output = torch.empty( + batch_size, + kv_hidden_size, + device=input.device, + dtype=input.dtype, + ) + v_output = torch.empty( + batch_size, + kv_hidden_size, + device=input.device, + dtype=input.dtype, + ) + return q_output, k_output, v_output + + +direct_register_custom_op(op_name="qkv_rmsnorm_rope", + op_func=split_qkv_rmsnorm_rope_impl, + fake_impl=split_qkv_rmsnorm_rope_impl_fake, + mutates_args=[], + dispatch_key="PrivateUse1") diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 24a846d9bae..ff662b728ba 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -25,6 +25,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType PADDING_SLOT_ID = -1 @@ -143,6 +144,9 @@ def dummy_run(self, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor=None, dummy_compute_logits=lambda hidden_states: None): + # update global cos, sin + update_cos_sin(self.positions[:num_tokens]) + with set_ascend_forward_context(None, self.vllm_config, in_profile_run=True, @@ -339,6 +343,8 @@ def _propose( builder = self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model()) + # update global cos, sin + update_cos_sin(self.positions[:num_input_tokens]) with set_ascend_forward_context(attn_metadata, self.vllm_config, @@ -444,6 +450,10 @@ def _propose( attn_metadata.attn_mask = attn_mask # Run the model. + + # update global cos, sin + update_cos_sin(self.positions[:input_batch_size]) + with set_ascend_forward_context(attn_metadata, self.vllm_config, num_tokens=input_batch_size): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d4b4b25bf74..5a82c021b3b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -84,12 +84,6 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ascend_forward_context import (MoECommType, - get_mc2_tokens_capacity, - select_moe_comm_method, - set_ascend_forward_context, - set_cos_and_sin, set_mc2_mask, - set_mc2_tokens_capacity) from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, @@ -111,6 +105,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.utils import model_register +from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort from vllm_ascend.sample.logits_processor import build_logitsprocs @@ -125,6 +120,10 @@ is_moe_model, lmhead_tp_enable, vllm_version_is) from vllm_ascend.worker.npu_input_batch import NPUInputBatch +from vllm_ascend.ascend_forward_context import ( # isort: skip + MoECommType, get_mc2_tokens_capacity, select_moe_comm_method, + set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity) + if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput @@ -1145,6 +1144,9 @@ def _prepare_inputs( for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i + # update global cos, sin + update_cos_sin(positions) + if lmhead_tp_enable(): max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len logits_indices = nn.functional.pad( @@ -2107,6 +2109,9 @@ def _dummy_run( else: positions = self.positions.gpu[:num_tokens_padded] + # update global cos, sin + update_cos_sin(positions) + if get_pp_group().is_first_rank: intermediate_tensors = None else: