@@ -452,6 +452,51 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
452452 update_cudagraph_capture_sizes (vllm_config ,
453453 new_cudagraph_capture_sizes )
454454
455+ # modify the default capture_sizes for num_speculative_tokens >= 1 scenario.
456+ # this is mainly because in the scenario where MTP is superimposed with Full Graph, the FIA operator needs to perform
457+ # padding operations to adapt to its actual_seq_lengths parameter. The padding operation will
458+ # expand each request to the maximum request count under MTP, therefore the input shape must be
459+ # equal to a multiple of the MTP layer count (k+1). Assuming k=2, capture_sizes = [3, 6, 9, 15, 18, ...].
460+ # Consequently, it is necessary to modify the default captured graph shape of Full Graph to
461+ # accommodate this requirement of the FIA operator.
462+ # TODO: It is more appropriate to place the initialization of shape capture for the fullgraph of the FIA
463+ # operator adapted for MTP in the vLLM community. Therefore, this section will be removed
464+ # migrated to the vLLM community.
465+ from vllm .config .compilation import CUDAGraphMode
466+ aclgraph_mode = vllm_config .compilation_config .cudagraph_mode
467+ if vllm_config .speculative_config is not None and \
468+ aclgraph_mode == CUDAGraphMode .FULL_DECODE_ONLY :
469+ num_speculative_tokens = vllm_config .speculative_config .num_speculative_tokens
470+ max_num_seqs = vllm_config .scheduler_config .max_num_seqs
471+ target_sizes = (num_speculative_tokens + 1 ) * max_num_seqs
472+ original_sizes , vllm_config .compilation_config .cudagraph_capture_sizes = \
473+ vllm_config .compilation_config .cudagraph_capture_sizes , None
474+ assert len (original_sizes ) > 0
475+ assert max_num_seqs > 0
476+ assert num_speculative_tokens > 0
477+ if num_speculative_tokens > 1 :
478+ if original_sizes [0 ] < (num_speculative_tokens + 1 ) * max_num_seqs :
479+ new_original_sizes = sorted (set (list (range (1 , min (10 , max_num_seqs + 1 ), 2 )) + list (range (8 , max_num_seqs + 1 , 4 ))))
480+ enlarged_sizes = [(num_speculative_tokens + 1 ) * sizes for sizes in new_original_sizes ]
481+ if enlarged_sizes [- 1 ] < target_sizes :
482+ enlarged_sizes .append (target_sizes )
483+ update_cudagraph_capture_sizes (vllm_config , enlarged_sizes )
484+ logger .info (
485+ "Adjusted ACL full graphs: %s → %s for speculative decoding" ,
486+ original_sizes , enlarged_sizes )
487+ else :
488+ vllm_config .compilation_config .cudagraph_capture_sizes = original_sizes
489+ if num_speculative_tokens == 1 :
490+ padding_sizes = original_sizes .copy ()
491+ if padding_sizes [- 1 ] < target_sizes :
492+ padding_sizes .append (target_sizes )
493+ update_cudagraph_capture_sizes (vllm_config , padding_sizes )
494+ logger .info (
495+ "Adjusted ACL full graphs: %s → %s for speculative decoding" ,
496+ original_sizes , padding_sizes )
497+ else :
498+ vllm_config .compilation_config .cudagraph_capture_sizes = original_sizes
499+
455500
456501def update_aclgraph_sizes (vllm_config : VllmConfig ) -> None :
457502 """Update ACL graph capture sizes based on hardware limitations"""
@@ -571,13 +616,21 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
571616 max_num_seqs = vllm_config .scheduler_config .max_num_seqs
572617 original_sizes , compilation_config .cudagraph_capture_sizes = \
573618 compilation_config .cudagraph_capture_sizes , None
619+ new_original_sizes = sorted (set (list (range (1 , min (10 , max_num_seqs + 1 ), 2 )) + list (range (8 , max_num_seqs + 1 , 4 ))))
620+ step = (len (new_original_sizes ) - 1 ) / (max_num_batch_sizes - 1 )
621+ indices = [round (i * step ) for i in range (max_num_batch_sizes )]
622+ indices [0 ], indices [- 1 ] = 0 , len (new_original_sizes ) - 1
623+ new_sampled_sizes = [new_original_sizes [i ] for i in indices ]
624+ target_sizes = (num_speculative_tokens + 1 ) * max_num_seqs
574625 assert len (original_sizes ) > 0
575626 if original_sizes [0 ] < (num_speculative_tokens + 1 ) * max_num_seqs :
576627 enlarged_sizes = [(num_speculative_tokens + 1 ) * size
577- for size in original_sizes ]
628+ for size in new_sampled_sizes ]
629+ if enlarged_sizes [- 1 ] < target_sizes :
630+ enlarged_sizes [- 1 ] = target_sizes
578631 update_cudagraph_capture_sizes (vllm_config , enlarged_sizes )
579632 logger .info (
580- "Adjusted ACL graphs: %s → %s for speculative decoding" ,
633+ "Adjusted PieceWise ACL graphs: %s → %s for speculative decoding" ,
581634 original_sizes , enlarged_sizes )
582635 else :
583636 compilation_config .cudagraph_capture_sizes = original_sizes
0 commit comments