Skip to content

Commit eac72f5

Browse files
zzhx1zzh02232027clrs97Levi-JQ
authored
[Feat] Flashcomm2 use o_shared linear (#4188)
### What this PR does / why we need it? It is mentioned in the [flashcomm2 technical report](https://gitcode.com/ascend-tribe/ascend-inference-cluster/blob/main/FlashComm/FlashComm2%E5%A4%A7%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86%E4%B8%AD%E4%BB%A5%E5%AD%98%E6%8D%A2%E4%BC%A0%E7%9A%84%E9%80%9A%E4%BF%A1%E4%BC%98%E5%8C%96%E6%8A%80%E6%9C%AF.pdf) that FC2 will introduce full redundant storage of the o_proj matrix, which will put pressure on the memory. Therefore, the technical report proposed a compromise solution using otp2, but it will introduce additional reduce-scatter communication. We propose a shared linear feature (#2931 ) that supports distributing weights layer by layer to each card, avoiding the need for TP splitting, and can solve the memory issue. This PR depends on #3232 and #2931 ### Flashcomm2 flowchart <img width="1142" height="878" alt="PixPin_2025-11-14_13-37-39" src="https://github.com/user-attachments/assets/d45ea8db-d8ef-4d45-8e18-abd4d82ce3e0" /> ### Does this PR introduce _any_ user-facing change? Use environment variables ```bash export VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE=1 export VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED=1 ``` - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: zzhx1 <[email protected]> Signed-off-by: zzhxx <[email protected]> Co-authored-by: zzh02232027 <[email protected]> Co-authored-by: clrs97 <[email protected]> Co-authored-by: Levi-JQ <[email protected]>
1 parent bb76f79 commit eac72f5

File tree

8 files changed

+86
-25
lines changed

8 files changed

+86
-25
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from vllm.model_executor.layers.linear import LinearBase
77

88
from tests.ut.base import TestBase
9+
from vllm_ascend.ascend_config import init_ascend_config
910
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1011
from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
1112
AscendMLADecodeMetadata,
@@ -845,6 +846,8 @@ def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
845846
model_config.dtype = torch.float16
846847
vllm_config.model_config = model_config
847848
get_current_vllm_config.return_value = vllm_config
849+
vllm_config.additional_config = {"refresh": True}
850+
init_ascend_config(vllm_config)
848851

849852
num_heads = 256
850853
head_size = 1024

tests/ut/distributed/test_parallel_state.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
4646
mock_vllm_config.kv_transfer_config.is_kv_producer = True
4747
mock_envs_ascend = MagicMock()
4848
mock_envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE = 2
49+
mock_envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED = 0
4950
mock_envs_ascend.VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL = 0
5051
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
5152
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \

vllm_ascend/ascend_config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,8 @@ def __init__(self, vllm_config):
165165
"Only support P node tp size lagger then D node tp size")
166166
self.SLO_limits_for_dynamic_batch = additional_config.get(
167167
"SLO_limits_for_dynamic_batch", -1)
168-
from vllm_ascend.utils import \
169-
get_flashcomm2_oproj_tp_size_and_validate_config
170-
self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config(
168+
from vllm_ascend.utils import get_flashcomm2_config_and_validate
169+
self.flashcomm2_oproj_tensor_parallel_size, self.flashcomm2_oproj_shared = get_flashcomm2_config_and_validate(
171170
self, vllm_config)
172171
self.enable_npugraph_ex = additional_config.get(
173172
"enable_npugraph_ex", False)

vllm_ascend/attention/mla_v1.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,15 @@
3434
from vllm_ascend.compilation.acl_graph import (get_graph_params,
3535
get_mtp_graph_params,
3636
update_graph_params_workspaces)
37+
from vllm_ascend.ops.shared_weight_layer import (
38+
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
39+
reach_layer_for_shared_weight_series,
40+
register_layer_to_shared_weight_series)
3741
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
3842
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
3943
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
40-
is_enable_nz, weak_ref_tensors)
44+
flashcomm2_o_shared_enabled, is_enable_nz,
45+
weak_ref_tensors)
4146
from vllm_ascend.worker.npu_input_batch import InputBatch
4247

4348
if TYPE_CHECKING:
@@ -848,6 +853,19 @@ def __init__(
848853
'q_b_proj']
849854
self.kv_b_proj = kwargs['kv_b_proj']
850855
self.o_proj = kwargs['o_proj']
856+
self.vllm_config = get_current_vllm_config()
857+
self.fc2_o_shared_enable = flashcomm2_o_shared_enabled()
858+
859+
if self.fc2_o_shared_enable and is_hidden_layer(
860+
self.vllm_config, self.o_proj):
861+
from vllm_ascend.distributed.parallel_state import \
862+
get_shared_weight_group
863+
register_layer_to_shared_weight_series(
864+
series_name="o_proj",
865+
group=get_shared_weight_group(),
866+
layer=self.o_proj,
867+
prefetch_step=1)
868+
851869
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
852870
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
853871
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
@@ -858,10 +876,9 @@ def __init__(
858876
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
859877
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
860878

861-
vllm_config = get_current_vllm_config()
862879
self.ring_mla_mask_size = 512
863880

864-
self.speculative_config = vllm_config.speculative_config
881+
self.speculative_config = self.vllm_config.speculative_config
865882
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
866883

867884
self.pcp_size = get_pcp_group().world_size
@@ -995,6 +1012,10 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
9951012
if self.enable_mlapo:
9961013
self._process_weights_for_fused_mlapo(act_dtype)
9971014

1015+
if self.fc2_o_shared_enable and is_hidden_layer(
1016+
self.vllm_config, self.o_proj):
1017+
post_process_after_loading_for_shared_weight_series(self.o_proj)
1018+
9981019
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
9991020
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
10001021
..., self.q_lora_rank:].contiguous()
@@ -1515,6 +1536,10 @@ def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
15151536
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
15161537
kv_no_split.contiguous(), need_gather_q_kv)
15171538

1539+
if self.fc2_o_shared_enable and is_hidden_layer(
1540+
self.vllm_config, self.o_proj):
1541+
reach_layer_for_shared_weight_series(self.o_proj)
1542+
15181543
decode_preprocess_res = None
15191544
prefill_preprocess_res = None
15201545
if has_prefill:
@@ -1633,6 +1658,9 @@ def forward(
16331658
assert output is not None, "Output tensor must be provided."
16341659
if attn_metadata is None:
16351660
# Profiling run.
1661+
if self.fc2_o_shared_enable and is_hidden_layer(
1662+
self.vllm_config, self.o_proj):
1663+
reach_layer_for_shared_weight_series(self.o_proj)
16361664
return output.fill_(0)
16371665
if self.pcp_size > 1:
16381666
num_actual_tokens = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size

vllm_ascend/distributed/parallel_state.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import vllm_ascend.envs as envs_ascend
1111
from vllm_ascend.ascend_config import get_ascend_config
12-
from vllm_ascend.utils import enable_sp, flashcomm2_enable
12+
from vllm_ascend.utils import (enable_sp, flashcomm2_enable,
13+
flashcomm2_o_shared_enabled)
1314

1415
# Currently, mc2 op need their own group coordinator.
1516
_MC2: Optional[GroupCoordinator] = None
@@ -77,6 +78,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
7778
assert torch.distributed.is_initialized()
7879
world_size = torch.distributed.get_world_size()
7980
backend = torch.distributed.get_backend(get_world_group().device_group)
81+
vllm_config = get_current_vllm_config()
8082

8183
# The layout of all ranks: ExternalDP * EP
8284
# ExternalDP is the data parallel group that is not part of the model,
@@ -182,6 +184,29 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
182184
backend,
183185
group_name="lmheadtp")
184186

187+
def _create_shared_weight_group(group_name: str) -> GroupCoordinator:
188+
#This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference
189+
group_ranks = []
190+
for pp_idx in range(global_pp_size):
191+
group = []
192+
for dp_idx in range(global_dp_size):
193+
base = (dp_idx * global_pp_size + pp_idx) * global_tp_size
194+
for i in range(global_tp_size):
195+
global_rank = base + i
196+
group.append(global_rank)
197+
group_ranks.append(group)
198+
199+
return init_model_parallel_group(group_ranks,
200+
get_world_group().local_rank,
201+
backend,
202+
group_name=group_name)
203+
204+
global _SHARED_WEIGHT
205+
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
206+
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
207+
if enable_sp() and is_ds_v32:
208+
_SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight")
209+
185210
# TODO: Extract and unify the logic across different communication group.
186211
if flashcomm2_enable():
187212
flashcomm2_otp_size = get_ascend_config(
@@ -234,17 +259,10 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
234259
backend,
235260
group_name="flashcomm2_odp")
236261

237-
vllm_config = get_current_vllm_config()
238-
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
239-
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")
240-
if enable_sp() and is_ds_v32:
241-
global _SHARED_WEIGHT
242-
group_ranks = [list(range(torch.distributed.get_world_size()))]
243-
_SHARED_WEIGHT = init_model_parallel_group(
244-
group_ranks,
245-
get_world_group().local_rank,
246-
backend,
247-
group_name="CP_shared_weight")
262+
# Create shared weight group for flashcomm2 oproj
263+
if flashcomm2_o_shared_enabled():
264+
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
265+
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
248266

249267

250268
def get_mlp_tensor_model_parallel_world_size():

vllm_ascend/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@
9797
# between this feature and FLASHCOMM1, please refer to the feature guide in the documentation.
9898
"VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE":
9999
lambda: int(os.getenv("VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE", 0)),
100+
# This feature is bound to the previous VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE, and it adds the shared weight feature,
101+
# which can eliminate redundant storage of weights. More detailed information can be found in PR#4188.
102+
# We recommend that you enable it when Flashcomm2 is enabled and the VRAM capacity is limited.
103+
"VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED":
104+
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED", "0"))),
100105
# Whether to enable MLP weight prefetch, only used in small concurrency.
101106
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
102107
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),

vllm_ascend/ops/shared_weight_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,4 +249,4 @@ def reach_layer_for_shared_weight_series(layer: LinearBase):
249249
def is_hidden_layer(vllm_config, layer: LinearBase) -> bool:
250250
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
251251
layer_idx = extract_layer_index(layer.prefix)
252-
return layer_idx < num_hidden_layers
252+
return layer_idx < num_hidden_layers

vllm_ascend/utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -953,17 +953,22 @@ def flashcomm2_enable() -> bool:
953953
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
954954

955955

956-
def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
957-
vllm_config):
956+
def flashcomm2_o_shared_enabled() -> bool:
957+
return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED
958+
959+
960+
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
958961
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
959962
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
963+
flashcomm2_oproj_shared = flashcomm2_o_shared_enabled()
960964

961965
if not flashcomm2_enable():
962-
logger.debug("FLASHCOMM2 not enable.")
963-
return flashcomm2_oproj_tp_size
966+
flashcomm2_oproj_shared = False
967+
logger.info("FLASHCOMM2 not enable.")
968+
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
964969

965970
logger.info(
966-
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size={flashcomm2_oproj_tp_size} and global_tp_size={global_tp_size}"
971+
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}"
967972
)
968973
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
969974
logger.warning_once(
@@ -990,8 +995,10 @@ def get_flashcomm2_oproj_tp_size_and_validate_config(ascend_config,
990995
"FLASHCOMM2 primarily targets P-scenario deployments, "
991996
"with additional support for hybrid deployment scenarios. "
992997
"It is not applicable in D-scenario environments.")
998+
if flashcomm2_oproj_shared:
999+
logger.info("Enable FLASHCOMM2 with oproj_shared.")
9931000

994-
return flashcomm2_oproj_tp_size
1001+
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
9951002

9961003

9971004
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:

0 commit comments

Comments
 (0)