11from dataclasses import dataclass
22from typing import TYPE_CHECKING , ClassVar , Optional , Tuple , Type , TypeVar
3-
3+ import math
44import torch
55import torch_npu
66from torch import nn
1717from vllm_ascend .attention .mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
1818from vllm_ascend .attention .utils import (AscendCommonAttentionMetadata ,
1919 wait_for_kv_layer_from_connector )
20- from vllm_ascend .distributed .sfa_sp_context import (get_sfa_sp_context ,
21- set_sfa_sp_context )
20+ from vllm_ascend .ops .weight_prefetch import maybe_npu_prefetch
21+ from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_ND , ACL_FORMAT_FRACTAL_NZ ,
22+ is_enable_nz , _round_up )
23+ from vllm_ascend .worker .npu_input_batch import InputBatch
24+ from vllm_ascend .utils import dispose_tensor , dispose_layer , replace_layer , enable_sp
2225from vllm_ascend .ops .shared_weight_layer import (
2326 is_hidden_layer , post_process_after_loading_for_shared_weight_series ,
2427 reach_layer_for_shared_weight_series ,
2831from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_ND , ACL_FORMAT_FRACTAL_NZ ,
2932 dispose_layer , is_enable_nz , replace_layer )
3033from vllm_ascend .worker .npu_input_batch import InputBatch
34+ from vllm .forward_context import get_forward_context
3135
3236if TYPE_CHECKING :
3337 from vllm .v1 .core .sched .output import SchedulerOutput
@@ -55,6 +59,20 @@ def get_impl_cls() -> Type["AscendSFAImpl"]:
5559 return AscendSFAImpl
5660
5761
62+
63+ @dataclass
64+ class SfaCpContext :
65+ num_tokens : int
66+ num_tokens_pad : int
67+ local_start : int
68+ local_end : int
69+ local_end_with_pad : int
70+ pad_size : int
71+ local_pad_size : int
72+ slot_mapping_cp : torch .Tensor
73+ actual_seq_lengths_query : torch .Tensor
74+ actual_seq_lengths_key : torch .Tensor
75+
5876@dataclass
5977class AscendSFAMetadata :
6078 """Metadata for MLACommon.
@@ -85,6 +103,7 @@ class AscendSFAMetadata:
85103 attn_mask : torch .Tensor = None
86104 # chunked prefill by default if no attn_states passed
87105 attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
106+ sfa_cp_context : Optional [SfaCpContext ] = None
88107
89108
90109M = TypeVar ("M" , bound = AscendSFAMetadata )
@@ -128,6 +147,9 @@ def __init__(self,
128147 self .cos_cache = None
129148 self .sin_cache = None
130149
150+ self .enable_sfa_cp = enable_sp () and \
151+ hasattr (self .model_config .hf_config , "index_topk" )
152+
131153 def reorder_batch (self , input_batch : "InputBatch" ,
132154 scheduler_output : "SchedulerOutput" ) -> bool :
133155 # No need to reorder for Ascend SFA
@@ -176,6 +198,63 @@ def build(
176198 1 ).unsqueeze (2 )
177199 sin = self .sin_cache [input_positions ].unsqueeze ( # type: ignore
178200 1 ).unsqueeze (2 )
201+
202+ sfa_cp_context = None
203+ if self .enable_sfa_cp :
204+ global_tp_size = get_tp_group ().world_size
205+ num_tokens = num_actual_tokens
206+ num_tokens_pad = _round_up (num_actual_tokens , global_tp_size )
207+ num_tokens_per_device = num_tokens_pad // global_tp_size
208+ pad_size = num_tokens_pad - num_tokens
209+ local_start = get_tp_group ().rank_in_group * num_tokens_per_device
210+ local_end_with_pad = local_start + num_tokens_per_device
211+ local_end = min (local_end_with_pad , num_actual_tokens )
212+ local_pad_size = local_end_with_pad - local_end
213+
214+ if pad_size > 0 :
215+ cos = nn .functional .pad (cos , (0 , 0 , 0 , 0 , 0 , 0 , 0 , pad_size ))
216+ sin = nn .functional .pad (sin , (0 , 0 , 0 , 0 , 0 , 0 , 0 , pad_size ))
217+ slot_mapping = nn .functional .pad (slot_mapping , (0 , pad_size ), value = - 1 )
218+ cos = cos [local_start :local_end_with_pad ]
219+ sin = sin [local_start :local_end_with_pad ]
220+ slot_mapping_cp = slot_mapping [local_start :local_end_with_pad ]
221+
222+ actual_seq_lengths_query = torch .empty_like (cum_query_lens )
223+ actual_seq_lengths_key = torch .empty_like (seq_lens )
224+ num_segs = cum_query_lens .shape [0 ]
225+ last_token = 0
226+ cum = 0
227+ for i in range (0 , num_segs ):
228+ global_start = last_token
229+ global_end = cum_query_lens [i ].item ()
230+ last_token = global_end
231+
232+ local_start = max (global_start , local_start )
233+ local_end = min (global_end , local_end_with_pad )
234+ num_local_tokens = local_end - local_start
235+
236+ if num_local_tokens > 0 :
237+ cum += num_local_tokens
238+ actual_seq_lengths_query [i ] = cum
239+
240+ offset = global_end - local_end
241+ actual_seq_lengths_key [i ] = seq_lens [i ].item () - offset
242+ else :
243+ actual_seq_lengths_query [i ] = cum
244+ actual_seq_lengths_key [i ] = 0
245+
246+ sfa_cp_context = SfaCpContext (
247+ num_tokens = num_tokens ,
248+ num_tokens_pad = num_tokens_pad ,
249+ local_start = local_start ,
250+ local_end = local_end ,
251+ local_end_with_pad = local_end_with_pad ,
252+ pad_size = pad_size ,
253+ local_pad_size = local_pad_size ,
254+ slot_mapping_cp = slot_mapping_cp ,
255+ actual_seq_lengths_query = actual_seq_lengths_query ,
256+ actual_seq_lengths_key = actual_seq_lengths_key ,
257+ )
179258
180259 return self .metadata_cls ( # type: ignore
181260 has_prefill = has_prefill ,
@@ -189,7 +268,8 @@ def build(
189268 attn_state = common_attn_metadata .attn_state ,
190269 block_tables = block_table ,
191270 sin = sin ,
192- cos = cos )
271+ cos = cos ,
272+ sfa_cp_context = sfa_cp_context )
193273
194274 def build_for_graph_capture (
195275 self ,
@@ -265,67 +345,29 @@ def __init__(
265345 self .enable_shared_expert_dp = ascend_config .enable_shared_expert_dp
266346 self .enable_prefetch = ascend_config .weight_prefetch_config .enabled
267347 self .enable_kv_nz = ascend_config .torchair_graph_config .enable_kv_nz
268-
348+ self . model_config = get_current_vllm_config (). model_config
269349 assert self .indexer is not None , "Indexer is required for DSA."
270-
271- self .enable_sfa_cp = ascend_config .enable_sfa_cp
350+
351+ self .enable_sfa_cp = enable_sp () and \
352+ hasattr (self .model_config .hf_config , "index_topk" )
272353 self .local_num_heads = self .num_heads
273354 self .vllm_config = get_current_vllm_config ()
274355 if self .enable_sfa_cp :
275356 self .local_num_heads = self .num_heads * self .tp_size
276357
277- # Dispose tensor from the original q_proj
278- dispose_layer (self .q_proj )
279- # Construct the new q_proj using ReplicatedLinear
280- new_q_proj = ReplicatedLinear (
281- self .q_lora_rank ,
282- self .local_num_heads * self .qk_head_dim ,
283- bias = False ,
284- quant_config = self .vllm_config .quant_config ,
285- prefix = self .q_proj .prefix )
286- # Replace the q_proj with the new one
287- replace_layer (self .q_proj , new_q_proj )
288-
289- # Dispose tensor from the original kv_b_proj
290- dispose_layer (self .kv_b_proj )
291- # Construct the new kv_b_proj using ReplicatedLinear
292- new_kv_b_proj = ReplicatedLinear (
293- self .kv_lora_rank ,
294- self .local_num_heads *
295- (self .qk_nope_head_dim + self .v_head_dim ),
296- bias = False ,
297- quant_config = self .vllm_config .quant_config ,
298- prefix = self .kv_b_proj .prefix )
299- # Replace the kv_b_proj with the new one
300- replace_layer (self .kv_b_proj , new_kv_b_proj )
301-
302- # Dispose tensor from the original o_proj
303- dispose_layer (self .o_proj )
304- # Construct the new o_proj using ReplicatedLinear
305- config = self .vllm_config .model_config .hf_config
306- new_o_proj = ReplicatedLinear (
307- config .num_attention_heads * config .v_head_dim ,
308- config .hidden_size ,
309- bias = False ,
310- quant_config = self .vllm_config .quant_config ,
311- prefix = self .o_proj .prefix )
312- # Replace the o_proj with the new one
313- replace_layer (self .o_proj , new_o_proj )
314-
315- from vllm_ascend .distributed .parallel_state import \
316- get_shared_weight_group
317- if is_hidden_layer (self .vllm_config , self .q_proj ):
318- register_layer_to_shared_weight_series (
319- series_name = "q_proj" ,
320- group = get_shared_weight_group (),
321- layer = self .q_proj ,
322- prefetch_step = 1 )
323- if is_hidden_layer (self .vllm_config , self .o_proj ):
324- register_layer_to_shared_weight_series (
325- series_name = "o_proj" ,
326- group = get_shared_weight_group (),
327- layer = self .o_proj ,
328- prefetch_step = 1 )
358+ #TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
359+ self ._replace_linear_class_for_sfa_cp ()
360+ from vllm_ascend .distributed .parallel_state import get_shared_weight_group
361+ register_layer_to_shared_weight_series (
362+ series_name = "q_proj" ,
363+ group = get_shared_weight_group (),
364+ layer = self .q_proj ,
365+ prefetch_step = 1 )
366+ register_layer_to_shared_weight_series (
367+ series_name = "o_proj" ,
368+ group = get_shared_weight_group (),
369+ layer = self .o_proj ,
370+ prefetch_step = 1 )
329371
330372 # indexer param
331373 self .n_head : int = self .indexer .n_head # 64
@@ -479,6 +521,7 @@ def exec_kv(
479521 cache_mode = cache_mode ,
480522 is_output_kv = True ,
481523 )
524+ #TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
482525 k_pe = get_tp_group ().all_gather (k_pe , 0 )
483526 k_nope = get_tp_group ().all_gather (k_nope , 0 )
484527
@@ -538,11 +581,8 @@ def forward(
538581 has_prefill = attn_metadata .has_prefill
539582 num_actual_tokens = attn_metadata .num_actual_tokens
540583 hidden_states = hidden_states [:num_actual_tokens ]
541- sfa_sp_context = None
542584 if self .enable_sfa_cp :
543585 need_gather_q_kv = False
544- set_sfa_sp_context (hidden_states , attn_metadata .num_actual_tokens )
545- sfa_sp_context = get_sfa_sp_context ()
546586 # Inputs and outputs may be padded for CUDA graphs
547587 output_padded = output
548588 output = output [:num_actual_tokens ]
@@ -570,76 +610,38 @@ def forward(
570610 cos = attn_metadata .cos
571611 sin = attn_metadata .sin
572612 slot_mapping = attn_metadata .slot_mapping [:num_actual_tokens ]
573-
574613 slot_mapping_cp = None
575- if self .enable_sfa_cp and sfa_sp_context is not None :
576- if sfa_sp_context .pad_size > 0 :
577- cos = nn .functional .pad (
578- cos , (0 , 0 , 0 , 0 , 0 , 0 , 0 , sfa_sp_context .pad_size ))
579- sin = nn .functional .pad (
580- sin , (0 , 0 , 0 , 0 , 0 , 0 , 0 , sfa_sp_context .pad_size ))
581- slot_mapping = nn .functional .pad (slot_mapping ,
582- (0 , sfa_sp_context .pad_size ),
583- value = - 1 )
584- cos = cos [sfa_sp_context .local_start :sfa_sp_context .
585- local_end_with_pad ]
586- sin = sin [sfa_sp_context .local_start :sfa_sp_context .
587- local_end_with_pad ]
588- slot_mapping_cp = slot_mapping [
589- sfa_sp_context .local_start :sfa_sp_context .local_end_with_pad ]
614+ if self .enable_sfa_cp :
615+ slot_mapping_cp = attn_metadata .sfa_cp_context .slot_mapping_cp
590616
591617 self .exec_kv (kv_no_split , cos , sin , kv_cache , slot_mapping ,
592618 slot_mapping_cp )
593619
594- if self .enable_sfa_cp and sfa_sp_context is not None :
620+ if self .enable_sfa_cp and attn_metadata . sfa_cp_context is not None :
595621 if is_hidden_layer (self .vllm_config , self .q_proj ):
596622 reach_layer_for_shared_weight_series (self .q_proj )
597623 if is_hidden_layer (self .vllm_config , self .o_proj ):
598624 reach_layer_for_shared_weight_series (self .o_proj )
599625
600626 ql_nope , q_pe = self ._q_proj_and_k_up_proj (q_c )
601627 q_pe = self .rope_single (q_pe , cos , sin )
628+
629+ actual_seq_lengths_query = attn_metadata .cum_query_lens
630+ actual_seq_lengths_key = attn_metadata .seq_lens
602631
603- cum_query_lens = attn_metadata .cum_query_lens
604- seq_lens = attn_metadata .seq_lens
605- actual_seq_lengths_query = cum_query_lens
606- actual_seq_lengths_key = seq_lens
607-
608- if self .enable_sfa_cp and sfa_sp_context is not None :
609- actual_seq_lengths_query = torch .empty_like (cum_query_lens )
610- actual_seq_lengths_key = torch .empty_like (seq_lens )
611- num_segs = cum_query_lens .shape [0 ]
612- last_token = 0
613- cum = 0
614- for i in range (0 , num_segs ):
615- global_start = last_token
616- global_end = cum_query_lens [i ].item ()
617- last_token = global_end
618-
619- local_start = max (global_start , sfa_sp_context .local_start )
620- local_end = min (global_end , sfa_sp_context .local_end_with_pad )
621- num_local_tokens = local_end - local_start
622-
623- if num_local_tokens > 0 :
624- cum += num_local_tokens
625- actual_seq_lengths_query [i ] = cum
626-
627- offset = global_end - local_end
628- actual_seq_lengths_key [i ] = seq_lens [i ].item () - offset
629- else :
630- actual_seq_lengths_query [i ] = cum
631- actual_seq_lengths_key [i ] = 0
632-
633- topk_indices = self .indexer_select (
634- x = hidden_states ,
635- qr = q_c ,
636- kv_cache = kv_cache ,
637- attn_metadata = attn_metadata ,
638- cos = cos ,
639- sin = sin ,
640- actual_seq_lengths_query = actual_seq_lengths_query ,
641- actual_seq_lengths_key = actual_seq_lengths_key ,
642- need_gather_q_kv = need_gather_q_kv )
632+ if self .enable_sfa_cp :
633+ actual_seq_lengths_query = attn_metadata .sfa_cp_context .actual_seq_lengths_query
634+ actual_seq_lengths_key = attn_metadata .sfa_cp_context .actual_seq_lengths_key
635+
636+ topk_indices = self .indexer_select (x = hidden_states ,
637+ qr = q_c ,
638+ kv_cache = kv_cache ,
639+ attn_metadata = attn_metadata ,
640+ cos = cos ,
641+ sin = sin ,
642+ actual_seq_lengths_query = actual_seq_lengths_query ,
643+ actual_seq_lengths_key = actual_seq_lengths_key ,
644+ need_gather_q_kv = need_gather_q_kv )
643645 attn_output = torch .ops ._C_ascend .npu_sparse_flash_attention (
644646 query = ql_nope ,
645647 key = kv_cache [0 ],
@@ -749,3 +751,45 @@ def indexer_select(
749751 sparse_count = 2048 ,
750752 sparse_mode = 3 )
751753 return topk_indices
754+
755+ def _replace_linear_class_for_sfa_cp (self ):
756+
757+ vllm_config = get_current_vllm_config ()
758+ # Dispose tensor from the original q_proj
759+ dispose_layer (self .q_proj )
760+ # Construct the new q_proj using ReplicatedLinear
761+ new_q_proj = ReplicatedLinear (
762+ self .q_lora_rank ,
763+ self .local_num_heads * self .qk_head_dim ,
764+ bias = False ,
765+ quant_config = vllm_config .quant_config ,
766+ prefix = self .q_proj .prefix )
767+ # Replace the q_proj with the new one
768+ replace_layer (self .q_proj , new_q_proj )
769+
770+ # Dispose tensor from the original kv_b_proj
771+ dispose_layer (self .kv_b_proj )
772+ # Construct the new kv_b_proj using ReplicatedLinear
773+ new_kv_b_proj = ReplicatedLinear (
774+ self .kv_lora_rank ,
775+ self .local_num_heads * (self .qk_nope_head_dim + self .v_head_dim ),
776+ bias = False ,
777+ quant_config = vllm_config .quant_config ,
778+ prefix = self .kv_b_proj .prefix )
779+ # Replace the kv_b_proj with the new one
780+ replace_layer (self .kv_b_proj , new_kv_b_proj )
781+
782+ # Dispose tensor from the original o_proj
783+ dispose_layer (self .o_proj )
784+ # Construct the new o_proj using ReplicatedLinear
785+ config = vllm_config .model_config .hf_config
786+ new_o_proj = ReplicatedLinear (
787+ config .num_attention_heads * config .v_head_dim ,
788+ config .hidden_size ,
789+ bias = False ,
790+ quant_config = vllm_config .quant_config ,
791+ prefix = self .o_proj .prefix )
792+ # Replace the o_proj with the new one
793+ replace_layer (self .o_proj , new_o_proj )
794+
795+
0 commit comments