diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 95f5982bc8c7..a1a0feab9b4c 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -73,6 +73,55 @@ def _is_flashinfer_backend(backend): create_mixed_batch=True, ) + # Attention backend warmup (non-FlashInfer too). + # + # NOTE: The default startup "profile run" intentionally avoids building + # attention metadata unless running FULL CUDA graphs, which can leave + # attention kernels/backends (e.g. FlashAttention/Triton) uninitialized. + # That shows up as a slow first real request due to JIT/autotune work. + # + # Here we force a small attention-inclusive dummy run so the attention path + # is exercised during startup across backends. + if ( + (not worker.model_runner.is_pooling_model) + and worker.model_runner.attn_groups + and (not worker.model_config.is_encoder_decoder) + ): + try: + max_batched_tokens = worker.scheduler_config.max_num_batched_tokens + max_num_seqs = worker.scheduler_config.max_num_seqs + uniform_decode_query_len = getattr( + worker.model_runner, "uniform_decode_query_len", 1 + ) + + # 1) Mixed prefill+decode warmup (covers both paths, low cost). + mixed_tokens = min(16, max_batched_tokens) + if mixed_tokens > 0: + worker.model_runner._dummy_run( + num_tokens=mixed_tokens, + skip_eplb=True, + is_profile=True, + force_attention=True, + create_mixed_batch=True, + ) + + # 2) Uniform decode warmup (hits decode-specialized attention paths). + # Keep it small to avoid extra startup latency. + q = max(1, int(uniform_decode_query_len)) + num_reqs = min(16, max_num_seqs, max_batched_tokens // q) + decode_tokens = q * num_reqs + if decode_tokens > 0: + worker.model_runner._dummy_run( + num_tokens=decode_tokens, + skip_eplb=True, + is_profile=True, + force_attention=True, + uniform_decode=True, + ) + except Exception: + # Best-effort: warmup should never block engine startup. + logger.exception("Attention backend warmup failed; continuing startup.") + def flashinfer_autotune(runner: "GPUModelRunner") -> None: """