Skip to content

Commit 0f67d4d

Browse files
authored
[Attention] Add MLA prefill backend: trtllm_ragged_attention_deepseek (#26397)
Signed-off-by: Ming Yang <[email protected]>
1 parent 7e1d697 commit 0f67d4d

File tree

2 files changed

+113
-2
lines changed

2 files changed

+113
-2
lines changed

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@
183183
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None
184184
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480
185185
VLLM_USE_CUDNN_PREFILL: bool = False
186+
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False
186187
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
187188
VLLM_LOOPBACK_IP: str = ""
188189
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
@@ -1250,6 +1251,10 @@ def get_vllm_port() -> int | None:
12501251
"VLLM_USE_CUDNN_PREFILL": lambda: bool(
12511252
int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))
12521253
),
1254+
# Controls whether to use TRT-LLM ragged DeepSeek prefill
1255+
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL": lambda: bool(
1256+
int(os.getenv("VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL", "0"))
1257+
),
12531258
# If set to 1/True, use the TRTLLM attention backend in flashinfer.
12541259
# If set to 0/False, use the default attention backend in flashinfer.
12551260
# If not set, auto-detect the attention backend in flashinfer.
@@ -1481,6 +1486,7 @@ def compute_hash() -> str:
14811486
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
14821487
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
14831488
"VLLM_USE_CUDNN_PREFILL",
1489+
"VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL",
14841490
"VLLM_USE_TRTLLM_ATTENTION",
14851491
"VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION",
14861492
"VLLM_ROCM_USE_AITER",

vllm/v1/attention/backends/mla/common.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ class ChunkedContextMetadata:
371371
query_start_loc: torch.Tensor
372372
max_query_len: int
373373
chunked_context: ChunkedContextMetadata | None = None
374+
query_seq_lens: torch.Tensor | None = None
374375

375376

376377
@dataclass
@@ -386,7 +387,6 @@ class CudnnPrefillMetadata(MLACommonPrefillMetadata):
386387
class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata):
387388
seq_lens: torch.Tensor
388389

389-
query_seq_lens: torch.Tensor | None = None
390390
cudnn_workspace: torch.Tensor | None = None
391391

392392

@@ -457,6 +457,7 @@ def use_flashinfer_prefill() -> bool:
457457
not envs.VLLM_DISABLE_FLASHINFER_PREFILL
458458
and flashinfer_available
459459
and not envs.VLLM_USE_CUDNN_PREFILL
460+
and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
460461
and current_platform.is_device_capability(100)
461462
)
462463

@@ -470,6 +471,15 @@ def use_cudnn_prefill() -> bool:
470471
)
471472

472473

474+
def use_trtllm_ragged_deepseek_prefill() -> bool:
475+
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
476+
return (
477+
flashinfer_available
478+
and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
479+
and current_platform.is_device_capability(100)
480+
)
481+
482+
473483
# Currently 394MB, this can be tuned based on GEMM sizes used.
474484
# Chosen to be the same as sglang:
475485
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
@@ -593,6 +603,7 @@ def __init__(
593603

594604
self._use_cudnn_prefill = use_cudnn_prefill()
595605
self._use_fi_prefill = use_flashinfer_prefill()
606+
self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill()
596607
self.prefill_metadata_cls = (
597608
FlashInferPrefillMetadata
598609
if self._use_fi_prefill
@@ -613,6 +624,11 @@ def __init__(
613624
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
614625
)
615626

627+
if self._use_trtllm_ragged_prefill:
628+
self._workspace_buffer = torch.empty(
629+
FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, device=device
630+
)
631+
616632
if self._use_cudnn_prefill:
617633
self.cudnn_workspace = torch.empty(
618634
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
@@ -934,6 +950,11 @@ def build(
934950
)
935951
prefill_metadata.cudnn_workspace = self.cudnn_workspace
936952

953+
if self._use_trtllm_ragged_prefill:
954+
prefill_metadata.query_seq_lens = (
955+
prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
956+
)
957+
937958
decode_metadata = None
938959
if num_decodes > 0:
939960
decode_metadata = self._build_decode(
@@ -1230,6 +1251,13 @@ def __init__(self, *args, **kwargs) -> None:
12301251
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
12311252
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
12321253
self._pad_v = False
1254+
elif use_trtllm_ragged_deepseek_prefill():
1255+
logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA")
1256+
self._run_prefill_context_chunk = (
1257+
self._run_prefill_context_chunk_trtllm_ragged
1258+
)
1259+
self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
1260+
self._pad_v = False
12331261
elif use_cudnn_prefill():
12341262
logger.debug_once("Using CUDNN prefill for MLA")
12351263
self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
@@ -1326,6 +1354,7 @@ def _run_prefill_new_tokens_fi(
13261354
):
13271355
assert isinstance(prefill, FlashInferPrefillMetadata)
13281356
assert prefill.prefill_main is not None
1357+
13291358
ret = prefill.prefill_main.run(
13301359
q=q,
13311360
k=k,
@@ -1334,7 +1363,6 @@ def _run_prefill_new_tokens_fi(
13341363
)
13351364

13361365
if isinstance(ret, tuple):
1337-
# Convert from (q_len, num_heads) to (num_heads, q_len)
13381366
return ret[0], ret[1].transpose(0, 1).contiguous()
13391367
return ret
13401368

@@ -1384,12 +1412,14 @@ def _run_prefill_context_chunk_fi(
13841412
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
13851413
):
13861414
assert isinstance(prefill, FlashInferPrefillMetadata)
1415+
13871416
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
13881417
q=q,
13891418
k=k,
13901419
v=v,
13911420
return_lse=True,
13921421
)
1422+
13931423
# Convert from (q_len, num_heads) to (num_heads, q_len)
13941424
return attn_out, lse.transpose(0, 1).contiguous()
13951425

@@ -1418,6 +1448,81 @@ def _run_prefill_context_chunk_cudnn(
14181448
is_cuda_graph_compatible=True,
14191449
)
14201450

1451+
def _run_prefill_new_tokens_trtllm_ragged(
1452+
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
1453+
):
1454+
"""TRT-LLM ragged attention for new tokens (causal)."""
1455+
from flashinfer.prefill import trtllm_ragged_attention_deepseek
1456+
1457+
assert prefill.query_seq_lens is not None
1458+
1459+
ret = trtllm_ragged_attention_deepseek(
1460+
query=q,
1461+
key=k,
1462+
value=v,
1463+
workspace_buffer=self._workspace_buffer,
1464+
seq_lens=prefill.query_seq_lens,
1465+
max_q_len=prefill.max_query_len,
1466+
max_kv_len=prefill.max_query_len,
1467+
bmm1_scale=self.scale,
1468+
bmm2_scale=1.0,
1469+
o_sf_scale=1.0,
1470+
batch_size=prefill.query_seq_lens.shape[0],
1471+
window_left=-1,
1472+
cum_seq_lens_q=prefill.query_start_loc,
1473+
cum_seq_lens_kv=prefill.query_start_loc,
1474+
enable_pdl=False,
1475+
is_causal=True,
1476+
return_lse=return_softmax_lse,
1477+
)
1478+
1479+
if isinstance(ret, tuple):
1480+
# Convert from (q_len, num_heads) to (num_heads, q_len)
1481+
return ret[0], ret[1].transpose(0, 1).contiguous()
1482+
return ret
1483+
1484+
def _run_prefill_context_chunk_trtllm_ragged(
1485+
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
1486+
):
1487+
"""TRT-LLM ragged attention for context chunks (non-causal)."""
1488+
from flashinfer.prefill import trtllm_ragged_attention_deepseek
1489+
1490+
assert prefill.chunked_context is not None
1491+
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
1492+
1493+
out = torch.zeros(
1494+
q.shape[0],
1495+
q.shape[1],
1496+
v.shape[2],
1497+
device=q.device,
1498+
dtype=q.dtype,
1499+
)
1500+
self._workspace_buffer.fill_(0)
1501+
1502+
attn_out, lse = trtllm_ragged_attention_deepseek(
1503+
query=q,
1504+
key=k,
1505+
value=v,
1506+
workspace_buffer=self._workspace_buffer,
1507+
seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
1508+
max_q_len=prefill.max_query_len,
1509+
max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
1510+
bmm1_scale=self.scale,
1511+
bmm2_scale=1.0,
1512+
o_sf_scale=1.0,
1513+
batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0],
1514+
window_left=-1,
1515+
cum_seq_lens_q=prefill.query_start_loc,
1516+
cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx],
1517+
enable_pdl=False,
1518+
is_causal=False,
1519+
return_lse=True,
1520+
out=out,
1521+
)
1522+
1523+
# Convert from (q_len, num_heads) to (num_heads, q_len)
1524+
return attn_out, lse.transpose(0, 1).contiguous()
1525+
14211526
def process_weights_after_loading(self, act_dtype: torch.dtype):
14221527
def get_layer_weight(layer):
14231528
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")

0 commit comments

Comments
 (0)