Skip to content

Commit 6856072

Browse files
committed
[CI] fix
Signed-off-by: zzhx1 <[email protected]>
1 parent 8775cce commit 6856072

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

vllm_ascend/distributed/parallel_state.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import torch
44
from vllm.config import ParallelConfig, get_current_vllm_config
55
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
6-
get_tp_group, get_world_group,
7-
init_model_parallel_group, get_pp_group)
6+
get_pp_group, get_tp_group,
7+
get_world_group,
8+
init_model_parallel_group)
89

910
import vllm_ascend.envs as envs_ascend
1011
from vllm_ascend.ascend_config import get_ascend_config
@@ -198,20 +199,25 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
198199
if flashcomm2_otp_size > 1:
199200
otp_group_ranks = []
200201
odp_group_ranks: list[list[int]] = [
201-
[] for _ in range(flashcomm2_otp_size * global_dp_size * global_pp_size)
202+
[] for _ in range(flashcomm2_otp_size * global_dp_size *
203+
global_pp_size)
202204
]
203205
for dp_group_index in range(global_dp_size):
204206
for pp_group_index in range(global_pp_size):
205-
tp_base_rank = (dp_group_index * global_pp_size + pp_group_index) * global_tp_size
207+
tp_base_rank = (dp_group_index * global_pp_size +
208+
pp_group_index) * global_tp_size
206209
for i in range(num_fc2_oproj_tensor_parallel_groups):
207210
ranks = []
208211
for j in range(flashcomm2_otp_size):
209212
tp_local_rank = i + j * num_fc2_oproj_tensor_parallel_groups
210213
assert tp_local_rank < global_tp_size
211214
global_rank = tp_base_rank + tp_local_rank
212215
ranks.append(global_rank)
213-
odp_group_index = (dp_group_index * global_pp_size + pp_group_index) * flashcomm2_otp_size + j
214-
odp_group_ranks[odp_group_index].append(global_rank)
216+
odp_group_index = (
217+
dp_group_index * global_pp_size +
218+
pp_group_index) * flashcomm2_otp_size + j
219+
odp_group_ranks[odp_group_index].append(
220+
global_rank)
215221
otp_group_ranks.append(ranks)
216222

217223
_FLASHCOMM2_OTP = init_model_parallel_group(

0 commit comments

Comments
 (0)