11from dataclasses import dataclass
22from typing import TYPE_CHECKING , ClassVar , Optional , Tuple , Type , TypeVar
3- import math
3+
44import torch
55import torch_npu
66from torch import nn
1717from vllm_ascend .attention .mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
1818from vllm_ascend .attention .utils import (AscendCommonAttentionMetadata ,
1919 wait_for_kv_layer_from_connector )
20- from vllm_ascend .ops .weight_prefetch import maybe_npu_prefetch
21- from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_ND , ACL_FORMAT_FRACTAL_NZ ,
22- is_enable_nz , _round_up )
23- from vllm_ascend .worker .npu_input_batch import InputBatch
24- from vllm_ascend .utils import dispose_tensor , dispose_layer , replace_layer , enable_sp
2520from vllm_ascend .ops .shared_weight_layer import (
2621 is_hidden_layer , post_process_after_loading_for_shared_weight_series ,
2722 reach_layer_for_shared_weight_series ,
2823 register_layer_to_shared_weight_series )
2924from vllm_ascend .ops .triton .rope import rope_forward_triton
3025from vllm_ascend .ops .weight_prefetch import maybe_npu_prefetch
3126from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_ND , ACL_FORMAT_FRACTAL_NZ ,
32- dispose_layer , is_enable_nz , replace_layer )
27+ _round_up , dispose_layer , enable_sp ,
28+ is_enable_nz , replace_layer )
3329from vllm_ascend .worker .npu_input_batch import InputBatch
34- from vllm .forward_context import get_forward_context
3530
3631if TYPE_CHECKING :
3732 from vllm .v1 .core .sched .output import SchedulerOutput
@@ -59,7 +54,6 @@ def get_impl_cls() -> Type["AscendSFAImpl"]:
5954 return AscendSFAImpl
6055
6156
62-
6357@dataclass
6458class SfaCpContext :
6559 num_tokens : int
@@ -73,6 +67,7 @@ class SfaCpContext:
7367 actual_seq_lengths_query : torch .Tensor
7468 actual_seq_lengths_key : torch .Tensor
7569
70+
7671@dataclass
7772class AscendSFAMetadata :
7873 """Metadata for MLACommon.
@@ -198,7 +193,7 @@ def build(
198193 1 ).unsqueeze (2 )
199194 sin = self .sin_cache [input_positions ].unsqueeze ( # type: ignore
200195 1 ).unsqueeze (2 )
201-
196+
202197 sfa_cp_context = None
203198 if self .enable_sfa_cp :
204199 global_tp_size = get_tp_group ().world_size
@@ -214,12 +209,13 @@ def build(
214209 if pad_size > 0 :
215210 cos = nn .functional .pad (cos , (0 , 0 , 0 , 0 , 0 , 0 , 0 , pad_size ))
216211 sin = nn .functional .pad (sin , (0 , 0 , 0 , 0 , 0 , 0 , 0 , pad_size ))
217- slot_mapping = nn .functional .pad (slot_mapping , (0 , pad_size ), value = - 1 )
212+ slot_mapping = nn .functional .pad (slot_mapping , (0 , pad_size ),
213+ value = - 1 )
218214 cos = cos [local_start :local_end_with_pad ]
219215 sin = sin [local_start :local_end_with_pad ]
220216 slot_mapping_cp = slot_mapping [local_start :local_end_with_pad ]
221217
222- actual_seq_lengths_query = torch .empty_like (cum_query_lens )
218+ actual_seq_lengths_query = torch .empty_like (cum_query_lens )
223219 actual_seq_lengths_key = torch .empty_like (seq_lens )
224220 num_segs = cum_query_lens .shape [0 ]
225221 last_token = 0
@@ -347,7 +343,7 @@ def __init__(
347343 self .enable_kv_nz = ascend_config .torchair_graph_config .enable_kv_nz
348344 self .model_config = get_current_vllm_config ().model_config
349345 assert self .indexer is not None , "Indexer is required for DSA."
350-
346+
351347 self .enable_sfa_cp = enable_sp () and \
352348 hasattr (self .model_config .hf_config , "index_topk" )
353349 self .local_num_heads = self .num_heads
@@ -357,7 +353,8 @@ def __init__(
357353
358354 #TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
359355 self ._replace_linear_class_for_sfa_cp ()
360- from vllm_ascend .distributed .parallel_state import get_shared_weight_group
356+ from vllm_ascend .distributed .parallel_state import \
357+ get_shared_weight_group
361358 register_layer_to_shared_weight_series (
362359 series_name = "q_proj" ,
363360 group = get_shared_weight_group (),
@@ -625,23 +622,24 @@ def forward(
625622
626623 ql_nope , q_pe = self ._q_proj_and_k_up_proj (q_c )
627624 q_pe = self .rope_single (q_pe , cos , sin )
628-
625+
629626 actual_seq_lengths_query = attn_metadata .cum_query_lens
630627 actual_seq_lengths_key = attn_metadata .seq_lens
631628
632629 if self .enable_sfa_cp :
633630 actual_seq_lengths_query = attn_metadata .sfa_cp_context .actual_seq_lengths_query
634631 actual_seq_lengths_key = attn_metadata .sfa_cp_context .actual_seq_lengths_key
635-
636- topk_indices = self .indexer_select (x = hidden_states ,
637- qr = q_c ,
638- kv_cache = kv_cache ,
639- attn_metadata = attn_metadata ,
640- cos = cos ,
641- sin = sin ,
642- actual_seq_lengths_query = actual_seq_lengths_query ,
643- actual_seq_lengths_key = actual_seq_lengths_key ,
644- need_gather_q_kv = need_gather_q_kv )
632+
633+ topk_indices = self .indexer_select (
634+ x = hidden_states ,
635+ qr = q_c ,
636+ kv_cache = kv_cache ,
637+ attn_metadata = attn_metadata ,
638+ cos = cos ,
639+ sin = sin ,
640+ actual_seq_lengths_query = actual_seq_lengths_query ,
641+ actual_seq_lengths_key = actual_seq_lengths_key ,
642+ need_gather_q_kv = need_gather_q_kv )
645643 attn_output = torch .ops ._C_ascend .npu_sparse_flash_attention (
646644 query = ql_nope ,
647645 key = kv_cache [0 ],
@@ -751,19 +749,18 @@ def indexer_select(
751749 sparse_count = 2048 ,
752750 sparse_mode = 3 )
753751 return topk_indices
754-
752+
755753 def _replace_linear_class_for_sfa_cp (self ):
756754
757755 vllm_config = get_current_vllm_config ()
758756 # Dispose tensor from the original q_proj
759757 dispose_layer (self .q_proj )
760758 # Construct the new q_proj using ReplicatedLinear
761- new_q_proj = ReplicatedLinear (
762- self .q_lora_rank ,
763- self .local_num_heads * self .qk_head_dim ,
764- bias = False ,
765- quant_config = vllm_config .quant_config ,
766- prefix = self .q_proj .prefix )
759+ new_q_proj = ReplicatedLinear (self .q_lora_rank ,
760+ self .local_num_heads * self .qk_head_dim ,
761+ bias = False ,
762+ quant_config = vllm_config .quant_config ,
763+ prefix = self .q_proj .prefix )
767764 # Replace the q_proj with the new one
768765 replace_layer (self .q_proj , new_q_proj )
769766
@@ -783,13 +780,11 @@ def _replace_linear_class_for_sfa_cp(self):
783780 dispose_layer (self .o_proj )
784781 # Construct the new o_proj using ReplicatedLinear
785782 config = vllm_config .model_config .hf_config
786- new_o_proj = ReplicatedLinear (
787- config . num_attention_heads * config .v_head_dim ,
788- config .hidden_size ,
789- bias = False ,
790- quant_config = vllm_config .quant_config ,
791- prefix = self .o_proj .prefix )
783+ new_o_proj = ReplicatedLinear (config . num_attention_heads *
784+ config .v_head_dim ,
785+ config .hidden_size ,
786+ bias = False ,
787+ quant_config = vllm_config .quant_config ,
788+ prefix = self .o_proj .prefix )
792789 # Replace the o_proj with the new one
793790 replace_layer (self .o_proj , new_o_proj )
794-
795-
0 commit comments