|
67 | 67 |
|
68 | 68 | import vllm_ascend.envs as envs_ascend |
69 | 69 | from vllm_ascend.ascend_config import get_ascend_config |
| 70 | +from vllm_ascend.distributed.parallel_state import get_ep_group |
70 | 71 | from vllm_ascend.ops.fused_moe import AscendFusedMoE |
71 | 72 | from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod |
72 | 73 | from vllm_ascend.utils import dispose_tensor |
@@ -211,13 +212,15 @@ def __init__( |
211 | 212 |
|
212 | 213 | self.tp_group = get_tp_group().device_group |
213 | 214 | self.tp_rank = get_tp_group().rank_in_group |
| 215 | + self.ep_group = get_ep_group() |
214 | 216 |
|
215 | 217 | self.params_dtype = torch.get_default_dtype() |
216 | 218 |
|
217 | 219 | ascend_config = get_ascend_config() |
218 | 220 | self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled |
| 221 | + # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on |
219 | 222 | self.enable_multistream_shared_expert = \ |
220 | | - ascend_config.torchair_graph_config.enable_multistream_shared_expert |
| 223 | + ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2 |
221 | 224 |
|
222 | 225 | def forward( |
223 | 226 | self, |
@@ -245,16 +248,12 @@ def forward( |
245 | 248 | old_hidden_states = hidden_states.clone() |
246 | 249 |
|
247 | 250 | if self.tp_size > 1: |
248 | | - if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: |
249 | | - chunks = torch.chunk(hidden_states, self.tp_size, dim=0) |
250 | | - hidden_states = chunks[self.tp_rank] |
251 | | - elif not self.torchair_graph_enabled: |
252 | | - num_padding_tokens = (self.tp_size - |
253 | | - num_tokens % self.tp_size) % self.tp_size |
254 | | - # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C |
255 | | - if num_padding_tokens > 0: |
| 251 | + if (VLLM_ENABLE_MC2 |
| 252 | + and not is_prefill) or not (self.torchair_graph_enabled or |
| 253 | + self.ep_group.world_size == 1): |
| 254 | + if num_tokens < self.tp_size: |
256 | 255 | hidden_states = nn.functional.pad( |
257 | | - hidden_states, (0, 0, 0, num_padding_tokens)) |
| 256 | + hidden_states, (0, 0, 0, self.tp_size - num_tokens)) |
258 | 257 | chunk_hidden_states = torch.tensor_split(hidden_states, |
259 | 258 | self.tp_size, |
260 | 259 | dim=0) |
@@ -284,24 +283,16 @@ def forward( |
284 | 283 | hidden_states = hidden_states * self.routed_scaling_factor |
285 | 284 |
|
286 | 285 | if self.tp_size > 1: |
287 | | - if self.torchair_graph_enabled: |
288 | | - if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: |
289 | | - final_hidden_states = torch.zeros( |
290 | | - [num_tokens, hidden_size], |
291 | | - dtype=self.params_dtype, |
292 | | - device="npu") |
293 | | - dist.all_gather_into_tensor(final_hidden_states, |
294 | | - hidden_states, self.tp_group) |
295 | | - hidden_states = final_hidden_states |
296 | | - else: |
297 | | - hidden_states = tensor_model_parallel_all_reduce( |
298 | | - hidden_states) |
299 | | - else: |
| 286 | + if (VLLM_ENABLE_MC2 |
| 287 | + and not is_prefill) or not (self.torchair_graph_enabled or |
| 288 | + self.ep_group.world_size == 1): |
300 | 289 | dist.all_gather(list(chunk_hidden_states), hidden_states, |
301 | 290 | self.tp_group) |
302 | 291 | hidden_states = torch.cat(chunk_hidden_states, dim=0) |
303 | | - if num_padding_tokens > 0: |
304 | | - hidden_states = hidden_states[:-num_padding_tokens] |
| 292 | + if num_tokens < self.tp_size: |
| 293 | + hidden_states = hidden_states[:num_tokens] |
| 294 | + else: |
| 295 | + hidden_states = tensor_model_parallel_all_reduce(hidden_states) |
305 | 296 |
|
306 | 297 | if self.n_shared_experts is not None: |
307 | 298 | if not multistream: |
|
0 commit comments