66from torch import nn
77from vllm .attention .backends .abstract import AttentionBackend , MLAAttentionImpl
88from vllm .config import VllmConfig , get_current_vllm_config
9- from vllm .distributed import get_tensor_model_parallel_world_size , get_tp_group
109from vllm .model_executor .layers .linear import (LinearBase , ReplicatedLinear ,
1110 UnquantizedLinearMethod )
1211from vllm .triton_utils import HAS_TRITON
1312from vllm .v1 .attention .backends .utils import AttentionCGSupport
13+ from vllm .logger import logger
1414
15+ from vllm_ascend import envs
1516from vllm_ascend .ascend_config import get_ascend_config
1617from vllm_ascend .attention .attention_v1 import AscendAttentionState
1718from vllm_ascend .attention .mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
1819from vllm_ascend .attention .utils import (AscendCommonAttentionMetadata ,
19- wait_for_kv_layer_from_connector )
20+ wait_for_kv_layer_from_connector ,
21+ trans_rope_weight , transdata )
22+ from vllm_ascend .distributed .sfa_sp_context import (get_sfa_sp_context ,
23+ set_sfa_sp_context )
2024from vllm_ascend .ops .shared_weight_layer import (
2125 is_hidden_layer , post_process_after_loading_for_shared_weight_series ,
2226 reach_layer_for_shared_weight_series ,
2630from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_ND , ACL_FORMAT_FRACTAL_NZ ,
2731 _round_up , dispose_layer , enable_sp ,
2832 is_enable_nz , replace_layer )
33+ from vllm_ascend .quantization .w8a8 import AscendW8A8LinearMethod
2934from vllm_ascend .worker .npu_input_batch import InputBatch
3035
3136if TYPE_CHECKING :
@@ -341,17 +346,54 @@ def __init__(
341346 self .enable_shared_expert_dp = ascend_config .enable_shared_expert_dp
342347 self .enable_prefetch = ascend_config .weight_prefetch_config .enabled
343348 self .enable_kv_nz = ascend_config .torchair_graph_config .enable_kv_nz
344- self .vllm_config = get_current_vllm_config ()
349+ self .enable_mlapo = envs .VLLM_ASCEND_ENABLE_MLAPO
350+
345351 assert self .indexer is not None , "Indexer is required for DSA."
346352
347353 self .enable_sfa_cp = enable_sp ()
348354 self .local_num_heads = self .num_heads
349-
355+ self . vllm_config = get_current_vllm_config ()
350356 if self .enable_sfa_cp :
351357 self .local_num_heads = self .num_heads * self .tp_size
352358
353- #TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
354- self ._replace_linear_class_for_sfa_cp ()
359+ # Dispose tensor from the original q_proj
360+ dispose_layer (self .q_proj )
361+ # Construct the new q_proj using ReplicatedLinear
362+ new_q_proj = ReplicatedLinear (
363+ self .q_lora_rank ,
364+ self .local_num_heads * self .qk_head_dim ,
365+ bias = False ,
366+ quant_config = self .vllm_config .quant_config ,
367+ prefix = self .q_proj .prefix )
368+ # Replace the q_proj with the new one
369+ replace_layer (self .q_proj , new_q_proj )
370+
371+ # Dispose tensor from the original kv_b_proj
372+ dispose_layer (self .kv_b_proj )
373+ # Construct the new kv_b_proj using ReplicatedLinear
374+ new_kv_b_proj = ReplicatedLinear (
375+ self .kv_lora_rank ,
376+ self .local_num_heads *
377+ (self .qk_nope_head_dim + self .v_head_dim ),
378+ bias = False ,
379+ quant_config = self .vllm_config .quant_config ,
380+ prefix = self .kv_b_proj .prefix )
381+ # Replace the kv_b_proj with the new one
382+ replace_layer (self .kv_b_proj , new_kv_b_proj )
383+
384+ # Dispose tensor from the original o_proj
385+ dispose_layer (self .o_proj )
386+ # Construct the new o_proj using ReplicatedLinear
387+ config = self .vllm_config .model_config .hf_config
388+ new_o_proj = ReplicatedLinear (
389+ config .num_attention_heads * config .v_head_dim ,
390+ config .hidden_size ,
391+ bias = False ,
392+ quant_config = self .vllm_config .quant_config ,
393+ prefix = self .o_proj .prefix )
394+ # Replace the o_proj with the new one
395+ replace_layer (self .o_proj , new_o_proj )
396+
355397 from vllm_ascend .distributed .parallel_state import \
356398 get_shared_weight_group
357399 if is_hidden_layer (self .vllm_config , self .q_proj ):
@@ -555,6 +597,98 @@ def rope_single(
555597 x = torch_npu .npu_interleave_rope (x , cos , sin )
556598 return x .view (B , N , D )
557599
600+ def _process_weights_for_fused_mlapo (self , act_dtype : torch .dtype ):
601+ assert self .kv_a_proj_with_mqa is None
602+ assert self .fused_qkv_a_proj is not None
603+
604+ kv_a_proj_wt = self .fused_qkv_a_proj .weight .data [
605+ ..., self .q_lora_rank :].contiguous ()
606+ q_a_proj_wt = self .fused_qkv_a_proj .weight .data [
607+ ..., :self .q_lora_rank ].contiguous ()
608+
609+ self .fused_qkv_a_proj .weight = None
610+
611+ kv_a_proj_wt = kv_a_proj_wt .t ().contiguous ()
612+ kv_a_proj_wt = trans_rope_weight (kv_a_proj_wt , self .qk_rope_head_dim )
613+ kv_a_proj_wt = kv_a_proj_wt .t ().contiguous ()
614+ wd_qkv = torch .cat ((kv_a_proj_wt , q_a_proj_wt ), dim = - 1 )
615+ wd_qkv = wd_qkv .t ().contiguous ()
616+ wd_qkv = transdata (wd_qkv ,
617+ block_size = (16 , 32 )).unsqueeze (0 ).contiguous ()
618+ self .wd_qkv = torch_npu .npu_format_cast (wd_qkv , 29 )
619+
620+ kv_a_proj_deq_scl = self .fused_qkv_a_proj .deq_scale [
621+ self .q_lora_rank :].contiguous ()
622+ q_a_proj_deq_scl = self .fused_qkv_a_proj .deq_scale [:self .
623+ q_lora_rank ].contiguous (
624+ )
625+ kv_a_proj_deq_scl = kv_a_proj_deq_scl .reshape (
626+ self .kv_lora_rank + self .qk_rope_head_dim , - 1 ).contiguous ()
627+ kv_a_proj_deq_scl = trans_rope_weight (kv_a_proj_deq_scl ,
628+ self .qk_rope_head_dim )
629+ kv_a_proj_deq_scl = kv_a_proj_deq_scl .view (
630+ self .kv_lora_rank + self .qk_rope_head_dim ).contiguous ()
631+ self .deq_scale_qkv = torch .cat ((kv_a_proj_deq_scl , q_a_proj_deq_scl ),
632+ dim = - 1 ).contiguous ()
633+
634+ kv_a_proj_qt_bias = self .fused_qkv_a_proj .quant_bias [
635+ self .q_lora_rank :].contiguous ()
636+ q_a_proj_qt_bias = self .fused_qkv_a_proj .quant_bias [:self .
637+ q_lora_rank ].contiguous (
638+ )
639+
640+ kv_a_proj_qt_bias = kv_a_proj_qt_bias .reshape (
641+ self .kv_lora_rank + self .qk_rope_head_dim , - 1 ).contiguous ()
642+ kv_a_proj_qt_bias = trans_rope_weight (kv_a_proj_qt_bias ,
643+ self .qk_rope_head_dim )
644+ kv_a_proj_qt_bias = kv_a_proj_qt_bias .view (
645+ self .kv_lora_rank + self .qk_rope_head_dim ).contiguous ()
646+ self .quant_bias_qkv = torch .cat ((kv_a_proj_qt_bias , q_a_proj_qt_bias ),
647+ dim = - 1 ).contiguous ()
648+
649+ wu_q = self .q_proj .weight .data
650+ wu_q = wu_q .t ().reshape (self .num_heads ,
651+ self .qk_nope_head_dim + self .qk_rope_head_dim ,
652+ - 1 )
653+ wu_q = trans_rope_weight (wu_q , self .qk_rope_head_dim )
654+ wu_q = wu_q .reshape (
655+ self .num_heads * (self .qk_nope_head_dim + self .qk_rope_head_dim ),
656+ - 1 )
657+ wu_q = transdata (wu_q , block_size = (16 , 32 )).unsqueeze (0 ).contiguous ()
658+ self .wu_q = torch_npu .npu_format_cast (wu_q , 29 )
659+
660+ qb_deq_scl = self .q_proj .deq_scale .data
661+ qb_deq_scl = qb_deq_scl .reshape (
662+ self .num_heads , self .qk_nope_head_dim + self .qk_rope_head_dim , - 1 )
663+ qb_deq_scl = trans_rope_weight (qb_deq_scl , self .qk_rope_head_dim )
664+ self .qb_deq_scl = qb_deq_scl .reshape (
665+ self .num_heads * (self .qk_nope_head_dim + self .qk_rope_head_dim ))
666+
667+ qb_qt_bias = self .q_proj .quant_bias .data
668+ qb_qt_bias = qb_qt_bias .reshape (
669+ self .num_heads , self .qk_nope_head_dim + self .qk_rope_head_dim , - 1 )
670+ qb_qt_bias = trans_rope_weight (qb_qt_bias , self .qk_rope_head_dim )
671+ self .qb_qt_bias = qb_qt_bias .reshape (
672+ self .num_heads * (self .qk_nope_head_dim + self .qk_rope_head_dim ))
673+
674+ device = self .q_proj .weight .device
675+ self .gamma1 = self .q_a_layernorm .weight .data
676+ self .beta1 = self .q_a_layernorm .bias .data
677+ self .gamma2 = self .kv_a_layernorm .weight .data
678+ self .quant_scale0 = self .fused_qkv_a_proj .input_scale .data
679+ self .quant_offset0 = self .fused_qkv_a_proj .input_offset .data
680+ self .quant_scale1 = self .q_proj .input_scale .data
681+ self .quant_offset1 = self .q_proj .input_offset .data
682+ self .ctkv_scale = torch .tensor ([1 ], dtype = act_dtype , device = device )
683+ self .q_nope_scale = torch .tensor ([1 ], dtype = act_dtype , device = device )
684+
685+ if self .vllm_config .kv_transfer_config is not None :
686+ self .fused_qkv_a_proj .deq_scale = None
687+ self .fused_qkv_a_proj .quant_bias = None
688+ self .q_proj .deq_scale = None
689+ self .q_proj .quant_bias = None
690+ torch .npu .empty_cache ()
691+
558692 def forward (
559693 self ,
560694 layer_name ,
@@ -578,56 +712,116 @@ def forward(
578712 return output .fill_ (0 )
579713 has_prefill = attn_metadata .has_prefill
580714 num_actual_tokens = attn_metadata .num_actual_tokens
715+ cos = attn_metadata .cos
716+ sin = attn_metadata .sin
717+ cum_query_lens = attn_metadata .cum_query_lens
718+ seq_lens = attn_metadata .seq_lens
719+ actual_seq_lengths_query = cum_query_lens
720+ actual_seq_lengths_key = seq_lens
581721 hidden_states = hidden_states [:num_actual_tokens ]
582722 if self .enable_sfa_cp :
583723 need_gather_q_kv = False
584724 # Inputs and outputs may be padded for CUDA graphs
585725 output_padded = output
586726 output = output [:num_actual_tokens ]
587- assert self .fused_qkv_a_proj is not None , "q lora is required for DSA."
588- maybe_npu_prefetch (inputs = self .fused_qkv_a_proj .weight ,
589- dependency = hidden_states ,
590- enabled = self .enable_prefetch )
591- qkv_lora = self .fused_qkv_a_proj (hidden_states )[0 ]
592- q_c , kv_no_split = qkv_lora .split (
593- [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ],
594- dim = - 1 ,
595- )
596- q_c = self .q_a_layernorm (q_c )
597-
598- # Process for Flash Comm V1
599- if need_gather_q_kv :
600- q_c = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
601- q_c .contiguous (), need_gather_q_kv )
602- kv_no_split = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
603- kv_no_split .contiguous (), need_gather_q_kv )
604-
605- if has_prefill :
606- wait_for_kv_layer_from_connector (layer_name )
607727
608- cos = attn_metadata .cos
609- sin = attn_metadata .sin
610- slot_mapping = attn_metadata .slot_mapping [:num_actual_tokens ]
611- slot_mapping_cp = None
612- actual_seq_lengths_query = attn_metadata .cum_query_lens
613- actual_seq_lengths_key = attn_metadata .seq_lens
614- if self .enable_sfa_cp :
615- assert attn_metadata .sfa_cp_context is not None
616- slot_mapping_cp = attn_metadata .sfa_cp_context .slot_mapping_cp
617- actual_seq_lengths_query = attn_metadata .sfa_cp_context .actual_seq_lengths_query
618- actual_seq_lengths_key = attn_metadata .sfa_cp_context .actual_seq_lengths_key
619-
620- self .exec_kv (kv_no_split , cos , sin , kv_cache , slot_mapping ,
621- slot_mapping_cp )
622-
623- if self .enable_sfa_cp and attn_metadata .sfa_cp_context is not None :
624- if is_hidden_layer (self .vllm_config , self .q_proj ):
625- reach_layer_for_shared_weight_series (self .q_proj )
626- if is_hidden_layer (self .vllm_config , self .o_proj ):
627- reach_layer_for_shared_weight_series (self .o_proj )
628-
629- ql_nope , q_pe = self ._q_proj_and_k_up_proj (q_c )
630- q_pe = self .rope_single (q_pe , cos , sin )
728+ from vllm .forward_context import get_forward_context
729+ forward_context = get_forward_context ()
730+ if self .enable_mlapo and not forward_context .with_prefill :
731+ hidden_states = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
732+ hidden_states .contiguous (), need_gather_q_kv )
733+ k_nope , k_pe = kv_cache [0 ], kv_cache [1 ]
734+ ql_nope = torch .empty (
735+ (num_actual_tokens , self .W_UK_T .shape [0 ], k_nope .shape [- 1 ]),
736+ dtype = hidden_states .dtype ,
737+ device = hidden_states .device ,
738+ )
739+ q_pe = torch .empty (
740+ (num_actual_tokens , self .W_UK_T .shape [0 ], k_pe .shape [- 1 ]),
741+ dtype = hidden_states .dtype ,
742+ device = hidden_states .device ,
743+ )
744+ q_c = torch .empty (
745+ (num_actual_tokens , self .q_lora_rank ),
746+ dtype = hidden_states .dtype ,
747+ device = hidden_states .device ,
748+ )
749+ torch .ops ._C_ascend .mla_preprocess (
750+ hidden_states ,
751+ self .wd_qkv ,
752+ self .deq_scale_qkv ,
753+ self .gamma1 ,
754+ self .beta1 ,
755+ self .wu_q ,
756+ self .qb_deq_scl ,
757+ self .gamma2 ,
758+ cos ,
759+ sin ,
760+ self .W_UK_T ,
761+ k_nope ,
762+ k_pe ,
763+ attn_metadata .slot_mapping [:num_actual_tokens ].flatten (),
764+ quant_scale0 = self .quant_scale0 ,
765+ quant_offset0 = self .quant_offset0 ,
766+ bias0 = self .quant_bias_qkv ,
767+ quant_scale1 = self .quant_scale1 ,
768+ quant_offset1 = self .quant_offset1 ,
769+ bias1 = self .qb_qt_bias ,
770+ ctkv_scale = self .ctkv_scale ,
771+ q_nope_scale = self .q_nope_scale ,
772+ cache_mode = "krope_ctkv" ,
773+ quant_mode = "per_tensor_quant_asymm" ,
774+ enable_inner_out = True ,
775+ q_out0 = ql_nope ,
776+ kv_cache_out0 = k_nope ,
777+ q_out1 = q_pe ,
778+ kv_cache_out1 = k_pe ,
779+ inner_out = q_c ,
780+ )
781+ else :
782+ assert self .fused_qkv_a_proj is not None , "q lora is required for DSA."
783+ maybe_npu_prefetch (inputs = self .fused_qkv_a_proj .weight ,
784+ dependency = hidden_states ,
785+ enabled = self .enable_prefetch )
786+ qkv_lora = self .fused_qkv_a_proj (hidden_states )[0 ]
787+ q_c , kv_no_split = qkv_lora .split (
788+ [self .q_lora_rank , self .kv_lora_rank + self .qk_rope_head_dim ],
789+ dim = - 1 ,
790+ )
791+ q_c = self .q_a_layernorm (q_c )
792+ # Process for Flash Comm V1
793+ if need_gather_q_kv :
794+ q_c = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
795+ q_c .contiguous (), need_gather_q_kv )
796+ kv_no_split = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
797+ kv_no_split .contiguous (), need_gather_q_kv )
798+
799+ if has_prefill :
800+ wait_for_kv_layer_from_connector (layer_name )
801+
802+ cos = attn_metadata .cos
803+ sin = attn_metadata .sin
804+ slot_mapping = attn_metadata .slot_mapping [:num_actual_tokens ]
805+ slot_mapping_cp = None
806+ actual_seq_lengths_query = attn_metadata .cum_query_lens
807+ actual_seq_lengths_key = attn_metadata .seq_lens
808+ if self .enable_sfa_cp :
809+ assert attn_metadata .sfa_cp_context is not None
810+ slot_mapping_cp = attn_metadata .sfa_cp_context .slot_mapping_cp
811+ actual_seq_lengths_query = attn_metadata .sfa_cp_context .actual_seq_lengths_query
812+ actual_seq_lengths_key = attn_metadata .sfa_cp_context .actual_seq_lengths_key
813+
814+ self .exec_kv (kv_no_split , cos , sin , kv_cache , slot_mapping ,
815+ slot_mapping_cp )
816+
817+ if self .enable_sfa_cp and attn_metadata .sfa_cp_context is not None :
818+ if is_hidden_layer (self .vllm_config , self .q_proj ):
819+ reach_layer_for_shared_weight_series (self .q_proj )
820+ if is_hidden_layer (self .vllm_config , self .o_proj ):
821+ reach_layer_for_shared_weight_series (self .o_proj )
822+
823+ ql_nope , q_pe = self ._q_proj_and_k_up_proj (q_c )
824+ q_pe = self .rope_single (q_pe , cos , sin )
631825
632826 topk_indices = self .indexer_select (
633827 x = hidden_states ,
0 commit comments