Skip to content

Commit 8775cce

Browse files
committed
Fix bug with establishing the flashcomm2 and pp communication domains.
Signed-off-by: zzhx1 <[email protected]>
1 parent b5f7a83 commit 8775cce

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

vllm_ascend/distributed/parallel_state.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from vllm.config import ParallelConfig, get_current_vllm_config
55
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
66
get_tp_group, get_world_group,
7-
init_model_parallel_group)
7+
init_model_parallel_group, get_pp_group)
88

99
import vllm_ascend.envs as envs_ascend
1010
from vllm_ascend.ascend_config import get_ascend_config
@@ -185,6 +185,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
185185
).flashcomm2_oproj_tensor_parallel_size
186186
global_tp_size = get_tp_group().world_size
187187
global_dp_size = get_dp_group().world_size
188+
global_pp_size = get_pp_group().world_size
188189
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
189190
flashcomm2_otp_size)
190191

@@ -197,18 +198,21 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
197198
if flashcomm2_otp_size > 1:
198199
otp_group_ranks = []
199200
odp_group_ranks: list[list[int]] = [
200-
[] for _ in range(flashcomm2_otp_size * global_dp_size)
201+
[] for _ in range(flashcomm2_otp_size * global_dp_size * global_pp_size)
201202
]
202-
203203
for dp_group_index in range(global_dp_size):
204-
for i in range(num_fc2_oproj_tensor_parallel_groups):
205-
ranks = []
206-
for j in range(flashcomm2_otp_size):
207-
rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups
208-
ranks.append(rank_idx)
209-
odp_group_index = dp_group_index * flashcomm2_otp_size + j
210-
odp_group_ranks[odp_group_index].append(rank_idx)
211-
otp_group_ranks.append(ranks)
204+
for pp_group_index in range(global_pp_size):
205+
tp_base_rank = (dp_group_index * global_pp_size + pp_group_index) * global_tp_size
206+
for i in range(num_fc2_oproj_tensor_parallel_groups):
207+
ranks = []
208+
for j in range(flashcomm2_otp_size):
209+
tp_local_rank = i + j * num_fc2_oproj_tensor_parallel_groups
210+
assert tp_local_rank < global_tp_size
211+
global_rank = tp_base_rank + tp_local_rank
212+
ranks.append(global_rank)
213+
odp_group_index = (dp_group_index * global_pp_size + pp_group_index) * flashcomm2_otp_size + j
214+
odp_group_ranks[odp_group_index].append(global_rank)
215+
otp_group_ranks.append(ranks)
212216

213217
_FLASHCOMM2_OTP = init_model_parallel_group(
214218
otp_group_ranks,

0 commit comments

Comments
 (0)