Skip to content

Commit 6765cad

Browse files
authored
[Test] Optimize test_trtllm_gen_fused_moe.py (#2072)
1 parent abf6a14 commit 6765cad

File tree

1 file changed

+46
-22
lines changed

1 file changed

+46
-22
lines changed

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@
4949
from flashinfer.utils import get_compute_capability
5050

5151

52+
# Max num tokens to tune for trtllm-gen fused moe
53+
TUNE_MAX_NUM_TOKENS = 4096
54+
55+
5256
def check_cuda(err):
5357
"""Unified CUDA error checking function used throughout the file."""
5458
if err != runtime.cudaError_t.cudaSuccess:
@@ -76,6 +80,7 @@ def __init__(self, moe_impl, static_data, **config):
7680
self.moe_impl = moe_impl
7781
self.static_data = static_data
7882
self.config = config
83+
self.enable_autotune = config.get("enable_autotune", True)
7984
self.graph = None
8085
self.graph_exec = None
8186
self.stream = None
@@ -106,7 +111,7 @@ def capture(self, hidden_states_sample, **runtime_args):
106111
self.input_tensor = hidden_states_sample.clone()
107112

108113
# Warmup
109-
with torch.cuda.stream(torch_stream), autotune(True):
114+
with torch.cuda.stream(torch_stream), autotune(self.enable_autotune):
110115
for _ in range(1):
111116
self._run_moe_computation(runtime_args)
112117

@@ -207,6 +212,7 @@ def _run_moe_computation(self, runtime_args):
207212
routing_method_type=self.config["routing_method_type"],
208213
gated_act_type=self.config["gated_act_type"],
209214
do_finalize=True,
215+
tune_max_num_tokens=TUNE_MAX_NUM_TOKENS,
210216
)
211217
return output # Extract tensor from tuple
212218

@@ -551,6 +557,7 @@ def call_moe(
551557
routed_scaling = kwargs["routed_scaling"]
552558
gated_act_type = kwargs["gated_act_type"]
553559
routing_method_type = kwargs["routing_method_type"]
560+
enable_autotune = kwargs.get("enable_autotune", True)
554561

555562
# Create CUDA graph configuration
556563
config = {
@@ -563,6 +570,7 @@ def call_moe(
563570
"routed_scaling": routed_scaling,
564571
"gated_act_type": gated_act_type,
565572
"routing_method_type": routing_method_type,
573+
"enable_autotune": enable_autotune,
566574
}
567575

568576
runtime_args = {
@@ -761,6 +769,7 @@ def call_moe(
761769
intermediate_size = kwargs["intermediate_size"]
762770
routed_scaling = kwargs["routed_scaling"]
763771
routing_method_type = kwargs["routing_method_type"]
772+
enable_autotune = kwargs.get("enable_autotune", True)
764773
enable_pdl = kwargs.get("enable_pdl")
765774
hidden_states_scale = kwargs["hidden_states_scale"]
766775
hidden_states_quant = kwargs["hidden_states_quant"]
@@ -772,7 +781,7 @@ def call_moe(
772781
)
773782

774783
# Use autotuner for optimal kernel selection
775-
with autotune(True):
784+
with autotune(enable_autotune):
776785
output = trtllm_fp8_block_scale_moe(
777786
expert_logits,
778787
routing_bias,
@@ -795,6 +804,7 @@ def call_moe(
795804
use_shuffled_weight=static_data["use_shuffled_weight"],
796805
weight_layout=static_data["weight_layout"],
797806
enable_pdl=enable_pdl,
807+
tune_max_num_tokens=TUNE_MAX_NUM_TOKENS,
798808
)
799809
return output.to(torch.float)
800810

@@ -937,14 +947,15 @@ def call_moe(
937947
intermediate_size = kwargs["intermediate_size"]
938948
routed_scaling = kwargs["routed_scaling"]
939949
routing_method_type = kwargs["routing_method_type"]
950+
enable_autotune = kwargs.get("enable_autotune", True)
940951

941952
# Quantize to FP8 per-tensor using pre-computed global scale factor
942953
hidden_states_fp8, _ = quant_fp8_per_tensor(
943954
hidden_states_orig, hidden_states_scale_global
944955
)
945956

946957
# Use autotuner for optimal kernel selection
947-
with autotune(True):
958+
with autotune(enable_autotune):
948959
output = trtllm_fp8_per_tensor_scale_moe(
949960
(
950961
expert_logits.to(torch.bfloat16)
@@ -970,6 +981,7 @@ def call_moe(
970981
== RoutingMethodType.Llama4, # Use_routing_scales_on_input
971982
None,
972983
routing_method_type,
984+
tune_max_num_tokens=TUNE_MAX_NUM_TOKENS,
973985
)
974986

975987
return output.to(torch.float)
@@ -1101,9 +1113,10 @@ def call_moe(
11011113
top_k_groups = kwargs["top_k_groups"]
11021114
intermediate_size = kwargs["intermediate_size"]
11031115
routing_method_type = kwargs["routing_method_type"]
1116+
enable_autotune = kwargs.get("enable_autotune", True)
11041117

11051118
# Use autotuner for optimal kernel selection
1106-
with autotune(True):
1119+
with autotune(enable_autotune):
11071120
output = trtllm_bf16_moe(
11081121
expert_logits, # float
11091122
routing_bias,
@@ -1120,6 +1133,7 @@ def call_moe(
11201133
use_shuffled_weight=static_data["use_shuffled_weight"],
11211134
weight_layout=static_data["weight_layout"],
11221135
routing_method_type=routing_method_type,
1136+
tune_max_num_tokens=TUNE_MAX_NUM_TOKENS,
11231137
)
11241138
return output.to(torch.float)
11251139

@@ -1408,20 +1422,18 @@ def routing_reference_topk(expert_logits, top_k, num_experts, padding):
14081422

14091423
def check_accuracy(a, b, atol, rtol, percent):
14101424
"""Unified accuracy checking function with detailed error reporting."""
1411-
if torch.any(torch.isnan(a)):
1412-
raise Exception("NaN in reference output")
1413-
if torch.any(torch.isnan(b)):
1414-
raise Exception("NaN in actual output")
1415-
if torch.any(torch.isinf(a)):
1416-
raise Exception("Inf in reference output")
1417-
if torch.any(torch.isinf(b)):
1418-
raise Exception("Inf in actual output")
1425+
if not torch.isfinite(a).all():
1426+
raise Exception("Non-finite values in reference output")
1427+
if not torch.isfinite(b).all():
1428+
raise Exception("Non-finite values in actual output")
14191429
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"
14201430

1421-
left = torch.abs(a - b)
1422-
right = atol + rtol * torch.abs(b)
1423-
count = torch.sum(left > right)
1424-
mismatch_percent = count / a.numel()
1431+
close = torch.isclose(a, b, atol=atol, rtol=rtol)
1432+
match_ratio = close.float().mean()
1433+
if match_ratio >= percent:
1434+
return
1435+
1436+
mismatch_percent = 1.0 - match_ratio.item()
14251437
if mismatch_percent > 1 - percent:
14261438
raise Exception(
14271439
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
@@ -1999,6 +2011,7 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs):
19992011
"gated_act_type": args.gated_act_type,
20002012
"hidden_states_scale": args.hidden_states_scale,
20012013
"hidden_states_quant": kwargs["hidden_states_quant"],
2014+
"enable_autotune": kwargs.get("enable_autotune", True),
20022015
}
20032016

20042017
return moe_impl.call_moe(
@@ -2238,6 +2251,8 @@ def run_moe_test(
22382251
pytest.fail("Reference computation failed to produce output")
22392252

22402253
# Compute actual output
2254+
enable_autotune = routing_config.get("enable_autotune", True)
2255+
22412256
output_dequant_actual = moe_impl.compute_production(
22422257
args_dequant,
22432258
args,
@@ -2253,6 +2268,7 @@ def run_moe_test(
22532268
weight_processing=weight_processing,
22542269
enable_pdl=True,
22552270
hidden_states_quant=inputs_data["hidden_states"],
2271+
enable_autotune=enable_autotune,
22562272
)
22572273

22582274
# Compare outputs
@@ -2267,7 +2283,7 @@ def run_moe_test(
22672283

22682284

22692285
# Test: Renormalize routing
2270-
@pytest.mark.parametrize("num_tokens", [1, 8, 1024, 3072])
2286+
@pytest.mark.parametrize("num_tokens", [8, 768, 3072])
22712287
@pytest.mark.parametrize("hidden_size", [1024])
22722288
@pytest.mark.parametrize("intermediate_size", [1024, 768, 512, 384])
22732289
@pytest.mark.parametrize(
@@ -2301,8 +2317,9 @@ def run_moe_test(
23012317
BF16Moe,
23022318
],
23032319
"compatible_intermediate_size": [384, 768, 1024],
2320+
"enable_autotune": True,
23042321
},
2305-
id="Qwen3",
2322+
id="Qwen3_MOE",
23062323
),
23072324
pytest.param(
23082325
{
@@ -2321,6 +2338,7 @@ def run_moe_test(
23212338
BF16Moe,
23222339
],
23232340
"compatible_intermediate_size": [384, 1024],
2341+
"enable_autotune": False,
23242342
},
23252343
id="Renorm",
23262344
),
@@ -2341,6 +2359,7 @@ def run_moe_test(
23412359
BF16Moe,
23422360
],
23432361
"compatible_intermediate_size": [512],
2362+
"enable_autotune": True,
23442363
},
23452364
id="Qwen3_next",
23462365
),
@@ -2406,7 +2425,7 @@ def test_renormalize_routing(
24062425

24072426

24082427
# Test: DeepSeekV3 routing
2409-
@pytest.mark.parametrize("num_tokens", [1, 8, 1024, 3072])
2428+
@pytest.mark.parametrize("num_tokens", [8, 768, 3072])
24102429
@pytest.mark.parametrize("hidden_size", [1024])
24112430
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384])
24122431
@pytest.mark.parametrize(
@@ -2433,6 +2452,7 @@ def test_renormalize_routing(
24332452
"routing_method_type": RoutingMethodType.DeepSeekV3,
24342453
"compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe],
24352454
"compatible_intermediate_size": [1024, 2048],
2455+
"enable_autotune": True,
24362456
},
24372457
id="kimi_k2",
24382458
),
@@ -2448,6 +2468,7 @@ def test_renormalize_routing(
24482468
"routing_method_type": RoutingMethodType.DeepSeekV3,
24492469
"compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe],
24502470
"compatible_intermediate_size": [512, 1024, 2048],
2471+
"enable_autotune": True,
24512472
},
24522473
id="DSv3",
24532474
),
@@ -2463,6 +2484,7 @@ def test_renormalize_routing(
24632484
"routing_method_type": RoutingMethodType.DeepSeekV3,
24642485
"compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe],
24652486
"compatible_intermediate_size": [384, 768],
2487+
"enable_autotune": False,
24662488
},
24672489
id="DSLite",
24682490
),
@@ -2528,7 +2550,7 @@ def test_deepseekv3_routing(
25282550

25292551

25302552
# Test: TopK routing
2531-
@pytest.mark.parametrize("num_tokens", [1, 8, 128]) # Limited for GeGlu
2553+
@pytest.mark.parametrize("num_tokens", [8, 128]) # Limited for GeGlu
25322554
@pytest.mark.parametrize("hidden_size", [1024])
25332555
@pytest.mark.parametrize("intermediate_size", [384, 512, 768, 1024])
25342556
@pytest.mark.parametrize(
@@ -2552,7 +2574,8 @@ def test_deepseekv3_routing(
25522574
"has_routing_bias": False,
25532575
"routing_method_type": RoutingMethodType.TopK,
25542576
"compatible_moe_impls": [FP4Moe],
2555-
"compatible_intermediate_size": [384, 512, 768, 1024],
2577+
"compatible_intermediate_size": [512, 768, 1024],
2578+
"enable_autotune": True,
25562579
},
25572580
id="TopK",
25582581
),
@@ -2602,7 +2625,7 @@ def test_topk_routing(
26022625

26032626

26042627
# Test: Llama4 routing
2605-
@pytest.mark.parametrize("num_tokens", [1, 8, 1024])
2628+
@pytest.mark.parametrize("num_tokens", [8, 768, 3072])
26062629
@pytest.mark.parametrize("hidden_size", [1024])
26072630
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
26082631
@pytest.mark.parametrize(
@@ -2626,6 +2649,7 @@ def test_topk_routing(
26262649
"routing_method_type": RoutingMethodType.Llama4,
26272650
"compatible_moe_impls": [FP8PerTensorMoe],
26282651
"compatible_intermediate_size": [1024, 2048],
2652+
"enable_autotune": True,
26292653
},
26302654
id="Llama4",
26312655
),

0 commit comments

Comments
 (0)