diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index 89721702de3..f507282622e 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -23,6 +23,8 @@ from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils from vllm_ascend.eplb.core.eplb_worker import EplbProcess +from vllm_ascend.eplb.utils import moe_load_async_stream +from vllm_ascend.utils import npu_stream_switch class EplbUpdator: @@ -153,21 +155,22 @@ def compute_and_set_moe_load(self, is_clear=False): self._gather_buffer = None if dist.is_initialized(): - self.world_size = dist.get_world_size() - self.device = local_load.device - if self._gather_buffer is None: - shape = (self.world_size, *local_load.shape) - self._gather_buffer = torch.empty(shape, - dtype=local_load.dtype, - device=self.device) - - dist.all_gather_into_tensor(self._gather_buffer, local_load) - - moe_load = self._gather_buffer.permute(1, 0, 2) - self.shared_dict["moe_load"] = moe_load.cpu() - logger.debug( - f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}" - ) + with npu_stream_switch(moe_load_async_stream()): + self.world_size = dist.get_world_size() + self.device = local_load.device + if self._gather_buffer is None: + shape = (self.world_size, *local_load.shape) + self._gather_buffer = torch.empty(shape, + dtype=local_load.dtype, + device=self.device) + + dist.all_gather_into_tensor(self._gather_buffer, local_load) + + moe_load = self._gather_buffer.permute(1, 0, 2) + self.shared_dict["moe_load"] = moe_load.cpu() + logger.debug( + f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}" + ) else: moe_load = local_load.unsqueeze(1) self.shared_dict["moe_load"] = moe_load.cpu() diff --git a/vllm_ascend/eplb/utils.py b/vllm_ascend/eplb/utils.py index 61e5735e4dc..8dfaf56293a 100644 --- a/vllm_ascend/eplb/utils.py +++ b/vllm_ascend/eplb/utils.py @@ -18,6 +18,9 @@ import types import torch +import torch_npu + +_MOE_LOAD_ASYNC_STREAM = None def get_expert_map(self, layer_id): @@ -75,3 +78,12 @@ def model_register(model, model_config): model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers else: raise NotImplementedError("EPLB is not supported.") + + +def moe_load_async_stream() -> torch_npu.npu.Stream: + global _MOE_LOAD_ASYNC_STREAM + if _MOE_LOAD_ASYNC_STREAM is None: + # when this function is called before any stream is set, + # we return the default stream. + _MOE_LOAD_ASYNC_STREAM = torch_npu.npu.Stream() + return _MOE_LOAD_ASYNC_STREAM \ No newline at end of file diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 6bf8d6b750c..80c7fcc2c8c 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -36,6 +36,7 @@ from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map +from vllm_ascend.eplb.utils import moe_load_async_stream from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method @@ -368,8 +369,15 @@ def forward_impl(self, hidden_states: torch.Tensor, if isinstance(final_hidden_states, tuple): final_hidden_states, group_list_type, expert_tokens = final_hidden_states if self.dynamic_eplb: - self.moe_load += expert_tokens if group_list_type == 1 else \ - torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + + moe_load_stream = moe_load_async_stream() + cur_stream = torch.npu.current_stream() + + moe_load_stream.wait_stream(cur_stream) + with npu_stream_switch(moe_load_stream): + self.moe_load += expert_tokens if group_list_type == 1 else \ + torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + cur_stream.wait_stream(moe_load_stream) final_hidden_states = forward_context.moe_comm_method.finalize( hidden_states=final_hidden_states,