@@ -2232,7 +2232,8 @@ def get_finished_kv_transfer(
22322232 return None , None
22332233
22342234 def _build_attention_metadata (self , create_mixed_batch , num_reqs ,
2235- num_tokens , max_query_len , force_attention ):
2235+ num_tokens , max_query_len , force_attention ,
2236+ num_scheduled_tokens ):
22362237 attn_metadata : Optional [dict [str , Any ]] = None
22372238
22382239 if force_attention :
@@ -2246,8 +2247,10 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
22462247 self .seq_lens_np [:num_reqs ] = seq_lens
22472248 self .seq_lens_np [num_reqs :] = 0
22482249
2249- self .query_start_loc [:num_reqs + 1 ] = torch .arange (num_reqs + 1 )
2250- self .query_start_loc_cpu [:num_reqs + 1 ] = torch .arange (num_reqs + 1 )
2250+ cu_num_tokens , arange = self ._get_cumsum_and_arange (num_scheduled_tokens )
2251+
2252+ self .query_start_loc [1 :num_reqs + 1 ] = torch .Tensor (cu_num_tokens )
2253+ self .query_start_loc_cpu [1 :num_reqs + 1 ] = torch .Tensor (cu_num_tokens )
22512254
22522255 num_computed_tokens_cpu = (
22532256 self .input_batch .num_computed_tokens_cpu_tensor [:num_reqs ])
@@ -2393,6 +2396,7 @@ def _dummy_run(
23932396 num_tokens = num_tokens ,
23942397 max_query_len = max_query_len ,
23952398 force_attention = force_attention ,
2399+ num_scheduled_tokens = num_scheduled_tokens ,
23962400 )
23972401
23982402 if not self .in_profile_run and self .dynamic_eplb :
0 commit comments