@@ -941,18 +941,22 @@ def flashcomm2_enable() -> bool:
941941 return envs_ascend .VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE > 0
942942
943943
944+ def flashcomm2_o_shared_enabled () -> bool :
945+ return envs_ascend .VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED > 0
946+
947+
944948def get_flashcomm2_config_and_validate (ascend_config , vllm_config ):
945949 flashcomm2_oproj_tp_size = envs_ascend .VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE
946950 global_tp_size = vllm_config .parallel_config .tensor_parallel_size
947- flashcomm2_oproj_shared = envs_ascend . VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED
951+ flashcomm2_oproj_shared = flashcomm2_o_shared_enabled ()
948952
949953 if not flashcomm2_enable ():
950954 flashcomm2_oproj_shared = False
951955 logger .info ("FLASHCOMM2 not enable." )
952956 return flashcomm2_oproj_tp_size , flashcomm2_oproj_shared
953957
954958 logger .info (
955- f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size= { flashcomm2_oproj_tp_size } and global_tp_size= { global_tp_size } "
959+ f"Enable FLASHCOMM2 with flashcomm2_oproj_tensor_parallel_size = { flashcomm2_oproj_tp_size } and oproj_shared_enabled = { flashcomm2_oproj_shared } "
956960 )
957961 if not envs_ascend .VLLM_ASCEND_ENABLE_FLASHCOMM1 :
958962 logger .warning_once (
@@ -979,14 +983,12 @@ def get_flashcomm2_config_and_validate(ascend_config, vllm_config):
979983 "FLASHCOMM2 primarily targets P-scenario deployments, "
980984 "with additional support for hybrid deployment scenarios. "
981985 "It is not applicable in D-scenario environments." )
986+ if flashcomm2_oproj_shared :
987+ logger .info ("Enable FLASHCOMM2 with oproj_shared." )
982988
983989 return flashcomm2_oproj_tp_size , flashcomm2_oproj_shared
984990
985991
986- def flashcomm2_o_shared_enabled () -> bool :
987- return get_ascend_config ().flashcomm2_oproj_shared
988-
989-
990992def get_flashcomm2_reorgnized_batch_ids (global_tp_size ) -> list [list [int ]]:
991993 # Reorganize batch_ids so that, after the all2all and reduce-scatter operation, each batch_id corresponds to the rank_id within the DP domain.
992994 # For example, when DP = [0, 1, 2, ..., 15] and flashcomm2_oproj_tensor_parallel_size = 2,
0 commit comments