Skip to content

Commit 29337d5

Browse files
fix
Signed-off-by: shenchuxiaofugui <[email protected]>
1 parent 26e93d8 commit 29337d5

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

vllm_ascend/ops/common_fused_moe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,12 @@ def forward_impl(self, hidden_states: torch.Tensor,
279279
quantized_x_for_share, dynamic_scale_for_share = None, None
280280

281281
forward_context = get_forward_context()
282+
283+
# Load balancing for token distribution among experts in dummy_run
284+
# TODO: The community only considers load balancing when DP > 1.
285+
# This approach may overlook some extreme scenarios.
286+
enable_force_load_balance = forward_context.in_profile_run
287+
282288
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
283289
hidden_states=hidden_states,
284290
router_logits=router_logits,
@@ -305,6 +311,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
305311
quantized_x_for_share=quantized_x_for_share,
306312
dynamic_scale_for_share=dynamic_scale_for_share,
307313
shared_experts=None,
314+
enable_force_load_balance=enable_force_load_balance,
308315
log2phy=self.log2phy,
309316
global_redundant_expert_num=self.global_redundant_expert_num)
310317

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,8 @@ def apply(
371371
# to avoid accumulating too much tokens on a single rank.
372372
# currently it is only activated when doing profile runs.
373373
if enable_force_load_balance:
374-
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
374+
topk_ids = torch.randint_like(
375+
topk_ids, 0, global_num_experts - global_redundant_expert_num)
375376

376377
topk_weights = topk_weights.to(x.dtype)
377378

vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@ def apply(
925925
scoring_func: str = "softmax",
926926
e_score_correction_bias: Optional[torch.Tensor] = None,
927927
is_prefill: bool = True,
928-
enable_force_load_balance: bool = True,
928+
enable_force_load_balance: bool = False,
929929
log2phy: torch.Tensor = None,
930930
global_redundant_expert_num: int = 0,
931931
shared_experts: Optional[Any] = None,

0 commit comments

Comments
 (0)