22from typing import (TYPE_CHECKING , ClassVar , NamedTuple , Optional , Tuple , Type ,
33 TypeVar )
44
5+ import numpy as np
56import torch
67import torch_npu
78from torch import nn
@@ -290,6 +291,31 @@ def reorder_batch(self, input_batch: "InputBatch",
290291 # better way of doing this
291292 return modified_batch
292293
294+ def pad_actual_seq_len_q (self , num_reqs_pad_size , num_reqs ,
295+ actual_seq_lengths_q ):
296+ """
297+ Only use for acl full graph mode.
298+ Pad the last element of the actual_seq_lengths_q equal to the TND(T) and
299+ the num of dimensions equal to the batch_size of main model.
300+
301+ For example:
302+ batch_size = 8, num_reqs = 4, num_speculative_tokens = 1
303+ input actual_seq_lengths_q = [1, 2, 4, 5] (the 3rd req was accept a token)
304+ After padding the actual_seq_lengths_q will be similar to [1, 2, 4, 5, 6, 6, 7, 8]
305+ """
306+ need_padding = num_reqs_pad_size > 0
307+ if need_padding :
308+ start_val = actual_seq_lengths_q [- 1 ]
309+ end_val = num_reqs + num_reqs_pad_size
310+ num_step = num_reqs_pad_size
311+ interpolated = np .round (
312+ np .linspace (start_val , end_val ,
313+ num_step + 1 )[1 :]).astype (int ).tolist ()
314+ assert interpolated [- 1 ] == end_val
315+ assert len (interpolated ) == num_reqs_pad_size
316+ actual_seq_lengths_q = actual_seq_lengths_q + interpolated
317+ return actual_seq_lengths_q
318+
293319 def build (
294320 self ,
295321 common_prefix_len : int ,
@@ -310,7 +336,13 @@ def build(
310336 # it blocks on all previous kernels.
311337 device = self .device
312338
313- block_table = (common_attn_metadata .block_table_tensor [:num_reqs ])
339+ graph_pad_size = common_attn_metadata .graph_pad_size
340+ if graph_pad_size > num_reqs :
341+ # Should we add other check to make sure is in full graph mode?
342+ block_table = (
343+ common_attn_metadata .block_table_tensor [:graph_pad_size ])
344+ else :
345+ block_table = (common_attn_metadata .block_table_tensor [:num_reqs ])
314346 slot_mapping = common_attn_metadata .slot_mapping [:num_actual_tokens ]
315347 input_positions = common_attn_metadata .positions [:
316348 num_actual_tokens ].long (
@@ -407,9 +439,28 @@ def build(
407439 max_seq_lens = seq_lens [:num_decodes ].max ().item ()
408440 seq_lens = seq_lens [:num_decodes ]
409441 input_positions = input_positions [:num_decode_tokens ]
410- block_table = block_table [:num_decodes , ...]
442+ if graph_pad_size > num_decodes :
443+ # Should we add other check to make sure is in full graph mode?
444+ block_table = block_table [:graph_pad_size , ...]
445+ else :
446+ block_table = block_table [:num_decodes , ...]
411447 seq_lens_list = seq_lens .tolist ()
412448
449+ if graph_pad_size > num_reqs :
450+ # Should we add other check to make sure is in full graph mode?
451+ num_reqs_pad_size = graph_pad_size - num_reqs
452+ actual_seq_lengths_q = self .pad_actual_seq_len_q (
453+ num_reqs_pad_size , num_reqs , actual_seq_lengths_q )
454+ seq_lens_list = seq_lens_list + [0 ] * num_reqs_pad_size
455+ num_block_pad_size = graph_pad_size - block_table .shape [0 ]
456+ if num_block_pad_size > 0 :
457+ block_table_padding = torch .zeros (
458+ (num_block_pad_size , ) + block_table .shape [1 :],
459+ dtype = block_table .dtype ,
460+ device = block_table .device )
461+ block_table = torch .cat ([block_table , block_table_padding ],
462+ dim = 0 )
463+
413464 # TODO: After the fullgraph supports MTP, the if branch needs to deleted
414465 assert self .cos_cache is not None
415466 assert self .sin_cache is not None
0 commit comments