Skip to content

Commit 6d2bc7a

Browse files
author
wangxiaoxin-sherie
committed
add fullandpiecesewise graph.
Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent ec1d2b5 commit 6d2bc7a

File tree

3 files changed

+246
-91
lines changed

3 files changed

+246
-91
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 151 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242

4343
from ..utils import weak_ref_tensors
4444

45-
4645
class AscendAttentionBackend(AttentionBackend):
4746
accept_output_buffer: bool = True
4847

@@ -149,6 +148,9 @@ class AscendMetadata:
149148
actual_seq_lengths_q: List[int] = None # type: ignore
150149

151150
query_start_loc: torch.Tensor = None
151+
seq_lens_list: List[int] = None
152+
153+
query_start_loc_list: List[int] = None
152154
query_lens: torch.Tensor = None
153155
# Maximum query length in the batch (None for decoding).
154156
max_query_len: Optional[int] = None
@@ -255,8 +257,10 @@ def build(
255257
attn_metadata = AscendMetadata(
256258
num_actual_tokens=num_actual_tokens,
257259
block_tables=block_table,
258-
query_start_loc=query_start_loc,
260+
query_start_loc=query_start_loc_cpu,
261+
query_start_loc_list=query_start_loc_cpu[1:].cpu().int().tolist(),
259262
query_lens=query_lens,
263+
seq_lens_list=seq_lens.cpu().int().tolist(),
260264
seq_lens=seq_lens,
261265
seq_lens_list=seq_lens.tolist(),
262266
max_query_len=common_attn_metadata.max_query_len,
@@ -427,13 +431,136 @@ def _forward_decode_only(
427431
else:
428432
graph_params = get_graph_params()
429433
forward_context: ForwardContext = get_forward_context()
430-
num_tokens = query.shape[0]
431434
if forward_context.capturing:
432-
if self.torch_npu_check:
435+
if torch.version.cann.startswith("8.3"):
436+
# Prepare tensors for attention output
437+
query_start_loc = attn_metadata.query_start_loc_list
438+
seq_lens = attn_metadata.seq_lens_list
439+
num_tokens = query_start_loc[-1]
440+
query = query[:num_tokens]
441+
433442
# Get workspace from cache or calculate it if not present.
434443
workspace = graph_params.workspaces.get(num_tokens)
444+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
445+
key = self.key_cache.view( # type: ignore
446+
num_block, block_size, -1)
447+
value = self.value_cache.view( # type: ignore
448+
num_block, block_size, -1)
449+
softmax_lse = torch.empty(num_tokens,
450+
dtype=query.dtype,
451+
device=query.device)
435452
if workspace is None:
436-
workspace = torch_npu._npu_paged_attention_get_workspace(
453+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
454+
query=query,
455+
key=key,
456+
value=value,
457+
block_table=attn_metadata.block_tables,
458+
input_layout="TND",
459+
block_size=block_size,
460+
actual_seq_lengths=query_start_loc,
461+
actual_seq_lengths_kv=seq_lens,
462+
num_key_value_heads=self.num_kv_heads,
463+
num_heads=self.num_heads,
464+
sparse_mode=0,
465+
scale=self.scale,)
466+
update_graph_params_workspaces(num_tokens, workspace)
467+
468+
# Handle graph capturing mode
469+
stream = torch_npu.npu.current_stream()
470+
471+
event = torch.npu.ExternalEvent()
472+
event.wait(stream)
473+
event.reset(stream)
474+
graph_params.events[num_tokens].append(event)
475+
graph_params.attn_params[num_tokens].append((
476+
weak_ref_tensors(query),
477+
weak_ref_tensors(key),
478+
weak_ref_tensors(value),
479+
weak_ref_tensors(attn_metadata.block_tables),
480+
block_size,
481+
seq_lens,
482+
query_start_loc,
483+
self.num_kv_heads,
484+
self.num_heads,
485+
self.scale,
486+
weak_ref_tensors(output),
487+
weak_ref_tensors(softmax_lse)
488+
))
489+
490+
torch.npu.graph_task_group_begin(stream)
491+
torch_npu.npu_fused_infer_attention_score.out(
492+
query=query,
493+
key=key,
494+
value=value,
495+
block_table=attn_metadata.block_tables,
496+
input_layout="TND",
497+
block_size=block_size,
498+
actual_seq_lengths=query_start_loc,
499+
actual_seq_lengths_kv=seq_lens,
500+
num_key_value_heads=self.num_kv_heads,
501+
num_heads=self.num_heads,
502+
scale=self.scale,
503+
sparse_mode=0,
504+
workspace=workspace,
505+
out=[output, softmax_lse],
506+
)
507+
508+
output = output.view(num_tokens, self.num_heads,
509+
self.head_size)
510+
511+
handle = torch.npu.graph_task_group_end(stream)
512+
graph_params.handles[num_tokens].append(handle)
513+
else:
514+
if self.torch_npu_check:
515+
# Get workspace from cache or calculate it if not present.
516+
workspace = graph_params.workspaces.get(num_tokens)
517+
if workspace is None:
518+
workspace = torch_npu._npu_paged_attention_get_workspace(
519+
query=query,
520+
key_cache=self.key_cache,
521+
value_cache=self.value_cache,
522+
num_kv_heads=self.num_kv_heads,
523+
num_heads=self.num_heads,
524+
scale_value=self.scale,
525+
block_table=attn_metadata.block_tables,
526+
context_lens=attn_metadata.seq_lens,
527+
out=output)
528+
update_graph_params_workspaces(num_tokens, workspace)
529+
# Handle graph capturing mode
530+
stream = torch_npu.npu.current_stream()
531+
532+
event = torch.npu.ExternalEvent()
533+
event.wait(stream)
534+
event.reset(stream)
535+
graph_params.events[num_tokens].append(event)
536+
graph_params.attn_params[num_tokens].append((
537+
weak_ref_tensors(query),
538+
weak_ref_tensors(self.key_cache),
539+
weak_ref_tensors(self.value_cache),
540+
self.num_kv_heads,
541+
self.num_heads,
542+
self.scale,
543+
weak_ref_tensors(attn_metadata.block_tables),
544+
attn_metadata.seq_lens,
545+
weak_ref_tensors(output),
546+
))
547+
548+
torch.npu.graph_task_group_begin(stream)
549+
550+
if self.torch_npu_check:
551+
torch_npu._npu_paged_attention(
552+
query=query,
553+
key_cache=self.key_cache,
554+
value_cache=self.value_cache,
555+
num_kv_heads=self.num_kv_heads,
556+
num_heads=self.num_heads,
557+
scale_value=self.scale,
558+
block_table=attn_metadata.block_tables,
559+
context_lens=attn_metadata.seq_lens,
560+
out=output,
561+
workspace=workspace)
562+
else:
563+
torch_npu._npu_paged_attention(
437564
query=query,
438565
key_cache=self.key_cache,
439566
value_cache=self.value_cache,
@@ -443,41 +570,27 @@ def _forward_decode_only(
443570
block_table=attn_metadata.block_tables,
444571
context_lens=attn_metadata.seq_lens,
445572
out=output)
446-
update_graph_params_workspaces(num_tokens, workspace)
447-
448-
# Handle graph capturing mode
449-
stream = torch_npu.npu.current_stream()
450-
451-
event = torch.npu.ExternalEvent()
452-
event.wait(stream)
453-
event.reset(stream)
454-
graph_params.events[num_tokens].append(event)
455-
graph_params.attn_params[num_tokens].append((
456-
weak_ref_tensors(query),
457-
weak_ref_tensors(self.key_cache),
458-
weak_ref_tensors(self.value_cache),
459-
self.num_kv_heads,
460-
self.num_heads,
461-
self.scale,
462-
weak_ref_tensors(attn_metadata.block_tables),
463-
attn_metadata.seq_lens,
464-
weak_ref_tensors(output),
465-
))
466-
467-
torch.npu.graph_task_group_begin(stream)
468-
469-
if self.torch_npu_check:
470-
torch_npu._npu_paged_attention(
573+
else:
574+
if torch.version.cann.startswith("8.3"):
575+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
576+
key = self.key_cache.view(
577+
num_block, block_size, -1)
578+
value = self.value_cache.view(
579+
num_block, block_size, -1)
580+
output, _ = torch_npu.npu_fused_infer_attention_score(
471581
query=query,
472-
key_cache=self.key_cache,
473-
value_cache=self.value_cache,
474-
num_kv_heads=self.num_kv_heads,
475-
num_heads=self.num_heads,
476-
scale_value=self.scale,
582+
key=key,
583+
value=value,
477584
block_table=attn_metadata.block_tables,
478-
context_lens=attn_metadata.seq_lens,
479-
out=output,
480-
workspace=workspace)
585+
input_layout="TND",
586+
block_size=block_size,
587+
actual_seq_lengths=attn_metadata.query_start_loc_list,
588+
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
589+
num_key_value_heads=self.num_kv_heads,
590+
num_heads=self.num_heads,
591+
scale=self.scale,
592+
sparse_mode=0
593+
)
481594
else:
482595
torch_npu._npu_paged_attention(
483596
query=query,
@@ -489,19 +602,6 @@ def _forward_decode_only(
489602
block_table=attn_metadata.block_tables,
490603
context_lens=attn_metadata.seq_lens,
491604
out=output)
492-
handle = torch.npu.graph_task_group_end(stream)
493-
graph_params.handles[num_tokens].append(handle)
494-
else:
495-
torch_npu._npu_paged_attention(
496-
query=query,
497-
key_cache=self.key_cache,
498-
value_cache=self.value_cache,
499-
num_kv_heads=self.num_kv_heads,
500-
num_heads=self.num_heads,
501-
scale_value=self.scale,
502-
block_table=attn_metadata.block_tables,
503-
context_lens=attn_metadata.seq_lens,
504-
out=output)
505605
return output
506606

507607
def _forward_v1_style(

vllm_ascend/compilation/acl_graph.py

Lines changed: 83 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -201,48 +201,87 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
201201
graph_params.handles[runtime_shape],
202202
graph_params.events[runtime_shape],
203203
):
204-
(
205-
query,
206-
key_cache,
207-
value_cache,
208-
num_kv_heads,
209-
num_heads,
210-
scale,
211-
block_table,
212-
seq_lens,
213-
output,
214-
) = param
215-
# block_table = forward_context.attn_metadata[key].block_tables
216-
seq_lens = forward_context.attn_metadata[key].seq_lens
217-
torch_npu_check = version_check()
218-
219-
with torch.npu.stream(update_stream):
220-
torch.npu.graph_task_update_begin(update_stream, handle)
221-
if torch_npu_check:
222-
torch_npu._npu_paged_attention(
204+
if torch.version.cann.startswith("8.3"):
205+
(
206+
query,
207+
key_cache,
208+
value,
209+
block_tables,
210+
block_size,
211+
seq_lens,
212+
query_start_loc,
213+
num_kv_heads,
214+
num_heads,
215+
scale,
216+
attn_output,
217+
softmax_lse
218+
) = param
219+
220+
# block_table = forward_context.attn_metadata[key].block_tables
221+
seq_lens = forward_context.attn_metadata[key].seq_lens
222+
223+
with torch.npu.stream(update_stream):
224+
torch.npu.graph_task_update_begin(update_stream, handle)
225+
torch_npu.npu_fused_infer_attention_score.out(
223226
query=query,
224-
key_cache=key_cache,
225-
value_cache=value_cache,
226-
num_kv_heads=num_kv_heads,
227+
key=key_cache,
228+
value=value,
229+
block_table=block_tables,
230+
input_layout="TND",
231+
block_size=block_size,
232+
actual_seq_lengths=query_start_loc,
233+
actual_seq_lengths_kv=seq_lens,
234+
num_key_value_heads=num_kv_heads,
227235
num_heads=num_heads,
228-
scale_value=scale,
229-
block_table=block_table,
230-
context_lens=seq_lens,
231-
out=output,
232-
workspace=graph_params.workspaces.get(runtime_shape))
233-
else:
234-
torch_npu._npu_paged_attention(query=query,
235-
key_cache=key_cache,
236-
value_cache=value_cache,
237-
num_kv_heads=num_kv_heads,
238-
num_heads=num_heads,
239-
scale_value=scale,
240-
block_table=block_table,
241-
context_lens=seq_lens,
242-
out=output)
243-
torch.npu.graph_task_update_end(update_stream)
244-
245-
event.record(update_stream)
236+
scale=scale,
237+
sparse_mode=0,
238+
workspace=graph_params.workspaces.get(runtime_shape),
239+
out=[attn_output, softmax_lse],
240+
)
241+
torch.npu.graph_task_update_end(update_stream)
242+
243+
event.record(update_stream)
244+
else:
245+
(
246+
query,
247+
key_cache,
248+
value_cache,
249+
num_kv_heads,
250+
num_heads,
251+
scale,
252+
block_table,
253+
seq_lens,
254+
output,
255+
) = param
256+
# block_table = forward_context.attn_metadata[key].block_tables
257+
seq_lens = forward_context.attn_metadata[key].seq_lens
258+
torch_npu_check = version_check()
259+
260+
with torch.npu.stream(update_stream):
261+
torch.npu.graph_task_update_begin(update_stream, handle)
262+
if torch_npu_check:
263+
torch_npu._npu_paged_attention(
264+
query=query,
265+
key_cache=key_cache,
266+
value_cache=value_cache,
267+
num_kv_heads=num_kv_heads,
268+
num_heads=num_heads,
269+
scale_value=scale,
270+
block_table=block_table,
271+
context_lens=seq_lens,
272+
out=output,
273+
workspace=graph_params.workspaces.get(runtime_shape))
274+
else:
275+
torch_npu._npu_paged_attention(query=query,
276+
key_cache=key_cache,
277+
value_cache=value_cache,
278+
num_kv_heads=num_kv_heads,
279+
num_heads=num_heads,
280+
scale_value=scale,
281+
block_table=block_table,
282+
context_lens=seq_lens,
283+
out=output)
284+
torch.npu.graph_task_update_end(update_stream)
246285

247286

248287
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
@@ -328,6 +367,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
328367
for size in aclgraph_capture_sizes},
329368
)
330369

370+
def update_graph_params_workspaces(num_tokens: int, workspace: int):
371+
global _graph_params
372+
if _graph_params is not None:
373+
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
331374

332375
def update_graph_params_workspaces(num_tokens: int, workspace: int):
333376
global _graph_params

0 commit comments

Comments
 (0)