Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions vllm/model_executor/warmup/kernel_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,55 @@ def _is_flashinfer_backend(backend):
create_mixed_batch=True,
)

# Attention backend warmup (non-FlashInfer too).
#
# NOTE: The default startup "profile run" intentionally avoids building
# attention metadata unless running FULL CUDA graphs, which can leave
# attention kernels/backends (e.g. FlashAttention/Triton) uninitialized.
# That shows up as a slow first real request due to JIT/autotune work.
#
# Here we force a small attention-inclusive dummy run so the attention path
# is exercised during startup across backends.
if (
(not worker.model_runner.is_pooling_model)
and worker.model_runner.attn_groups
and (not worker.model_config.is_encoder_decoder)
):
try:
max_batched_tokens = worker.scheduler_config.max_num_batched_tokens
max_num_seqs = worker.scheduler_config.max_num_seqs
uniform_decode_query_len = getattr(
worker.model_runner, "uniform_decode_query_len", 1
)

# 1) Mixed prefill+decode warmup (covers both paths, low cost).
mixed_tokens = min(16, max_batched_tokens)
if mixed_tokens > 0:
worker.model_runner._dummy_run(
num_tokens=mixed_tokens,
skip_eplb=True,
is_profile=True,
force_attention=True,
create_mixed_batch=True,
)

# 2) Uniform decode warmup (hits decode-specialized attention paths).
# Keep it small to avoid extra startup latency.
q = max(1, int(uniform_decode_query_len))
num_reqs = min(16, max_num_seqs, max_batched_tokens // q)
decode_tokens = q * num_reqs
if decode_tokens > 0:
worker.model_runner._dummy_run(
num_tokens=decode_tokens,
skip_eplb=True,
is_profile=True,
force_attention=True,
uniform_decode=True,
)
except Exception:
# Best-effort: warmup should never block engine startup.
logger.exception("Attention backend warmup failed; continuing startup.")


def flashinfer_autotune(runner: "GPUModelRunner") -> None:
"""
Expand Down