Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tests.e2e.model_utils import check_outputs_equal

MODELS = [
"Qwen/Qwen3-0.6B",
"vllm-ascend/DeepSeek-V2-Lite-W8A8",
]


Expand Down
26 changes: 14 additions & 12 deletions vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
is_enable_nz, npu_stream_switch,
shared_expert_dp_enabled)
shared_expert_dp_enabled,
shared_experts_compute_stream)


class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
Expand Down Expand Up @@ -419,8 +420,6 @@ def __init__(
self.shared_expert_stream = None
ascend_config = get_ascend_config()
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
if self.multistream_overlap_shared_expert:
self.shared_expert_stream = torch.npu.Stream()
if enable_sp():
logger.info_once(
"Sequence parallelism is enabled, shared experts are replicated for best performance."
Expand All @@ -442,25 +441,28 @@ def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
# Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
self.shared_expert_stream.wait_stream( # type: ignore
shared_experts_compute_stream().wait_stream( # type: ignore
torch.npu.current_stream())
with npu_stream_switch(self.shared_expert_stream,
with npu_stream_switch(shared_experts_compute_stream(),
enabled=self.multistream_overlap_shared_expert):
# Use a separate stream to run shared experts.
# Note that currently we only support calculations in separate streams with aclgraph.
# Communication operations in another stream might cause unknown errors.
shared_out = self._shared_experts(hidden_states)

# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
torch.npu.current_stream().wait_stream(
shared_experts_compute_stream())
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out, fused_output
10 changes: 10 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
_SLEEP_MODE_ENABLED = None
_CURRENT_STREAM = None
_PREFETCH_STREAM = None
_SHARED_EXPERTS_COMPUTE_STREAM = None
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
_DEFAULT_BUFFER_SIZE = 200
_MIN_DP_BUFFER_SIZE = 50
Expand Down Expand Up @@ -259,6 +260,15 @@ def prefetch_stream() -> torch.npu.Stream:
return _PREFETCH_STREAM


def shared_experts_compute_stream() -> torch.npu.Stream:
global _SHARED_EXPERTS_COMPUTE_STREAM
if _SHARED_EXPERTS_COMPUTE_STREAM is None:
# when this function is called before any stream is set,
# we return the default stream.
_SHARED_EXPERTS_COMPUTE_STREAM = torch_npu.npu.Stream()
return _SHARED_EXPERTS_COMPUTE_STREAM


def adapt_patch(is_global_patch: bool = False):
if is_global_patch:
from vllm_ascend.patch import platform # noqa: F401
Expand Down