diff --git a/tests/e2e/singlecard/test_multistream_overlap_shared_expert.py b/tests/e2e/singlecard/test_multistream_overlap_shared_expert.py index 0f150c8acf0..89a78912231 100644 --- a/tests/e2e/singlecard/test_multistream_overlap_shared_expert.py +++ b/tests/e2e/singlecard/test_multistream_overlap_shared_expert.py @@ -28,7 +28,7 @@ from tests.e2e.model_utils import check_outputs_equal MODELS = [ - "Qwen/Qwen3-0.6B", + "vllm-ascend/DeepSeek-V2-Lite-W8A8", ] diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 604418cc570..1aceb89da0d 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -40,7 +40,8 @@ from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p, is_enable_nz, npu_stream_switch, - shared_expert_dp_enabled) + shared_expert_dp_enabled, + shared_experts_compute_stream) class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): @@ -419,8 +420,6 @@ def __init__( self.shared_expert_stream = None ascend_config = get_ascend_config() self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert - if self.multistream_overlap_shared_expert: - self.shared_expert_stream = torch.npu.Stream() if enable_sp(): logger.info_once( "Sequence parallelism is enabled, shared experts are replicated for best performance." @@ -442,19 +441,15 @@ def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): # Make sure the shared experts stream begins after hidden_states are ready. if self.multistream_overlap_shared_expert: - self.shared_expert_stream.wait_stream( # type: ignore + shared_experts_compute_stream().wait_stream( # type: ignore torch.npu.current_stream()) - with npu_stream_switch(self.shared_expert_stream, + with npu_stream_switch(shared_experts_compute_stream(), enabled=self.multistream_overlap_shared_expert): # Use a separate stream to run shared experts. + # Note that currently we only support calculations in separate streams with aclgraph. + # Communication operations in another stream might cause unknown errors. shared_out = self._shared_experts(hidden_states) - # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` - forward_context = get_forward_context() - moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \ - and not shared_expert_dp_enabled(): - shared_out = tensor_model_parallel_all_reduce(shared_out) fused_output = AscendFusedMoE.forward_impl( self, hidden_states=hidden_states, @@ -462,5 +457,12 @@ def forward_impl(self, hidden_states: torch.Tensor, ) # Make sure the default stream waits for the shared experts stream to finish. if self.multistream_overlap_shared_expert: - torch.npu.current_stream().wait_stream(self.shared_expert_stream) + torch.npu.current_stream().wait_stream( + shared_experts_compute_stream()) + # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` + forward_context = get_forward_context() + moe_comm_type = forward_context.moe_comm_type + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \ + and not shared_expert_dp_enabled(): + shared_out = tensor_model_parallel_all_reduce(shared_out) return shared_out, fused_output diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index a5f23c78c11..6745c1671a1 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -52,6 +52,7 @@ _SLEEP_MODE_ENABLED = None _CURRENT_STREAM = None _PREFETCH_STREAM = None +_SHARED_EXPERTS_COMPUTE_STREAM = None _ASCEND_CUSTOMOP_IS_REIGISTERED = False _DEFAULT_BUFFER_SIZE = 200 _MIN_DP_BUFFER_SIZE = 50 @@ -259,6 +260,15 @@ def prefetch_stream() -> torch.npu.Stream: return _PREFETCH_STREAM +def shared_experts_compute_stream() -> torch.npu.Stream: + global _SHARED_EXPERTS_COMPUTE_STREAM + if _SHARED_EXPERTS_COMPUTE_STREAM is None: + # when this function is called before any stream is set, + # we return the default stream. + _SHARED_EXPERTS_COMPUTE_STREAM = torch_npu.npu.Stream() + return _SHARED_EXPERTS_COMPUTE_STREAM + + def adapt_patch(is_global_patch: bool = False): if is_global_patch: from vllm_ascend.patch import platform # noqa: F401