Skip to content

Commit f6cec07

Browse files
author
Levi-JQ
committed
optimization of oshard parallel group
Signed-off-by: Levi-JQ <[email protected]>
1 parent bfbda42 commit f6cec07

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

vllm_ascend/distributed/parallel_state.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -231,15 +231,17 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
231231

232232
if flashcomm2_o_shared_enabled():
233233
global _FLASHCOMM2_O_SHARED
234-
group_ranks = [[] for _ in range(flashcomm2_otp_size)]
235-
for i in range(world_size):
236-
group_ranks[i % global_tp_size * flashcomm2_otp_size //
237-
global_tp_size].append(i)
238-
_FLASHCOMM2_O_SHARED = init_model_parallel_group(
239-
group_ranks,
240-
get_world_group().local_rank,
241-
backend,
242-
group_name="flashcomm2_o_shared")
234+
_FLASHCOMM2_O_SHARED = _FLASHCOMM2_ODP
235+
if global_dp_size > 1:
236+
group_ranks = [[] for _ in range(flashcomm2_otp_size)]
237+
for i in range(world_size):
238+
group_ranks[i % global_tp_size * flashcomm2_otp_size //
239+
global_tp_size].append(i)
240+
_FLASHCOMM2_O_SHARED = init_model_parallel_group(
241+
group_ranks,
242+
get_world_group().local_rank,
243+
backend,
244+
group_name="flashcomm2_o_shared")
243245

244246

245247
def get_mlp_tensor_model_parallel_world_size():

0 commit comments

Comments
 (0)