|
37 | 37 | from vllm_ascend.distributed.parallel_state import get_mc2_group |
38 | 38 | from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map |
39 | 39 | from vllm_ascend.eplb.utils import moe_load_async_stream |
| 40 | +from vllm_ascend.flash_common3_context import (get_flash_common3_context, |
| 41 | + set_flash_common3_context) |
40 | 42 | from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer |
41 | 43 | from vllm_ascend.ops.fused_moe.experts_selector import select_experts |
42 | | -from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method |
| 44 | +from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl, |
| 45 | + setup_moe_comm_method) |
43 | 46 | from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType |
44 | 47 | from vllm_ascend.quantization.w4a8_dynamic import \ |
45 | 48 | AscendW4A8DynamicFusedMoEMethod |
@@ -139,6 +142,7 @@ def apply(self, |
139 | 142 |
|
140 | 143 | class AscendFusedMoE(FusedMoE): |
141 | 144 | moe_counter = -1 |
| 145 | + gate_stream: Optional[torch.npu.Stream] = None |
142 | 146 |
|
143 | 147 | def __init__(self, *args, **kwargs): |
144 | 148 | super().__init__(*args, **kwargs) |
@@ -170,6 +174,10 @@ def __init__(self, *args, **kwargs): |
170 | 174 | self.expert_map_path = ascend_config.expert_map_path |
171 | 175 | self.global_redundant_expert_num = ascend_config.init_redundancy_expert |
172 | 176 | self.global_num_experts = num_experts + self.global_redundant_expert_num |
| 177 | + # flashcommon3 gate stream |
| 178 | + self.multistream_overlap_gate = ascend_config.multistream_overlap_gate |
| 179 | + if self.multistream_overlap_gate and AscendFusedMoE.gate_stream is None: |
| 180 | + AscendFusedMoE.gate_stream = torch.npu.Stream() |
173 | 181 | if self.custom_routing_function is None and self.e_score_correction_bias is not None: |
174 | 182 | vllm_config = get_current_vllm_config() |
175 | 183 | self.e_score_correction_bias.data = self.e_score_correction_bias.data.to( |
@@ -332,13 +340,58 @@ def forward_impl(self, hidden_states: torch.Tensor, |
332 | 340 | enable_force_load_balance = forward_context.in_profile_run |
333 | 341 |
|
334 | 342 | forward_context = get_forward_context() |
| 343 | + if self.multistream_overlap_gate: |
| 344 | + assert AscendFusedMoE.gate_stream is not None |
| 345 | + fc3_context = get_flash_common3_context() |
| 346 | + assert fc3_context is not None |
| 347 | + AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream()) |
| 348 | + with npu_stream_switch(AscendFusedMoE.gate_stream, |
| 349 | + enabled=self.multistream_overlap_gate): |
| 350 | + # share_expert |
| 351 | + assert fc3_context.shared_experts is not None |
| 352 | + shared_out = fc3_context.shared_experts(hidden_states) |
| 353 | + # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` |
| 354 | + moe_comm_type = forward_context.moe_comm_type |
| 355 | + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \ |
| 356 | + and not shared_expert_dp_enabled(): |
| 357 | + shared_out = tensor_model_parallel_all_reduce(shared_out) |
| 358 | + set_flash_common3_context(shared_out=shared_out) |
| 359 | + |
| 360 | + topk_weights, topk_ids = select_experts( |
| 361 | + hidden_states=hidden_states, |
| 362 | + router_logits=router_logits, |
| 363 | + top_k=self.top_k, |
| 364 | + use_grouped_topk=self.use_grouped_topk, |
| 365 | + renormalize=self.renormalize, |
| 366 | + topk_group=self.topk_group, |
| 367 | + num_expert_group=self.num_expert_group, |
| 368 | + custom_routing_function=self.custom_routing_function, |
| 369 | + scoring_func=self.scoring_func, |
| 370 | + routed_scaling_factor=self.routed_scaling_factor, |
| 371 | + e_score_correction_bias=self.e_score_correction_bias, |
| 372 | + global_num_experts=self.global_num_experts) |
| 373 | + |
| 374 | + if isinstance(forward_context.moe_comm_method, |
| 375 | + AllGatherCommImpl): |
| 376 | + topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( |
| 377 | + topk_weights, True, True) |
| 378 | + topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( |
| 379 | + topk_ids, True, True) |
| 380 | + |
| 381 | + set_flash_common3_context(topk_weights=topk_weights, |
| 382 | + topk_ids=topk_ids) |
| 383 | + |
335 | 384 | hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare( |
336 | 385 | hidden_states=hidden_states, |
337 | 386 | router_logits=router_logits, |
338 | 387 | replace_allreduce=forward_context.sp_enabled, |
339 | 388 | enable_shared_expert_dp=self.enable_shared_expert_dp, |
340 | 389 | quant_type=self.quant_type) |
341 | 390 |
|
| 391 | + # Make sure the default stream waits for the gate stream to finish. |
| 392 | + if self.multistream_overlap_gate: |
| 393 | + torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream) |
| 394 | + |
342 | 395 | if isinstance(hidden_states, tuple): |
343 | 396 | hidden_states, pertoken_scale = hidden_states |
344 | 397 | else: |
@@ -407,6 +460,7 @@ def __init__( |
407 | 460 | self.shared_expert_stream = None |
408 | 461 | ascend_config = get_ascend_config() |
409 | 462 | self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert |
| 463 | + self.multistream_overlap_gate = ascend_config.multistream_overlap_gate |
410 | 464 | if enable_sp(): |
411 | 465 | logger.info_once( |
412 | 466 | "Sequence parallelism is enabled, shared experts are replicated for best performance." |
@@ -443,30 +497,42 @@ def forward( |
443 | 497 |
|
444 | 498 | def forward_impl(self, hidden_states: torch.Tensor, |
445 | 499 | router_logits: torch.Tensor): |
446 | | - # Make sure the shared experts stream begins after hidden_states are ready. |
447 | | - if self.multistream_overlap_shared_expert: |
448 | | - shared_experts_calculation_stream().wait_stream( # type: ignore |
449 | | - torch.npu.current_stream()) |
450 | | - with npu_stream_switch(shared_experts_calculation_stream(), |
451 | | - enabled=self.multistream_overlap_shared_expert): |
452 | | - # Use a separate stream to run shared experts. |
453 | | - # Note that currently we only support calculations in separate streams with aclgraph. |
454 | | - # Communication operations in another stream might cause unknown errors. |
455 | | - shared_out = self._shared_experts(hidden_states) |
| 500 | + shared_out = None |
| 501 | + if not self.multistream_overlap_gate: |
| 502 | + # Make sure the shared experts stream begins after hidden_states are ready. |
| 503 | + if self.multistream_overlap_shared_expert: |
| 504 | + shared_experts_calculation_stream( |
| 505 | + ).wait_stream( # type: ignore |
| 506 | + torch.npu.current_stream()) |
| 507 | + with npu_stream_switch( |
| 508 | + shared_experts_calculation_stream(), |
| 509 | + enabled=self.multistream_overlap_shared_expert): |
| 510 | + # Use a separate stream to run shared experts. |
| 511 | + shared_out = self._shared_experts(hidden_states) |
| 512 | + else: |
| 513 | + set_flash_common3_context(shared_experts=self._shared_experts) |
456 | 514 |
|
457 | 515 | fused_output = AscendFusedMoE.forward_impl( |
458 | 516 | self, |
459 | 517 | hidden_states=hidden_states, |
460 | 518 | router_logits=router_logits, |
461 | 519 | ) |
462 | | - # Make sure the default stream waits for the shared experts stream to finish. |
463 | | - if self.multistream_overlap_shared_expert: |
464 | | - torch.npu.current_stream().wait_stream( |
465 | | - shared_experts_calculation_stream()) |
466 | | - # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` |
467 | | - forward_context = get_forward_context() |
468 | | - moe_comm_type = forward_context.moe_comm_type |
469 | | - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \ |
470 | | - and not shared_expert_dp_enabled(): |
471 | | - shared_out = tensor_model_parallel_all_reduce(shared_out) |
| 520 | + |
| 521 | + if not self.multistream_overlap_gate: |
| 522 | + # Make sure the default stream waits for the shared experts stream to finish. |
| 523 | + if self.multistream_overlap_shared_expert: |
| 524 | + torch.npu.current_stream().wait_stream( |
| 525 | + shared_experts_calculation_stream()) |
| 526 | + |
| 527 | + # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` |
| 528 | + forward_context = get_forward_context() |
| 529 | + moe_comm_type = forward_context.moe_comm_type |
| 530 | + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \ |
| 531 | + and not shared_expert_dp_enabled(): |
| 532 | + shared_out = tensor_model_parallel_all_reduce(shared_out) |
| 533 | + else: |
| 534 | + fc3_context = get_flash_common3_context() |
| 535 | + assert fc3_context is not None |
| 536 | + shared_out = fc3_context.shared_out |
| 537 | + |
472 | 538 | return shared_out, fused_output |
0 commit comments