Skip to content

Commit 343b465

Browse files
committed
Update the way of capture
Signed-off-by: anon189Ty <[email protected]>
1 parent 6a3ec43 commit 343b465

File tree

3 files changed

+161
-60
lines changed

3 files changed

+161
-60
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
33
TypeVar)
44

5+
import numpy as np
56
import torch
67
import torch_npu
78
from torch import nn
@@ -290,6 +291,31 @@ def reorder_batch(self, input_batch: "InputBatch",
290291
# better way of doing this
291292
return modified_batch
292293

294+
def pad_actual_seq_len_q(self, num_reqs_pad_size, num_reqs,
295+
actual_seq_lengths_q):
296+
"""
297+
Only use for acl full graph mode.
298+
Pad the last element of the actual_seq_lengths_q equal to the TND(T) and
299+
the num of dimensions equal to the batch_size of main model.
300+
301+
For example:
302+
batch_size = 8, num_reqs = 4, num_speculative_tokens = 1
303+
input actual_seq_lengths_q = [1, 2, 4, 5] (the 3rd req was accept a token)
304+
After padding the actual_seq_lengths_q will be similar to [1, 2, 4, 5, 6, 6, 7, 8]
305+
"""
306+
need_padding = num_reqs_pad_size > 0
307+
if need_padding:
308+
start_val = actual_seq_lengths_q[-1]
309+
end_val = num_reqs + num_reqs_pad_size
310+
num_step = num_reqs_pad_size
311+
interpolated = np.round(
312+
np.linspace(start_val, end_val,
313+
num_step + 1)[1:]).astype(int).tolist()
314+
assert interpolated[-1] == end_val
315+
assert len(interpolated) == num_reqs_pad_size
316+
actual_seq_lengths_q = actual_seq_lengths_q + interpolated
317+
return actual_seq_lengths_q
318+
293319
def build(
294320
self,
295321
common_prefix_len: int,
@@ -310,7 +336,13 @@ def build(
310336
# it blocks on all previous kernels.
311337
device = self.device
312338

313-
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
339+
graph_pad_size = common_attn_metadata.graph_pad_size
340+
if graph_pad_size > num_reqs:
341+
# Should we add other check to make sure is in full graph mode?
342+
block_table = (
343+
common_attn_metadata.block_table_tensor[:graph_pad_size])
344+
else:
345+
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
314346
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
315347
input_positions = common_attn_metadata.positions[:
316348
num_actual_tokens].long(
@@ -407,9 +439,28 @@ def build(
407439
max_seq_lens = seq_lens[:num_decodes].max().item()
408440
seq_lens = seq_lens[:num_decodes]
409441
input_positions = input_positions[:num_decode_tokens]
410-
block_table = block_table[:num_decodes, ...]
442+
if graph_pad_size > num_decodes:
443+
# Should we add other check to make sure is in full graph mode?
444+
block_table = block_table[:graph_pad_size, ...]
445+
else:
446+
block_table = block_table[:num_decodes, ...]
411447
seq_lens_list = seq_lens.tolist()
412448

449+
if graph_pad_size > num_reqs:
450+
# Should we add other check to make sure is in full graph mode?
451+
num_reqs_pad_size = graph_pad_size - num_reqs
452+
actual_seq_lengths_q = self.pad_actual_seq_len_q(
453+
num_reqs_pad_size, num_reqs, actual_seq_lengths_q)
454+
seq_lens_list = seq_lens_list + [0] * num_reqs_pad_size
455+
num_block_pad_size = graph_pad_size - block_table.shape[0]
456+
if num_block_pad_size > 0:
457+
block_table_padding = torch.zeros(
458+
(num_block_pad_size, ) + block_table.shape[1:],
459+
dtype=block_table.dtype,
460+
device=block_table.device)
461+
block_table = torch.cat([block_table, block_table_padding],
462+
dim=0)
463+
413464
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
414465
assert self.cos_cache is not None
415466
assert self.sin_cache is not None

vllm_ascend/compilation/acl_graph.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,17 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
279279
spec_multiple * (i + 1)
280280
for i in range(runtime_shape // spec_multiple)
281281
]
282+
elif forward_context.is_mtp_model:
283+
actual_seq_lengths = forward_context.attn_metadata[
284+
key].decode.actual_seq_lengths_q
285+
block_table = forward_context.attn_metadata[
286+
key].decode.block_table
287+
seq_lens_list = seq_lens_list + [0] * (
288+
len(actual_seq_lengths) - len(seq_lens_list))
282289
else:
283290
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
284291
len(seq_lens_list))
285-
torch.npu.graph_task_update_begin(update_stream, handle)
292+
torch.npu.graph_task_update_begin(update_stream, handle)
286293

287294
torch_npu.npu_fused_infer_attention_score.out(
288295
q_nope,
@@ -346,22 +353,14 @@ def get_graph_params():
346353
return _graph_params
347354

348355

349-
@dataclass
350-
class MTPGraphParams:
351-
events: dict[int, list[torch.npu.ExternalEvent]]
352-
workspaces: dict[int, torch.Tensor]
353-
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
354-
attn_params: dict[int, list[tuple]]
355-
356-
357-
_mtp_graph_params: Optional[MTPGraphParams] = None
356+
_mtp_graph_params: Optional[GraphParams] = None
358357

359358

360359
def set_mtp_graph_params(aclgraph_capture_sizes: set[int]):
361360
global _mtp_graph_params
362361
if _mtp_graph_params is not None:
363362
raise ValueError("MTPGraph parameters have already been set!")
364-
_mtp_graph_params = MTPGraphParams(
363+
_mtp_graph_params = GraphParams(
365364
{size: []
366365
for size in aclgraph_capture_sizes},
367366
{size: None

0 commit comments

Comments
 (0)