2121import os
2222import time
2323import weakref
24+ import math
2425from contextlib import contextmanager , nullcontext
2526from dataclasses import dataclass
2627from typing import TYPE_CHECKING , Dict , List , Optional , Union
@@ -976,7 +977,9 @@ def capture_model(self) -> None:
976977
977978 start_time = time .perf_counter ()
978979 start_free_npu_memory = torch .npu .mem_get_info ()[0 ]
979-
980+ # Since vllm npugraph_batch_sizes is too large,
981+ # we need to adjust its length to proper size.
982+ self .verify_adjust_npugraph_batch_sizes ()
980983 # Trigger NPU graph capture for specific shapes.
981984 # Capture the large shapes first so that the smaller shapes
982985 # can reuse the memory pool allocated for the large shapes.
@@ -994,3 +997,59 @@ def capture_model(self) -> None:
994997 # This usually takes 5~20 seconds.
995998 logger .info ("Graph capturing finished in %.0f secs, took %.2f GiB" ,
996999 elapsed_time , npu_graph_size / (1 << 30 ))
1000+
1001+ def verify_adjust_npugraph_batch_sizes (self ) -> None :
1002+ # Now, vllm-ascend support max capture size is 1920
1003+ max_capture_size = 1920
1004+ original_npugraph_batch_sizes = self .npugraph_batch_sizes
1005+ num_hidden_layers = self .vllm_config .model_config .hf_config .num_hidden_layers
1006+ max_support_len_npugraph = self .get_max_support_len (max_capture_size , num_hidden_layers )
1007+
1008+ if max_support_len_npugraph < len (original_npugraph_batch_sizes ):
1009+ self .npugraph_batch_sizes = self .sample_from_list (max_support_len_npugraph )
1010+ logger .info ("Model:%s-num_hidden_layers:%d will adjust npugraph_bash_size, pre-adjust-len: %s, post-adjust-len: %s" ,
1011+ self .vllm_config .model_config .architectures [0 ],
1012+ num_hidden_layers ,
1013+ len (original_npugraph_batch_sizes ),
1014+ len (self .npugraph_batch_sizes )
1015+ )
1016+ else :
1017+ logger .info ("Model:%s-num_hidden_layers:%d no need adjust npugraph_bash_size, list_len: %s" ,
1018+ self .vllm_config .model_config .architectures [0 ],
1019+ num_hidden_layers ,
1020+ len (original_npugraph_batch_sizes )
1021+ )
1022+
1023+ def get_max_support_len (self , max_capture_size , num_hidden_layers ) -> int :
1024+ parallel_type_cnt = 0
1025+ dp_size = self .vllm_config .parallel_config .data_parallel_size
1026+ tp_size = self .vllm_config .parallel_config .tensor_parallel_size
1027+ if dp_size > 1 :
1028+ parallel_type_cnt += 1
1029+ if tp_size > 1 :
1030+ parallel_type_cnt += 1
1031+ max_support_len_npugraph = math .floor (max_capture_size / (num_hidden_layers + 1 ) / (parallel_type_cnt + 1 ))
1032+ logger .info ("max_capture_size:%s, dp_size:%s, tp_size:%s, parallel_type_cnt:%s, max_support_len_npugraph: %s:" ,
1033+ max_capture_size ,
1034+ dp_size ,
1035+ tp_size ,
1036+ parallel_type_cnt ,
1037+ max_support_len_npugraph
1038+ )
1039+
1040+ return max_support_len_npugraph
1041+
1042+ def sample_from_list (self , sample_len ) -> list [int ]:
1043+ # we use this function to sample a new list from old list by given length, and aintain uniformity, for example:
1044+ # original: [1 8 16 24 32 40 48 56 64]
1045+ # --> sample length = 3: [1 32 64]
1046+ # --> sample length = 5: [1 16 32 48 56]
1047+ original_len = len (self .npugraph_batch_sizes )
1048+ step = (original_len - 1 ) / (sample_len - 1 )
1049+ indices = [round (i * step ) for i in range (sample_len )]
1050+ # Align first and last element of the original list and sub-list
1051+ indices [0 ] = 0
1052+ indices [- 1 ] = original_len - 1
1053+ # Sample new list
1054+ new_list = [self .npugraph_batch_sizes [i ] for i in indices ]
1055+ return new_list
0 commit comments