Skip to content

Commit 2627e98

Browse files
白永斌845473182
authored andcommitted
fix e2e-light
Signed-off-by: 白永斌 <[email protected]> Signed-off-by: 欧派果奶我还要 <[email protected]>
1 parent 016ff63 commit 2627e98

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

vllm_ascend/ops/fused_moe/moe_comm_method.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def finalize(self,
8383
def fused_experts(
8484
self,
8585
hidden_states: torch.Tensor,
86-
w1: list[torch.Tensor],
87-
w2: list[torch.Tensor],
86+
w1: torch.Tensor | list[torch.Tensor],
87+
w2: torch.Tensor | list[torch.Tensor],
8888
topk_weights: torch.Tensor,
8989
topk_ids: torch.Tensor,
9090
activation: str = "silu",

vllm_ascend/ops/fused_moe/moe_mlp.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,20 +221,20 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
221221

222222

223223
def unquant_apply_mlp(hidden_states: torch.Tensor,
224-
w1: list[torch.Tensor],
225-
w2: list[torch.Tensor],
224+
w1: torch.Tensor,
225+
w2: torch.Tensor,
226226
group_list: torch.Tensor,
227227
group_list_type: int = 1,
228228
topk_scales: Optional[torch.Tensor] = None,
229229
need_trans: bool = True) -> torch.Tensor:
230230

231231
if need_trans:
232-
w1[0] = w1[0].transpose(1, 2)
233-
w2[0] = w2[0].transpose(1, 2)
232+
w1 = w1.transpose(1, 2)
233+
w2 = w2.transpose(1, 2)
234234

235235
gate_up_out = torch_npu.npu_grouped_matmul(
236236
x=[hidden_states],
237-
weight=w1,
237+
weight=[w1],
238238
split_item=2,
239239
group_list_type=group_list_type,
240240
group_type=0,
@@ -251,7 +251,7 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
251251

252252
hidden_states = torch_npu.npu_grouped_matmul(
253253
x=[gate_up_out],
254-
weight=w2,
254+
weight=[w2],
255255
split_item=2,
256256
group_list_type=group_list_type,
257257
group_type=0,
@@ -261,8 +261,8 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
261261

262262

263263
def unified_apply_mlp(hidden_states: torch.Tensor,
264-
w1: list[torch.Tensor],
265-
w2: list[torch.Tensor],
264+
w1: torch.Tensor | list[torch.Tensor],
265+
w2: torch.Tensor | list[torch.Tensor],
266266
group_list: torch.Tensor,
267267
w1_scale: Optional[list[torch.Tensor]] = None,
268268
w2_scale: Optional[list[torch.Tensor]] = None,

0 commit comments

Comments
 (0)