@@ -84,9 +84,7 @@ def fused_experts(
8484 self ,
8585 hidden_states : torch .Tensor ,
8686 w1 : list [torch .Tensor ],
87- w1_scale : list [torch .Tensor ],
8887 w2 : list [torch .Tensor ],
89- w2_scale : list [torch .Tensor ],
9088 topk_weights : torch .Tensor ,
9189 topk_ids : torch .Tensor ,
9290 activation : str = "silu" ,
@@ -95,6 +93,8 @@ def fused_experts(
9593 use_int4_w4a8 : bool = False ,
9694 global_num_experts : Optional [int ] = None ,
9795 expert_map : Optional [torch .Tensor ] = None ,
96+ w1_scale : Optional [list [torch .Tensor ]] = None ,
97+ w2_scale : Optional [list [torch .Tensor ]] = None ,
9898 w1_scale_bias : torch .Tensor = None ,
9999 w2_scale_bias : torch .Tensor = None ,
100100 # For TorchAir graph
@@ -137,6 +137,7 @@ def fused_experts(
137137 permuted_hidden_states , expert_tokens , dynamic_scale , group_list_type , topk_scales , context_metadata = \
138138 results ["hidden_states" ], results ["group_list" ], results .get ("dynamic_scale" ), results ["group_list_type" ], results .get ("topk_scales" ), results .get ("context_metadata" )
139139
140+ assert w1_scale is not None and w2_scale is not None
140141 mlp_output = unified_apply_mlp (hidden_states = permuted_hidden_states ,
141142 w1 = w1 ,
142143 w1_scale = w1_scale ,
0 commit comments