@@ -371,6 +371,7 @@ class ChunkedContextMetadata:
371371 query_start_loc : torch .Tensor
372372 max_query_len : int
373373 chunked_context : ChunkedContextMetadata | None = None
374+ query_seq_lens : torch .Tensor | None = None
374375
375376
376377@dataclass
@@ -386,7 +387,6 @@ class CudnnPrefillMetadata(MLACommonPrefillMetadata):
386387 class ChunkedContextMetadata (MLACommonPrefillMetadata .ChunkedContextMetadata ):
387388 seq_lens : torch .Tensor
388389
389- query_seq_lens : torch .Tensor | None = None
390390 cudnn_workspace : torch .Tensor | None = None
391391
392392
@@ -457,6 +457,7 @@ def use_flashinfer_prefill() -> bool:
457457 not envs .VLLM_DISABLE_FLASHINFER_PREFILL
458458 and flashinfer_available
459459 and not envs .VLLM_USE_CUDNN_PREFILL
460+ and not envs .VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
460461 and current_platform .is_device_capability (100 )
461462 )
462463
@@ -470,6 +471,15 @@ def use_cudnn_prefill() -> bool:
470471 )
471472
472473
474+ def use_trtllm_ragged_deepseek_prefill () -> bool :
475+ """Check if TRT-LLM ragged DeepSeek prefill should be used."""
476+ return (
477+ flashinfer_available
478+ and envs .VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
479+ and current_platform .is_device_capability (100 )
480+ )
481+
482+
473483# Currently 394MB, this can be tuned based on GEMM sizes used.
474484# Chosen to be the same as sglang:
475485# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
@@ -593,6 +603,7 @@ def __init__(
593603
594604 self ._use_cudnn_prefill = use_cudnn_prefill ()
595605 self ._use_fi_prefill = use_flashinfer_prefill ()
606+ self ._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill ()
596607 self .prefill_metadata_cls = (
597608 FlashInferPrefillMetadata
598609 if self ._use_fi_prefill
@@ -613,6 +624,11 @@ def __init__(
613624 get_per_layer_parameters (vllm_config , layer_names , MLACommonImpl )
614625 )
615626
627+ if self ._use_trtllm_ragged_prefill :
628+ self ._workspace_buffer = torch .empty (
629+ FLASHINFER_WORKSPACE_BUFFER_SIZE , dtype = torch .uint8 , device = device
630+ )
631+
616632 if self ._use_cudnn_prefill :
617633 self .cudnn_workspace = torch .empty (
618634 CUDNN_WORKSPACE_SIZE * scheduler_config .max_num_seqs ,
@@ -934,6 +950,11 @@ def build(
934950 )
935951 prefill_metadata .cudnn_workspace = self .cudnn_workspace
936952
953+ if self ._use_trtllm_ragged_prefill :
954+ prefill_metadata .query_seq_lens = (
955+ prefill_query_start_loc [1 :] - prefill_query_start_loc [:- 1 ]
956+ )
957+
937958 decode_metadata = None
938959 if num_decodes > 0 :
939960 decode_metadata = self ._build_decode (
@@ -1230,6 +1251,13 @@ def __init__(self, *args, **kwargs) -> None:
12301251 self ._run_prefill_context_chunk = self ._run_prefill_context_chunk_fi
12311252 self ._run_prefill_new_tokens = self ._run_prefill_new_tokens_fi
12321253 self ._pad_v = False
1254+ elif use_trtllm_ragged_deepseek_prefill ():
1255+ logger .debug_once ("Using TRT-LLM ragged DeepSeek prefill for MLA" )
1256+ self ._run_prefill_context_chunk = (
1257+ self ._run_prefill_context_chunk_trtllm_ragged
1258+ )
1259+ self ._run_prefill_new_tokens = self ._run_prefill_new_tokens_trtllm_ragged
1260+ self ._pad_v = False
12331261 elif use_cudnn_prefill ():
12341262 logger .debug_once ("Using CUDNN prefill for MLA" )
12351263 self ._run_prefill_context_chunk = self ._run_prefill_context_chunk_cudnn
@@ -1326,6 +1354,7 @@ def _run_prefill_new_tokens_fi(
13261354 ):
13271355 assert isinstance (prefill , FlashInferPrefillMetadata )
13281356 assert prefill .prefill_main is not None
1357+
13291358 ret = prefill .prefill_main .run (
13301359 q = q ,
13311360 k = k ,
@@ -1334,7 +1363,6 @@ def _run_prefill_new_tokens_fi(
13341363 )
13351364
13361365 if isinstance (ret , tuple ):
1337- # Convert from (q_len, num_heads) to (num_heads, q_len)
13381366 return ret [0 ], ret [1 ].transpose (0 , 1 ).contiguous ()
13391367 return ret
13401368
@@ -1384,12 +1412,14 @@ def _run_prefill_context_chunk_fi(
13841412 self , prefill : MLACommonPrefillMetadata , chunk_idx : int , q , k , v
13851413 ):
13861414 assert isinstance (prefill , FlashInferPrefillMetadata )
1415+
13871416 attn_out , lse = prefill .prefill_chunks [chunk_idx ].run (
13881417 q = q ,
13891418 k = k ,
13901419 v = v ,
13911420 return_lse = True ,
13921421 )
1422+
13931423 # Convert from (q_len, num_heads) to (num_heads, q_len)
13941424 return attn_out , lse .transpose (0 , 1 ).contiguous ()
13951425
@@ -1418,6 +1448,81 @@ def _run_prefill_context_chunk_cudnn(
14181448 is_cuda_graph_compatible = True ,
14191449 )
14201450
1451+ def _run_prefill_new_tokens_trtllm_ragged (
1452+ self , prefill : MLACommonPrefillMetadata , q , k , v , return_softmax_lse
1453+ ):
1454+ """TRT-LLM ragged attention for new tokens (causal)."""
1455+ from flashinfer .prefill import trtllm_ragged_attention_deepseek
1456+
1457+ assert prefill .query_seq_lens is not None
1458+
1459+ ret = trtllm_ragged_attention_deepseek (
1460+ query = q ,
1461+ key = k ,
1462+ value = v ,
1463+ workspace_buffer = self ._workspace_buffer ,
1464+ seq_lens = prefill .query_seq_lens ,
1465+ max_q_len = prefill .max_query_len ,
1466+ max_kv_len = prefill .max_query_len ,
1467+ bmm1_scale = self .scale ,
1468+ bmm2_scale = 1.0 ,
1469+ o_sf_scale = 1.0 ,
1470+ batch_size = prefill .query_seq_lens .shape [0 ],
1471+ window_left = - 1 ,
1472+ cum_seq_lens_q = prefill .query_start_loc ,
1473+ cum_seq_lens_kv = prefill .query_start_loc ,
1474+ enable_pdl = False ,
1475+ is_causal = True ,
1476+ return_lse = return_softmax_lse ,
1477+ )
1478+
1479+ if isinstance (ret , tuple ):
1480+ # Convert from (q_len, num_heads) to (num_heads, q_len)
1481+ return ret [0 ], ret [1 ].transpose (0 , 1 ).contiguous ()
1482+ return ret
1483+
1484+ def _run_prefill_context_chunk_trtllm_ragged (
1485+ self , prefill : MLACommonPrefillMetadata , chunk_idx : int , q , k , v
1486+ ):
1487+ """TRT-LLM ragged attention for context chunks (non-causal)."""
1488+ from flashinfer .prefill import trtllm_ragged_attention_deepseek
1489+
1490+ assert prefill .chunked_context is not None
1491+ assert prefill .chunked_context .seq_lens [chunk_idx ] is not None
1492+
1493+ out = torch .zeros (
1494+ q .shape [0 ],
1495+ q .shape [1 ],
1496+ v .shape [2 ],
1497+ device = q .device ,
1498+ dtype = q .dtype ,
1499+ )
1500+ self ._workspace_buffer .fill_ (0 )
1501+
1502+ attn_out , lse = trtllm_ragged_attention_deepseek (
1503+ query = q ,
1504+ key = k ,
1505+ value = v ,
1506+ workspace_buffer = self ._workspace_buffer ,
1507+ seq_lens = prefill .chunked_context .seq_lens [chunk_idx ],
1508+ max_q_len = prefill .max_query_len ,
1509+ max_kv_len = prefill .chunked_context .max_seq_lens [chunk_idx ],
1510+ bmm1_scale = self .scale ,
1511+ bmm2_scale = 1.0 ,
1512+ o_sf_scale = 1.0 ,
1513+ batch_size = prefill .chunked_context .seq_lens [chunk_idx ].shape [0 ],
1514+ window_left = - 1 ,
1515+ cum_seq_lens_q = prefill .query_start_loc ,
1516+ cum_seq_lens_kv = prefill .chunked_context .cu_seq_lens [chunk_idx ],
1517+ enable_pdl = False ,
1518+ is_causal = False ,
1519+ return_lse = True ,
1520+ out = out ,
1521+ )
1522+
1523+ # Convert from (q_len, num_heads) to (num_heads, q_len)
1524+ return attn_out , lse .transpose (0 , 1 ).contiguous ()
1525+
14211526 def process_weights_after_loading (self , act_dtype : torch .dtype ):
14221527 def get_layer_weight (layer ):
14231528 WEIGHT_NAMES = ("weight" , "qweight" , "weight_packed" )
0 commit comments