3434from vllm_ascend .compilation .acl_graph import (get_graph_params ,
3535 get_mtp_graph_params ,
3636 update_graph_params_workspaces )
37+ from vllm_ascend .ops .shared_weight_layer import (
38+ is_hidden_layer , post_process_after_loading_for_shared_weight_series ,
39+ reach_layer_for_shared_weight_series ,
40+ register_layer_to_shared_weight_series )
3741from vllm_ascend .ops .weight_prefetch import maybe_npu_prefetch
3842from vllm_ascend .quantization .w8a8 import AscendW8A8LinearMethod
3943from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_ND , ACL_FORMAT_FRACTAL_NZ ,
40- is_enable_nz , weak_ref_tensors )
44+ flashcomm2_o_shared_enabled , is_enable_nz ,
45+ weak_ref_tensors )
4146from vllm_ascend .worker .npu_input_batch import InputBatch
4247
4348if TYPE_CHECKING :
@@ -848,6 +853,19 @@ def __init__(
848853 'q_b_proj' ]
849854 self .kv_b_proj = kwargs ['kv_b_proj' ]
850855 self .o_proj = kwargs ['o_proj' ]
856+ self .vllm_config = get_current_vllm_config ()
857+ self .fc2_o_shared_enable = flashcomm2_o_shared_enabled ()
858+
859+ if self .fc2_o_shared_enable and is_hidden_layer (
860+ self .vllm_config , self .o_proj ):
861+ from vllm_ascend .distributed .parallel_state import \
862+ get_shared_weight_group
863+ register_layer_to_shared_weight_series (
864+ series_name = "o_proj" ,
865+ group = get_shared_weight_group (),
866+ layer = self .o_proj ,
867+ prefetch_step = 1 )
868+
851869 self .kv_a_proj_with_mqa = kwargs .get ('kv_a_proj_with_mqa' , None )
852870 self .kv_a_layernorm = kwargs .get ('kv_a_layernorm' , None )
853871 self .q_a_layernorm = kwargs .get ('q_a_layernorm' , None )
@@ -858,10 +876,9 @@ def __init__(
858876 self .enable_shared_expert_dp = ascend_config .enable_shared_expert_dp
859877 self .enable_prefetch = ascend_config .weight_prefetch_config .enabled
860878
861- vllm_config = get_current_vllm_config ()
862879 self .ring_mla_mask_size = 512
863880
864- self .speculative_config = vllm_config .speculative_config
881+ self .speculative_config = self . vllm_config .speculative_config
865882 self .enable_mlapo = envs .VLLM_ASCEND_ENABLE_MLAPO
866883
867884 self .pcp_size = get_pcp_group ().world_size
@@ -995,6 +1012,10 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
9951012 if self .enable_mlapo :
9961013 self ._process_weights_for_fused_mlapo (act_dtype )
9971014
1015+ if self .fc2_o_shared_enable and is_hidden_layer (
1016+ self .vllm_config , self .o_proj ):
1017+ post_process_after_loading_for_shared_weight_series (self .o_proj )
1018+
9981019 def _process_weights_for_fused_mlapo (self , act_dtype : torch .dtype ):
9991020 kv_a_proj_wt = self .fused_qkv_a_proj .weight .data [
10001021 ..., self .q_lora_rank :].contiguous ()
@@ -1515,6 +1536,10 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
15151536 kv_no_split = torch .ops .vllm .maybe_all_gather_and_maybe_unpad (
15161537 kv_no_split .contiguous (), need_gather_q_kv )
15171538
1539+ if self .fc2_o_shared_enable and is_hidden_layer (
1540+ self .vllm_config , self .o_proj ):
1541+ reach_layer_for_shared_weight_series (self .o_proj )
1542+
15181543 decode_preprocess_res = None
15191544 prefill_preprocess_res = None
15201545 if has_prefill :
@@ -1633,6 +1658,9 @@ def forward(
16331658 assert output is not None , "Output tensor must be provided."
16341659 if attn_metadata is None :
16351660 # Profiling run.
1661+ if self .fc2_o_shared_enable and is_hidden_layer (
1662+ self .vllm_config , self .o_proj ):
1663+ reach_layer_for_shared_weight_series (self .o_proj )
16361664 return output .fill_ (0 )
16371665 if self .pcp_size > 1 :
16381666 num_actual_tokens = attn_metadata .num_actual_tokens_pcp_padded // self .pcp_size
0 commit comments