@@ -21,10 +21,14 @@ def __init__(self,
2121 block_m : Optional [int ] = None ,
2222 allow_deep_gemm : bool = False ):
2323 super ().__init__ ()
24- self .triton_expert : TritonExperts = TritonExperts (
25- use_fp8_w8a8 , use_int8_w8a8 , use_int4_w4a16 , use_int8_w8a16 ,
26- per_channel_quant , block_shape , block_m )
27- self .deep_gemm_expert : DeepGemmExperts = DeepGemmExperts ()
24+ self .triton_expert = TritonExperts (use_fp8_w8a8 = use_fp8_w8a8 ,
25+ use_int8_w8a8 = use_int8_w8a8 ,
26+ use_int4_w4a16 = use_int4_w4a16 ,
27+ use_int8_w8a16 = use_int8_w8a16 ,
28+ per_channel_quant = per_channel_quant ,
29+ block_shape = block_shape ,
30+ block_m = block_m )
31+ self .deep_gemm_expert = DeepGemmExperts ()
2832 self .allow_deep_gemm = allow_deep_gemm
2933 self .use_fp8_w8a8 = use_fp8_w8a8
3034
@@ -69,7 +73,7 @@ def apply(
6973 N = w1 .shape [1 ]
7074 if (self .allow_deep_gemm and self .use_fp8_w8a8 and N > 512
7175 and _valid_deep_gemm (hidden_states , w1 , w2 , expert_map )):
72- return self .deep_gemm_expert (
76+ return self .deep_gemm_expert . apply (
7377 hidden_states ,
7478 w1 ,
7579 w2 ,
@@ -88,7 +92,7 @@ def apply(
8892 expert_num_tokens ,
8993 )
9094 else :
91- return self .triton_expert (
95+ return self .triton_expert . apply (
9296 hidden_states ,
9397 w1 ,
9498 w2 ,
0 commit comments