@@ -326,14 +326,12 @@ def build(
326326 + 1 ]
327327 num_computed_tokens_cpu = (seq_lens - query_lens )
328328
329- if attn_state == AscendAttentionState .DecodeOnly and \
330- common_attn_metadata .num_input_tokens > num_actual_tokens :
329+ if common_attn_metadata .num_input_tokens > num_actual_tokens :
331330 padded_num_tokens = common_attn_metadata .num_input_tokens - num_actual_tokens
332331 seq_lens = torch .cat ([
333332 seq_lens ,
334- torch .ones (padded_num_tokens ,
335- dtype = seq_lens .dtype ,
336- device = seq_lens .device )
333+ torch .tensor ([padded_num_tokens
334+ ]).to (seq_lens .device ).to (seq_lens .dtype )
337335 ])
338336 block_table_padding = torch .zeros (
339337 (padded_num_tokens , ) + block_table .shape [1 :],
@@ -342,10 +340,8 @@ def build(
342340 block_table = torch .cat ([block_table , block_table_padding ], dim = 0 )
343341 query_start_loc_cpu = torch .cat ([
344342 query_start_loc_cpu ,
345- torch .arange (query_start_loc_cpu [- 1 ] + 1 ,
346- query_start_loc_cpu [- 1 ] + padded_num_tokens ,
347- dtype = query_start_loc_cpu .dtype ,
348- device = query_start_loc_cpu .device )
343+ torch .tensor ([query_start_loc_cpu [- 1 ] + padded_num_tokens ]).to (
344+ query_start_loc_cpu .device ).to (query_start_loc_cpu .dtype )
349345 ])
350346
351347 query_start_loc = query_start_loc_cpu .to (self .device ,
@@ -621,7 +617,6 @@ def full_graph_attention(self,
621617 actual_seq_lengths_kv = attn_metadata .seq_lens_list
622618
623619 num_tokens = attn_metadata .query_start_loc_list [- 1 ]
624- query = query [:num_tokens ]
625620 graph_params = get_graph_params ()
626621 query_start_loc = attn_metadata .query_start_loc_list
627622 # Prepare tensors for attention output
0 commit comments