@@ -221,20 +221,20 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
221221
222222
223223def unquant_apply_mlp (hidden_states : torch .Tensor ,
224- w1 : list [ torch .Tensor ] ,
225- w2 : list [ torch .Tensor ] ,
224+ w1 : torch .Tensor ,
225+ w2 : torch .Tensor ,
226226 group_list : torch .Tensor ,
227227 group_list_type : int = 1 ,
228228 topk_scales : Optional [torch .Tensor ] = None ,
229229 need_trans : bool = True ) -> torch .Tensor :
230230
231231 if need_trans :
232- w1 [ 0 ] = w1 [ 0 ] .transpose (1 , 2 )
233- w2 [ 0 ] = w2 [ 0 ] .transpose (1 , 2 )
232+ w1 = w1 .transpose (1 , 2 )
233+ w2 = w2 .transpose (1 , 2 )
234234
235235 gate_up_out = torch_npu .npu_grouped_matmul (
236236 x = [hidden_states ],
237- weight = w1 ,
237+ weight = [ w1 ] ,
238238 split_item = 2 ,
239239 group_list_type = group_list_type ,
240240 group_type = 0 ,
@@ -251,7 +251,7 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
251251
252252 hidden_states = torch_npu .npu_grouped_matmul (
253253 x = [gate_up_out ],
254- weight = w2 ,
254+ weight = [ w2 ] ,
255255 split_item = 2 ,
256256 group_list_type = group_list_type ,
257257 group_type = 0 ,
@@ -261,8 +261,8 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
261261
262262
263263def unified_apply_mlp (hidden_states : torch .Tensor ,
264- w1 : list [torch .Tensor ],
265- w2 : list [torch .Tensor ],
264+ w1 : torch . Tensor | list [torch .Tensor ],
265+ w2 : torch . Tensor | list [torch .Tensor ],
266266 group_list : torch .Tensor ,
267267 w1_scale : Optional [list [torch .Tensor ]] = None ,
268268 w2_scale : Optional [list [torch .Tensor ]] = None ,
0 commit comments