|
3 | 3 | import torch |
4 | 4 | from vllm.config import ParallelConfig, get_current_vllm_config |
5 | 5 | from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group, |
6 | | - get_tp_group, get_world_group, |
7 | | - init_model_parallel_group, get_pp_group) |
| 6 | + get_pp_group, get_tp_group, |
| 7 | + get_world_group, |
| 8 | + init_model_parallel_group) |
8 | 9 |
|
9 | 10 | import vllm_ascend.envs as envs_ascend |
10 | 11 | from vllm_ascend.ascend_config import get_ascend_config |
@@ -198,20 +199,25 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): |
198 | 199 | if flashcomm2_otp_size > 1: |
199 | 200 | otp_group_ranks = [] |
200 | 201 | odp_group_ranks: list[list[int]] = [ |
201 | | - [] for _ in range(flashcomm2_otp_size * global_dp_size * global_pp_size) |
| 202 | + [] for _ in range(flashcomm2_otp_size * global_dp_size * |
| 203 | + global_pp_size) |
202 | 204 | ] |
203 | 205 | for dp_group_index in range(global_dp_size): |
204 | 206 | for pp_group_index in range(global_pp_size): |
205 | | - tp_base_rank = (dp_group_index * global_pp_size + pp_group_index) * global_tp_size |
| 207 | + tp_base_rank = (dp_group_index * global_pp_size + |
| 208 | + pp_group_index) * global_tp_size |
206 | 209 | for i in range(num_fc2_oproj_tensor_parallel_groups): |
207 | 210 | ranks = [] |
208 | 211 | for j in range(flashcomm2_otp_size): |
209 | 212 | tp_local_rank = i + j * num_fc2_oproj_tensor_parallel_groups |
210 | 213 | assert tp_local_rank < global_tp_size |
211 | 214 | global_rank = tp_base_rank + tp_local_rank |
212 | 215 | ranks.append(global_rank) |
213 | | - odp_group_index = (dp_group_index * global_pp_size + pp_group_index) * flashcomm2_otp_size + j |
214 | | - odp_group_ranks[odp_group_index].append(global_rank) |
| 216 | + odp_group_index = ( |
| 217 | + dp_group_index * global_pp_size + |
| 218 | + pp_group_index) * flashcomm2_otp_size + j |
| 219 | + odp_group_ranks[odp_group_index].append( |
| 220 | + global_rank) |
215 | 221 | otp_group_ranks.append(ranks) |
216 | 222 |
|
217 | 223 | _FLASHCOMM2_OTP = init_model_parallel_group( |
|
0 commit comments