44from vllm .config import ParallelConfig , get_current_vllm_config
55from vllm .distributed .parallel_state import (GroupCoordinator , get_dp_group ,
66 get_tp_group , get_world_group ,
7- init_model_parallel_group )
7+ init_model_parallel_group , get_pp_group )
88
99import vllm_ascend .envs as envs_ascend
1010from vllm_ascend .ascend_config import get_ascend_config
@@ -185,6 +185,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
185185 ).flashcomm2_oproj_tensor_parallel_size
186186 global_tp_size = get_tp_group ().world_size
187187 global_dp_size = get_dp_group ().world_size
188+ global_pp_size = get_pp_group ().world_size
188189 num_fc2_oproj_tensor_parallel_groups : int = (global_tp_size //
189190 flashcomm2_otp_size )
190191
@@ -197,18 +198,21 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
197198 if flashcomm2_otp_size > 1 :
198199 otp_group_ranks = []
199200 odp_group_ranks : list [list [int ]] = [
200- [] for _ in range (flashcomm2_otp_size * global_dp_size )
201+ [] for _ in range (flashcomm2_otp_size * global_dp_size * global_pp_size )
201202 ]
202-
203203 for dp_group_index in range (global_dp_size ):
204- for i in range (num_fc2_oproj_tensor_parallel_groups ):
205- ranks = []
206- for j in range (flashcomm2_otp_size ):
207- rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups
208- ranks .append (rank_idx )
209- odp_group_index = dp_group_index * flashcomm2_otp_size + j
210- odp_group_ranks [odp_group_index ].append (rank_idx )
211- otp_group_ranks .append (ranks )
204+ for pp_group_index in range (global_pp_size ):
205+ tp_base_rank = (dp_group_index * global_pp_size + pp_group_index ) * global_tp_size
206+ for i in range (num_fc2_oproj_tensor_parallel_groups ):
207+ ranks = []
208+ for j in range (flashcomm2_otp_size ):
209+ tp_local_rank = i + j * num_fc2_oproj_tensor_parallel_groups
210+ assert tp_local_rank < global_tp_size
211+ global_rank = tp_base_rank + tp_local_rank
212+ ranks .append (global_rank )
213+ odp_group_index = (dp_group_index * global_pp_size + pp_group_index ) * flashcomm2_otp_size + j
214+ odp_group_ranks [odp_group_index ].append (global_rank )
215+ otp_group_ranks .append (ranks )
212216
213217 _FLASHCOMM2_OTP = init_model_parallel_group (
214218 otp_group_ranks ,
0 commit comments