Skip to content

Commit e0eb802

Browse files
committed
Optimize code
Signed-off-by: zzhxx <[email protected]>
1 parent e137319 commit e0eb802

File tree

2 files changed

+8
-12
lines changed

2 files changed

+8
-12
lines changed

vllm_ascend/ascend_config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,6 @@ def __init__(self, vllm_config):
137137
from vllm_ascend.utils import get_flashcomm2_config_and_validate
138138
self.flashcomm2_oproj_tensor_parallel_size, self.flashcomm2_oproj_shared = get_flashcomm2_config_and_validate(
139139
self, vllm_config)
140-
if self.flashcomm2_oproj_shared:
141-
if self.flashcomm2_oproj_tensor_parallel_size == 0:
142-
raise AssertionError(
143-
"flashcomm2_oproj_shared must be enabled with flashcomm2_oproj_tensor_parallel_size > 0"
144-
)
145-
logger.info("Enable Flashcomm2 with flashcomm2_oproj_shared")
146140

147141

148142
class TorchairGraphConfig:

vllm_ascend/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -941,18 +941,22 @@ def flashcomm2_enable() -> bool:
941941
return envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
942942

943943

944+
def flashcomm2_o_shared_enabled() -> bool:
945+
return envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED > 0
946+
947+
944948
def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
945949
flashcomm2_oproj_tp_size = envs_ascend.VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
946950
global_tp_size = vllm_config.parallel_config.tensor_parallel_size
947-
flashcomm2_oproj_shared = envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED
951+
flashcomm2_oproj_shared = flashcomm2_o_shared_enabled()
948952

949953
if not flashcomm2_enable():
950954
flashcomm2_oproj_shared = False
951955
logger.info("FLASHCOMM2 not enable.")
952956
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
953957

954958
logger.info(
955-
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size={flashcomm2_oproj_tp_size} and global_tp_size={global_tp_size}"
959+
f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = {flashcomm2_oproj_tp_size} and oproj_shared_enabled = {flashcomm2_oproj_shared}"
956960
)
957961
if not envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1:
958962
logger.warning_once(
@@ -979,14 +983,12 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
979983
"FLASHCOMM2 primarily targets P-scenario deployments, "
980984
"with additional support for hybrid deployment scenarios. "
981985
"It is not applicable in D-scenario environments.")
986+
if flashcomm2_oproj_shared:
987+
logger.info("Enable FLASHCOMM2 with oproj_shared.")
982988

983989
return flashcomm2_oproj_tp_size, flashcomm2_oproj_shared
984990

985991

986-
def flashcomm2_o_shared_enabled() -> bool:
987-
return get_ascend_config().flashcomm2_oproj_shared
988-
989-
990992
def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
991993
# Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain.
992994
# For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2,

0 commit comments

Comments
 (0)