4949from flashinfer .utils import get_compute_capability
5050
5151
52+ # Max num tokens to tune for trtllm-gen fused moe
53+ TUNE_MAX_NUM_TOKENS = 4096
54+
55+
5256def check_cuda (err ):
5357 """Unified CUDA error checking function used throughout the file."""
5458 if err != runtime .cudaError_t .cudaSuccess :
@@ -76,6 +80,7 @@ def __init__(self, moe_impl, static_data, **config):
7680 self .moe_impl = moe_impl
7781 self .static_data = static_data
7882 self .config = config
83+ self .enable_autotune = config .get ("enable_autotune" , True )
7984 self .graph = None
8085 self .graph_exec = None
8186 self .stream = None
@@ -106,7 +111,7 @@ def capture(self, hidden_states_sample, **runtime_args):
106111 self .input_tensor = hidden_states_sample .clone ()
107112
108113 # Warmup
109- with torch .cuda .stream (torch_stream ), autotune (True ):
114+ with torch .cuda .stream (torch_stream ), autotune (self . enable_autotune ):
110115 for _ in range (1 ):
111116 self ._run_moe_computation (runtime_args )
112117
@@ -207,6 +212,7 @@ def _run_moe_computation(self, runtime_args):
207212 routing_method_type = self .config ["routing_method_type" ],
208213 gated_act_type = self .config ["gated_act_type" ],
209214 do_finalize = True ,
215+ tune_max_num_tokens = TUNE_MAX_NUM_TOKENS ,
210216 )
211217 return output # Extract tensor from tuple
212218
@@ -551,6 +557,7 @@ def call_moe(
551557 routed_scaling = kwargs ["routed_scaling" ]
552558 gated_act_type = kwargs ["gated_act_type" ]
553559 routing_method_type = kwargs ["routing_method_type" ]
560+ enable_autotune = kwargs .get ("enable_autotune" , True )
554561
555562 # Create CUDA graph configuration
556563 config = {
@@ -563,6 +570,7 @@ def call_moe(
563570 "routed_scaling" : routed_scaling ,
564571 "gated_act_type" : gated_act_type ,
565572 "routing_method_type" : routing_method_type ,
573+ "enable_autotune" : enable_autotune ,
566574 }
567575
568576 runtime_args = {
@@ -761,6 +769,7 @@ def call_moe(
761769 intermediate_size = kwargs ["intermediate_size" ]
762770 routed_scaling = kwargs ["routed_scaling" ]
763771 routing_method_type = kwargs ["routing_method_type" ]
772+ enable_autotune = kwargs .get ("enable_autotune" , True )
764773 enable_pdl = kwargs .get ("enable_pdl" )
765774 hidden_states_scale = kwargs ["hidden_states_scale" ]
766775 hidden_states_quant = kwargs ["hidden_states_quant" ]
@@ -772,7 +781,7 @@ def call_moe(
772781 )
773782
774783 # Use autotuner for optimal kernel selection
775- with autotune (True ):
784+ with autotune (enable_autotune ):
776785 output = trtllm_fp8_block_scale_moe (
777786 expert_logits ,
778787 routing_bias ,
@@ -795,6 +804,7 @@ def call_moe(
795804 use_shuffled_weight = static_data ["use_shuffled_weight" ],
796805 weight_layout = static_data ["weight_layout" ],
797806 enable_pdl = enable_pdl ,
807+ tune_max_num_tokens = TUNE_MAX_NUM_TOKENS ,
798808 )
799809 return output .to (torch .float )
800810
@@ -937,14 +947,15 @@ def call_moe(
937947 intermediate_size = kwargs ["intermediate_size" ]
938948 routed_scaling = kwargs ["routed_scaling" ]
939949 routing_method_type = kwargs ["routing_method_type" ]
950+ enable_autotune = kwargs .get ("enable_autotune" , True )
940951
941952 # Quantize to FP8 per-tensor using pre-computed global scale factor
942953 hidden_states_fp8 , _ = quant_fp8_per_tensor (
943954 hidden_states_orig , hidden_states_scale_global
944955 )
945956
946957 # Use autotuner for optimal kernel selection
947- with autotune (True ):
958+ with autotune (enable_autotune ):
948959 output = trtllm_fp8_per_tensor_scale_moe (
949960 (
950961 expert_logits .to (torch .bfloat16 )
@@ -970,6 +981,7 @@ def call_moe(
970981 == RoutingMethodType .Llama4 , # Use_routing_scales_on_input
971982 None ,
972983 routing_method_type ,
984+ tune_max_num_tokens = TUNE_MAX_NUM_TOKENS ,
973985 )
974986
975987 return output .to (torch .float )
@@ -1101,9 +1113,10 @@ def call_moe(
11011113 top_k_groups = kwargs ["top_k_groups" ]
11021114 intermediate_size = kwargs ["intermediate_size" ]
11031115 routing_method_type = kwargs ["routing_method_type" ]
1116+ enable_autotune = kwargs .get ("enable_autotune" , True )
11041117
11051118 # Use autotuner for optimal kernel selection
1106- with autotune (True ):
1119+ with autotune (enable_autotune ):
11071120 output = trtllm_bf16_moe (
11081121 expert_logits , # float
11091122 routing_bias ,
@@ -1120,6 +1133,7 @@ def call_moe(
11201133 use_shuffled_weight = static_data ["use_shuffled_weight" ],
11211134 weight_layout = static_data ["weight_layout" ],
11221135 routing_method_type = routing_method_type ,
1136+ tune_max_num_tokens = TUNE_MAX_NUM_TOKENS ,
11231137 )
11241138 return output .to (torch .float )
11251139
@@ -1408,20 +1422,18 @@ def routing_reference_topk(expert_logits, top_k, num_experts, padding):
14081422
14091423def check_accuracy (a , b , atol , rtol , percent ):
14101424 """Unified accuracy checking function with detailed error reporting."""
1411- if torch .any (torch .isnan (a )):
1412- raise Exception ("NaN in reference output" )
1413- if torch .any (torch .isnan (b )):
1414- raise Exception ("NaN in actual output" )
1415- if torch .any (torch .isinf (a )):
1416- raise Exception ("Inf in reference output" )
1417- if torch .any (torch .isinf (b )):
1418- raise Exception ("Inf in actual output" )
1425+ if not torch .isfinite (a ).all ():
1426+ raise Exception ("Non-finite values in reference output" )
1427+ if not torch .isfinite (b ).all ():
1428+ raise Exception ("Non-finite values in actual output" )
14191429 assert a .shape == b .shape , f"Shape mismatch: { a .shape } vs { b .shape } "
14201430
1421- left = torch .abs (a - b )
1422- right = atol + rtol * torch .abs (b )
1423- count = torch .sum (left > right )
1424- mismatch_percent = count / a .numel ()
1431+ close = torch .isclose (a , b , atol = atol , rtol = rtol )
1432+ match_ratio = close .float ().mean ()
1433+ if match_ratio >= percent :
1434+ return
1435+
1436+ mismatch_percent = 1.0 - match_ratio .item ()
14251437 if mismatch_percent > 1 - percent :
14261438 raise Exception (
14271439 f"Mismatch percentage is { mismatch_percent :.4f} for rtol { rtol } "
@@ -1999,6 +2011,7 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs):
19992011 "gated_act_type" : args .gated_act_type ,
20002012 "hidden_states_scale" : args .hidden_states_scale ,
20012013 "hidden_states_quant" : kwargs ["hidden_states_quant" ],
2014+ "enable_autotune" : kwargs .get ("enable_autotune" , True ),
20022015 }
20032016
20042017 return moe_impl .call_moe (
@@ -2238,6 +2251,8 @@ def run_moe_test(
22382251 pytest .fail ("Reference computation failed to produce output" )
22392252
22402253 # Compute actual output
2254+ enable_autotune = routing_config .get ("enable_autotune" , True )
2255+
22412256 output_dequant_actual = moe_impl .compute_production (
22422257 args_dequant ,
22432258 args ,
@@ -2253,6 +2268,7 @@ def run_moe_test(
22532268 weight_processing = weight_processing ,
22542269 enable_pdl = True ,
22552270 hidden_states_quant = inputs_data ["hidden_states" ],
2271+ enable_autotune = enable_autotune ,
22562272 )
22572273
22582274 # Compare outputs
@@ -2267,7 +2283,7 @@ def run_moe_test(
22672283
22682284
22692285# Test: Renormalize routing
2270- @pytest .mark .parametrize ("num_tokens" , [1 , 8 , 1024 , 3072 ])
2286+ @pytest .mark .parametrize ("num_tokens" , [8 , 768 , 3072 ])
22712287@pytest .mark .parametrize ("hidden_size" , [1024 ])
22722288@pytest .mark .parametrize ("intermediate_size" , [1024 , 768 , 512 , 384 ])
22732289@pytest .mark .parametrize (
@@ -2301,8 +2317,9 @@ def run_moe_test(
23012317 BF16Moe ,
23022318 ],
23032319 "compatible_intermediate_size" : [384 , 768 , 1024 ],
2320+ "enable_autotune" : True ,
23042321 },
2305- id = "Qwen3 " ,
2322+ id = "Qwen3_MOE " ,
23062323 ),
23072324 pytest .param (
23082325 {
@@ -2321,6 +2338,7 @@ def run_moe_test(
23212338 BF16Moe ,
23222339 ],
23232340 "compatible_intermediate_size" : [384 , 1024 ],
2341+ "enable_autotune" : False ,
23242342 },
23252343 id = "Renorm" ,
23262344 ),
@@ -2341,6 +2359,7 @@ def run_moe_test(
23412359 BF16Moe ,
23422360 ],
23432361 "compatible_intermediate_size" : [512 ],
2362+ "enable_autotune" : True ,
23442363 },
23452364 id = "Qwen3_next" ,
23462365 ),
@@ -2406,7 +2425,7 @@ def test_renormalize_routing(
24062425
24072426
24082427# Test: DeepSeekV3 routing
2409- @pytest .mark .parametrize ("num_tokens" , [1 , 8 , 1024 , 3072 ])
2428+ @pytest .mark .parametrize ("num_tokens" , [8 , 768 , 3072 ])
24102429@pytest .mark .parametrize ("hidden_size" , [1024 ])
24112430@pytest .mark .parametrize ("intermediate_size" , [2048 , 1024 , 768 , 512 , 384 ])
24122431@pytest .mark .parametrize (
@@ -2433,6 +2452,7 @@ def test_renormalize_routing(
24332452 "routing_method_type" : RoutingMethodType .DeepSeekV3 ,
24342453 "compatible_moe_impls" : [FP4Moe , FP8BlockScaleMoe ],
24352454 "compatible_intermediate_size" : [1024 , 2048 ],
2455+ "enable_autotune" : True ,
24362456 },
24372457 id = "kimi_k2" ,
24382458 ),
@@ -2448,6 +2468,7 @@ def test_renormalize_routing(
24482468 "routing_method_type" : RoutingMethodType .DeepSeekV3 ,
24492469 "compatible_moe_impls" : [FP4Moe , FP8BlockScaleMoe ],
24502470 "compatible_intermediate_size" : [512 , 1024 , 2048 ],
2471+ "enable_autotune" : True ,
24512472 },
24522473 id = "DSv3" ,
24532474 ),
@@ -2463,6 +2484,7 @@ def test_renormalize_routing(
24632484 "routing_method_type" : RoutingMethodType .DeepSeekV3 ,
24642485 "compatible_moe_impls" : [FP4Moe , FP8BlockScaleMoe ],
24652486 "compatible_intermediate_size" : [384 , 768 ],
2487+ "enable_autotune" : False ,
24662488 },
24672489 id = "DSLite" ,
24682490 ),
@@ -2528,7 +2550,7 @@ def test_deepseekv3_routing(
25282550
25292551
25302552# Test: TopK routing
2531- @pytest .mark .parametrize ("num_tokens" , [1 , 8 , 128 ]) # Limited for GeGlu
2553+ @pytest .mark .parametrize ("num_tokens" , [8 , 128 ]) # Limited for GeGlu
25322554@pytest .mark .parametrize ("hidden_size" , [1024 ])
25332555@pytest .mark .parametrize ("intermediate_size" , [384 , 512 , 768 , 1024 ])
25342556@pytest .mark .parametrize (
@@ -2552,7 +2574,8 @@ def test_deepseekv3_routing(
25522574 "has_routing_bias" : False ,
25532575 "routing_method_type" : RoutingMethodType .TopK ,
25542576 "compatible_moe_impls" : [FP4Moe ],
2555- "compatible_intermediate_size" : [384 , 512 , 768 , 1024 ],
2577+ "compatible_intermediate_size" : [512 , 768 , 1024 ],
2578+ "enable_autotune" : True ,
25562579 },
25572580 id = "TopK" ,
25582581 ),
@@ -2602,7 +2625,7 @@ def test_topk_routing(
26022625
26032626
26042627# Test: Llama4 routing
2605- @pytest .mark .parametrize ("num_tokens" , [1 , 8 , 1024 ])
2628+ @pytest .mark .parametrize ("num_tokens" , [8 , 768 , 3072 ])
26062629@pytest .mark .parametrize ("hidden_size" , [1024 ])
26072630@pytest .mark .parametrize ("intermediate_size" , [1024 , 2048 ])
26082631@pytest .mark .parametrize (
@@ -2626,6 +2649,7 @@ def test_topk_routing(
26262649 "routing_method_type" : RoutingMethodType .Llama4 ,
26272650 "compatible_moe_impls" : [FP8PerTensorMoe ],
26282651 "compatible_intermediate_size" : [1024 , 2048 ],
2652+ "enable_autotune" : True ,
26292653 },
26302654 id = "Llama4" ,
26312655 ),
0 commit comments