Skip to content

Commit e37a66d

Browse files
whx-sjtuAngazenn
authored andcommitted
[Cherry-pick] Port MoE multi-stream fix to v0.11.0-dev (vllm-project#3753)
This PR moves the communication operation of shared experts out of extra stream because I found that this might cause rtMemcpy related errors when running shared experts multistream with aclgraph. Furthermore, I utilize a global variable as extra stream object to avoid allocating streams for each layer in full-graph mode. Signed-off-by: whx-sjtu <[email protected]>
1 parent 9868268 commit e37a66d

File tree

3 files changed

+25
-13
lines changed

3 files changed

+25
-13
lines changed

tests/e2e/singlecard/test_multistream_overlap_shared_expert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from tests.e2e.model_utils import check_outputs_equal
2929

3030
MODELS = [
31-
"Qwen/Qwen3-0.6B",
31+
"vllm-ascend/DeepSeek-V2-Lite-W8A8",
3232
]
3333

3434

vllm_ascend/ops/common_fused_moe.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
4141
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
4242
is_enable_nz, npu_stream_switch,
43-
shared_expert_dp_enabled)
43+
shared_expert_dp_enabled,
44+
shared_experts_compute_stream)
4445

4546

4647
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
@@ -419,8 +420,6 @@ def __init__(
419420
self.shared_expert_stream = None
420421
ascend_config = get_ascend_config()
421422
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
422-
if self.multistream_overlap_shared_expert:
423-
self.shared_expert_stream = torch.npu.Stream()
424423
if enable_sp():
425424
logger.info_once(
426425
"Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -442,25 +441,28 @@ def forward_impl(self, hidden_states: torch.Tensor,
442441
router_logits: torch.Tensor):
443442
# Make sure the shared experts stream begins after hidden_states are ready.
444443
if self.multistream_overlap_shared_expert:
445-
self.shared_expert_stream.wait_stream( # type: ignore
444+
shared_experts_compute_stream().wait_stream( # type: ignore
446445
torch.npu.current_stream())
447-
with npu_stream_switch(self.shared_expert_stream,
446+
with npu_stream_switch(shared_experts_compute_stream(),
448447
enabled=self.multistream_overlap_shared_expert):
449448
# Use a separate stream to run shared experts.
449+
# Note that currently we only support calculations in separate streams with aclgraph.
450+
# Communication operations in another stream might cause unknown errors.
450451
shared_out = self._shared_experts(hidden_states)
451452

452-
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
453-
forward_context = get_forward_context()
454-
moe_comm_type = forward_context.moe_comm_type
455-
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
456-
and not shared_expert_dp_enabled():
457-
shared_out = tensor_model_parallel_all_reduce(shared_out)
458453
fused_output = AscendFusedMoE.forward_impl(
459454
self,
460455
hidden_states=hidden_states,
461456
router_logits=router_logits,
462457
)
463458
# Make sure the default stream waits for the shared experts stream to finish.
464459
if self.multistream_overlap_shared_expert:
465-
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
460+
torch.npu.current_stream().wait_stream(
461+
shared_experts_compute_stream())
462+
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
463+
forward_context = get_forward_context()
464+
moe_comm_type = forward_context.moe_comm_type
465+
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
466+
and not shared_expert_dp_enabled():
467+
shared_out = tensor_model_parallel_all_reduce(shared_out)
466468
return shared_out, fused_output

vllm_ascend/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
_SLEEP_MODE_ENABLED = None
5353
_CURRENT_STREAM = None
5454
_PREFETCH_STREAM = None
55+
_SHARED_EXPERTS_COMPUTE_STREAM = None
5556
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
5657
_DEFAULT_BUFFER_SIZE = 200
5758
_MIN_DP_BUFFER_SIZE = 50
@@ -259,6 +260,15 @@ def prefetch_stream() -> torch.npu.Stream:
259260
return _PREFETCH_STREAM
260261

261262

263+
def shared_experts_compute_stream() -> torch.npu.Stream:
264+
global _SHARED_EXPERTS_COMPUTE_STREAM
265+
if _SHARED_EXPERTS_COMPUTE_STREAM is None:
266+
# when this function is called before any stream is set,
267+
# we return the default stream.
268+
_SHARED_EXPERTS_COMPUTE_STREAM = torch_npu.npu.Stream()
269+
return _SHARED_EXPERTS_COMPUTE_STREAM
270+
271+
262272
def adapt_patch(is_global_patch: bool = False):
263273
if is_global_patch:
264274
from vllm_ascend.patch import platform # noqa: F401

0 commit comments

Comments
 (0)