Skip to content

Commit 35c6d4d

Browse files
zzhx1Levi-JQ
andcommitted
According to the suggestions proposed by gemini code assist, make modifications.
Co-authored-by: Levi-JQ <[email protected]> Signed-off-by: zzhx1 <[email protected]>
1 parent 6856072 commit 35c6d4d

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

vllm_ascend/distributed/parallel_state.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,18 +204,19 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
204204
]
205205
for dp_group_index in range(global_dp_size):
206206
for pp_group_index in range(global_pp_size):
207-
tp_base_rank = (dp_group_index * global_pp_size +
208-
pp_group_index) * global_tp_size
207+
dp_pp_serial_index = dp_group_index * global_pp_size + pp_group_index
208+
tp_base_rank = dp_pp_serial_index * global_tp_size
209+
odp_base_index = dp_pp_serial_index * flashcomm2_otp_size
210+
209211
for i in range(num_fc2_oproj_tensor_parallel_groups):
210212
ranks = []
211213
for j in range(flashcomm2_otp_size):
212214
tp_local_rank = i + j * num_fc2_oproj_tensor_parallel_groups
213215
assert tp_local_rank < global_tp_size
214216
global_rank = tp_base_rank + tp_local_rank
215217
ranks.append(global_rank)
216-
odp_group_index = (
217-
dp_group_index * global_pp_size +
218-
pp_group_index) * flashcomm2_otp_size + j
218+
219+
odp_group_index = odp_base_index + j
219220
odp_group_ranks[odp_group_index].append(
220221
global_rank)
221222
otp_group_ranks.append(ranks)

0 commit comments

Comments
 (0)