Skip to content

Commit 765a3c3

Browse files
白永斌845473182
authored andcommitted
fix ut
Signed-off-by: 白永斌 <[email protected]> Signed-off-by: 欧派果奶我还要 <[email protected]>
1 parent 18b13f5 commit 765a3c3

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

vllm_ascend/ops/fused_moe/moe_comm_method.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ def fused_experts(
8484
self,
8585
hidden_states: torch.Tensor,
8686
w1: list[torch.Tensor],
87-
w1_scale: list[torch.Tensor],
8887
w2: list[torch.Tensor],
89-
w2_scale: list[torch.Tensor],
9088
topk_weights: torch.Tensor,
9189
topk_ids: torch.Tensor,
9290
activation: str = "silu",
@@ -95,6 +93,8 @@ def fused_experts(
9593
use_int4_w4a8: bool = False,
9694
global_num_experts: Optional[int] = None,
9795
expert_map: Optional[torch.Tensor] = None,
96+
w1_scale: Optional[list[torch.Tensor]] = None,
97+
w2_scale: Optional[list[torch.Tensor]] = None,
9898
w1_scale_bias: torch.Tensor = None,
9999
w2_scale_bias: torch.Tensor = None,
100100
# For TorchAir graph
@@ -137,6 +137,7 @@ def fused_experts(
137137
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
138138
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
139139

140+
assert w1_scale is not None and w2_scale is not None
140141
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
141142
w1=w1,
142143
w1_scale=w1_scale,

0 commit comments

Comments
 (0)