@@ -192,6 +192,7 @@ class AscendMetadata:
192192 seq_lens : torch .Tensor = None
193193 seq_lens_list : List [int ] = None # type: ignore
194194 actual_seq_lengths_q : List [int ] = None # type: ignore
195+ query_start_loc_list : List [int ] = None # type: ignore
195196
196197 query_start_loc : torch .Tensor = None
197198 query_lens : torch .Tensor = None
@@ -360,6 +361,7 @@ def build(
360361 num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded ,
361362 block_tables = block_table ,
362363 query_start_loc = query_start_loc ,
364+ query_start_loc_list = query_start_loc_cpu .tolist (),
363365 query_lens = query_lens ,
364366 seq_lens = seq_lens ,
365367 seq_lens_list = seq_lens .tolist (),
@@ -454,7 +456,7 @@ def full_graph_attention(self,
454456 forward_context : ForwardContext = get_forward_context ()
455457 if forward_context .capturing :
456458 graph_params = get_graph_params ()
457- query_start_loc = attn_metadata .actual_seq_lengths_q
459+ query_start_loc = attn_metadata .query_start_loc_list
458460 seq_lens = attn_metadata .seq_lens_lis
459461 # Prepare tensors for attention output
460462 # TODO: Refactor this to step-level instead of layer-level
0 commit comments