Skip to content

Commit 4022a9d

Browse files
[BugFix][Performance] Restore flashinfer autotuning for all scenarios (#27904)
1 parent 53f6e81 commit 4022a9d

File tree

4 files changed

+14
-44
lines changed

4 files changed

+14
-44
lines changed

tests/quantization/test_blackwell_moe.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -172,21 +172,9 @@ def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch
172172
can_initialize("openai/gpt-oss-20b", hf_overrides=HF_OVERRIDE_TEXT)
173173

174174

175-
def test_gptoss_dp2_mxfp4mxfp8_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
176-
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
177-
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
178-
can_initialize(
179-
"openai/gpt-oss-20b",
180-
extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"],
181-
hf_overrides=HF_OVERRIDE_TEXT,
182-
)
183-
184-
185-
def test_gptoss_dp2_mxfp4bf16_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
186-
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
187-
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "deepep_high_throughput")
175+
def test_gptoss_eager(monkeypatch: pytest.MonkeyPatch):
188176
can_initialize(
189177
"openai/gpt-oss-20b",
190-
extra_args=["--data-parallel-size", "2", "--enable-expert-parallel"],
191178
hf_overrides=HF_OVERRIDE_TEXT,
179+
extra_args=["--enforce-eager"],
192180
)

vllm/model_executor/layers/fused_moe/trtllm_moe.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,17 @@ def apply(
127127
"routing_method_type": 1,
128128
"do_finalize": True,
129129
"output": output,
130-
"tune_max_num_tokens": self.max_capture_size,
130+
"tune_max_num_tokens": max(self.max_capture_size, 1),
131131
}
132132

133133
from flashinfer import trtllm_fp4_block_scale_routed_moe
134134

135-
trtllm_fp4_block_scale_routed_moe(**kwargs)
135+
from vllm.utils.flashinfer import autotune
136+
137+
with autotune(False):
138+
# Enable autotune when,
139+
# https://github.com/flashinfer-ai/flashinfer/issues/2023 is
140+
# resolved.
141+
trtllm_fp4_block_scale_routed_moe(**kwargs)
142+
136143
return output

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,7 +1047,7 @@ def apply(
10471047
None,
10481048
1 if renormalize else 0, # routing_method_type, renormalize
10491049
True, # do finalize
1050-
tune_max_num_tokens=self.max_capture_size,
1050+
tune_max_num_tokens=max(self.max_capture_size, 1),
10511051
)[0]
10521052
return trtllm_gen_output
10531053
elif (
@@ -1122,7 +1122,7 @@ def apply(
11221122
tp_rank=self.moe.tp_rank,
11231123
ep_size=self.moe.ep_size,
11241124
ep_rank=self.moe.ep_rank,
1125-
tune_max_num_tokens=self.max_capture_size,
1125+
tune_max_num_tokens=max(self.max_capture_size, 1),
11261126
**extra_kwargs,
11271127
)
11281128

vllm/model_executor/warmup/kernel_warmup.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch
1212

1313
import vllm.envs as envs
14-
from vllm.config import CUDAGraphMode, VllmConfig
1514
from vllm.logger import init_logger
1615
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
1716
from vllm.platforms import current_platform
@@ -25,26 +24,6 @@
2524
logger = init_logger(__name__)
2625

2726

28-
def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
29-
"""
30-
Record known issues with vllm + flashinfer autotune here. Return True if
31-
and only if flashinfer autotune will run through without issues.
32-
"""
33-
is_tp_or_dp = (vllm_config.parallel_config.data_parallel_size > 1) or (
34-
vllm_config.parallel_config.tensor_parallel_size > 1
35-
)
36-
is_fi_mxfp4_backend = (
37-
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
38-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
39-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
40-
) or (
41-
current_platform.is_cuda() and current_platform.is_device_capability(100)
42-
) # on >=sm100, default mxfp4 backend is flashinfer
43-
is_eager = vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
44-
45-
return not (is_tp_or_dp and is_fi_mxfp4_backend and is_eager)
46-
47-
4827
def kernel_warmup(worker: "Worker"):
4928
# Deep GEMM warmup
5029
do_deep_gemm_warmup = (
@@ -58,11 +37,7 @@ def kernel_warmup(worker: "Worker"):
5837
deep_gemm_warmup(model, max_tokens)
5938

6039
# FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs
61-
if (
62-
has_flashinfer()
63-
and current_platform.has_device_capability(90)
64-
and flashinfer_autotune_supported(worker.vllm_config)
65-
):
40+
if has_flashinfer() and current_platform.has_device_capability(90):
6641
flashinfer_autotune(worker.model_runner)
6742

6843
# FlashInfer attention warmup

0 commit comments

Comments
 (0)