Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
31361e5
[Graph] [Fusion] Fusion slice and qknorm operator
wxsIcey Dec 4, 2025
70ec112
adapt bias is not none
wxsIcey Dec 4, 2025
fb9e37e
change to norm rope fusion
wxsIcey Dec 4, 2025
78cc0a2
tiny fix
wxsIcey Dec 4, 2025
ef734cb
tiny fix
wxsIcey Dec 4, 2025
e898867
normalize fusion naming and format code
wxsIcey Dec 5, 2025
b20db1d
move special operator to attention metadata builder
wxsIcey Dec 5, 2025
7376617
add e2e test
wxsIcey Dec 7, 2025
65ce080
remove first layer change
wxsIcey Dec 8, 2025
691d54c
tiny fix
wxsIcey Dec 8, 2025
6e95239
fix
wxsIcey Dec 8, 2025
fef462f
fix
wxsIcey Dec 8, 2025
6f72f7f
fix
wxsIcey Dec 8, 2025
171539a
remove e2e test
wxsIcey Dec 9, 2025
4a745bd
tiny fix
wxsIcey Dec 9, 2025
7f5ab08
fix
wxsIcey Dec 9, 2025
a9cfb33
fix eagle spec decode
wxsIcey Dec 11, 2025
e0c5139
tiny fix
wxsIcey Dec 11, 2025
fdb549c
fix triton
wxsIcey Dec 11, 2025
ead3622
fix
wxsIcey Dec 11, 2025
4a02ac5
install triton
wxsIcey Dec 11, 2025
9019f16
fix ut
wxsIcey Dec 12, 2025
09019ab
fix
wxsIcey Dec 12, 2025
5454d63
fix ut
wxsIcey Dec 12, 2025
360fcf5
fix ut
wxsIcey Dec 14, 2025
129fcde
resolve conflict
wxsIcey Dec 14, 2025
536ead4
change workflow
wxsIcey Dec 14, 2025
8599237
tiny fix
wxsIcey Dec 15, 2025
5de88fc
fix
wxsIcey Dec 15, 2025
fe8a177
fix
wxsIcey Dec 15, 2025
2c43712
fix qwen3 next
wxsIcey Dec 15, 2025
519aa97
fix eagle
wxsIcey Dec 15, 2025
5e22e20
fix eagle
wxsIcey Dec 15, 2025
5e6a838
fix mla cp
wxsIcey Dec 15, 2025
993744a
fix
wxsIcey Dec 15, 2025
31aa1b2
fix e2e
wxsIcey Dec 16, 2025
01814e5
recover e2e
wxsIcey Dec 16, 2025
9143521
resolve conflict
wxsIcey Dec 16, 2025
dd4617f
fix
wxsIcey Dec 16, 2025
d90be75
fix
wxsIcey Dec 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/e2e/singlecard/test_aclgraph_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,4 @@ def test_aclgraph_enable():
# after check_and_update_config, mode should be VLLM_COMPILE and piecewise cudagraph
NPUPlatform.check_and_update_config(VllmConfig)
assert VllmConfig.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert VllmConfig.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
assert VllmConfig.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
11 changes: 8 additions & 3 deletions vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from uuid import uuid4

from vllm.logger import logger
from vllm.triton_utils import HAS_TRITON


def check_kv_extra_config(vllm_config):
Expand Down Expand Up @@ -231,19 +232,23 @@ class AscendCompilationConfig:
deployed on Ascend platforms.
"""

def __init__(self, fuse_norm_quant: bool = True, **kwargs):
def __init__(self,
fuse_norm_quant: bool = True,
fuse_qknorm_rope: bool = False,
**kwargs):
"""
Initialize the configuration.
Args:
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
When set to True, the system will optimize norm and quant operations.
Default: True
fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization.
Default: False
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.fuse_norm_quant = fuse_norm_quant
# Add more compilation related configs here as needed
self.fuse_qknorm_rope = HAS_TRITON or fuse_qknorm_rope


class XliteGraphConfig:
Expand Down
31 changes: 0 additions & 31 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,37 +209,6 @@ def get_mc2_mask():
return _reserved_mc2_mask


def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
device):
global _cos
global _sin
if _cos is not None:
return
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
_cos = torch.ones(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
_sin = torch.zeros(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
else:
_cos = None
_sin = None


def get_cos_and_sin():
return _cos, _sin


def select_moe_comm_method(num_tokens: int,
vllm_config: VllmConfig) -> Optional[MoECommType]:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/attention/mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec

from vllm_ascend.ascend_forward_context import get_cos_and_sin
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder,
Expand All @@ -29,6 +28,7 @@
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, reach_layer_for_shared_weight_series)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
Expand Down Expand Up @@ -286,7 +286,7 @@ def build(

decode_metadata = None
if num_decodes > 0:
cos, sin = get_cos_and_sin()
cos, sin = get_cos_and_sin_mla()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
1].tolist()
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import get_cos_and_sin
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
Expand All @@ -32,6 +31,7 @@
from vllm_ascend.compilation.acl_graph import (get_graph_params,
get_mtp_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
Expand Down Expand Up @@ -531,7 +531,7 @@ def build(

decode_metadata = None
if num_decodes > 0:
cos, sin = get_cos_and_sin()
cos, sin = get_cos_and_sin_mla()
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes +
1].tolist()
Expand Down
4 changes: 2 additions & 2 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import get_cos_and_sin
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.shared_weight_layer import (
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
reach_layer_for_shared_weight_series,
Expand Down Expand Up @@ -187,7 +187,7 @@ def build(
cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1]
seq_lens = common_attn_metadata.seq_lens[:num_reqs]

cos, sin = get_cos_and_sin()
cos, sin = get_cos_and_sin_mla()

assert self.cos_cache is not None and self.sin_cache is not None
new_cos = self.cos_cache[input_positions][:, None, None]
Expand Down
5 changes: 4 additions & 1 deletion vllm_ascend/compilation/graph_fusion_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,7 @@ def configure(self, config: VllmConfig):
from .passes.norm_quant_fusion_pass import \
AddRMSNormQuantFusionPass
self.passes.append(AddRMSNormQuantFusionPass(config))
# Add more passes here as needed

if self.ascend_compilation_config.get("fuse_qknorm_rope", True):
from .passes.qknorm_rope_fusion_pass import QKNormRopeFusionPass
self.passes.append(QKNormRopeFusionPass(config))
Loading
Loading