4040from vllm_ascend .ops .moe .moe_comm_method import setup_moe_comm_method
4141from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , enable_sp , is_310p ,
4242 is_enable_nz , npu_stream_switch ,
43- shared_expert_dp_enabled )
43+ shared_expert_dp_enabled ,
44+ shared_experts_compute_stream )
4445
4546
4647class AscendUnquantizedFusedMoEMethod (UnquantizedFusedMoEMethod ):
@@ -419,8 +420,6 @@ def __init__(
419420 self .shared_expert_stream = None
420421 ascend_config = get_ascend_config ()
421422 self .multistream_overlap_shared_expert = ascend_config .multistream_overlap_shared_expert
422- if self .multistream_overlap_shared_expert :
423- self .shared_expert_stream = torch .npu .Stream ()
424423 if enable_sp ():
425424 logger .info_once (
426425 "Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -442,25 +441,28 @@ def forward_impl(self, hidden_states: torch.Tensor,
442441 router_logits : torch .Tensor ):
443442 # Make sure the shared experts stream begins after hidden_states are ready.
444443 if self .multistream_overlap_shared_expert :
445- self . shared_expert_stream .wait_stream ( # type: ignore
444+ shared_experts_compute_stream () .wait_stream ( # type: ignore
446445 torch .npu .current_stream ())
447- with npu_stream_switch (self . shared_expert_stream ,
446+ with npu_stream_switch (shared_experts_compute_stream () ,
448447 enabled = self .multistream_overlap_shared_expert ):
449448 # Use a separate stream to run shared experts.
449+ # Note that currently we only support calculations in separate streams with aclgraph.
450+ # Communication operations in another stream might cause unknown errors.
450451 shared_out = self ._shared_experts (hidden_states )
451452
452- # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
453- forward_context = get_forward_context ()
454- moe_comm_type = forward_context .moe_comm_type
455- if moe_comm_type in {MoECommType .ALLTOALL , MoECommType .MC2 } \
456- and not shared_expert_dp_enabled ():
457- shared_out = tensor_model_parallel_all_reduce (shared_out )
458453 fused_output = AscendFusedMoE .forward_impl (
459454 self ,
460455 hidden_states = hidden_states ,
461456 router_logits = router_logits ,
462457 )
463458 # Make sure the default stream waits for the shared experts stream to finish.
464459 if self .multistream_overlap_shared_expert :
465- torch .npu .current_stream ().wait_stream (self .shared_expert_stream )
460+ torch .npu .current_stream ().wait_stream (
461+ shared_experts_compute_stream ())
462+ # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
463+ forward_context = get_forward_context ()
464+ moe_comm_type = forward_context .moe_comm_type
465+ if moe_comm_type in {MoECommType .ALLTOALL , MoECommType .MC2 } \
466+ and not shared_expert_dp_enabled ():
467+ shared_out = tensor_model_parallel_all_reduce (shared_out )
466468 return shared_out , fused_output
0 commit comments