Skip to content

Commit 27b7aab

Browse files
author
wangxiaoxin-sherie
committed
add fullandpiecesewise graph.
Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent 9ff6b0b commit 27b7aab

File tree

3 files changed

+246
-93
lines changed

3 files changed

+246
-93
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 151 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343

4444
from ..utils import weak_ref_tensors
4545

46-
4746
class AscendAttentionBackend(AttentionBackend):
4847
accept_output_buffer: bool = True
4948

@@ -144,6 +143,9 @@ class AscendMetadata:
144143
seq_lens: torch.Tensor = None
145144

146145
query_start_loc: torch.Tensor = None
146+
seq_lens_list: List[int] = None
147+
148+
query_start_loc_list: List[int] = None
147149
query_lens: torch.Tensor = None
148150
# Maximum query length in the batch (None for decoding).
149151
max_query_len: Optional[int] = None
@@ -211,8 +213,6 @@ def build(
211213
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
212214
num_reqs
213215
+ 1]
214-
query_start_loc = query_start_loc_cpu.to(self.device,
215-
non_blocking=True)
216216

217217
if is_310p():
218218
if attn_state == AscendAttentionState.PrefillNoCache:
@@ -227,8 +227,10 @@ def build(
227227
attn_metadata = AscendMetadata(
228228
num_actual_tokens=num_actual_tokens,
229229
block_tables=block_table,
230-
query_start_loc=query_start_loc,
230+
query_start_loc=query_start_loc_cpu,
231+
query_start_loc_list=query_start_loc_cpu[1:].cpu().int().tolist(),
231232
query_lens=query_lens,
233+
seq_lens_list=seq_lens.cpu().int().tolist(),
232234
seq_lens=seq_lens,
233235
max_query_len=common_attn_metadata.max_query_len,
234236
slot_mapping=slot_mapping,
@@ -397,13 +399,136 @@ def _forward_decode_only(
397399
else:
398400
graph_params = get_graph_params()
399401
forward_context: ForwardContext = get_forward_context()
400-
num_tokens = query.shape[0]
401402
if forward_context.capturing:
402-
if self.torch_npu_check:
403+
if torch.version.cann.startswith("8.3"):
404+
# Prepare tensors for attention output
405+
query_start_loc = attn_metadata.query_start_loc_list
406+
seq_lens = attn_metadata.seq_lens_list
407+
num_tokens = query_start_loc[-1]
408+
query = query[:num_tokens]
409+
403410
# Get workspace from cache or calculate it if not present.
404411
workspace = graph_params.workspaces.get(num_tokens)
412+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
413+
key = self.key_cache.view( # type: ignore
414+
num_block, block_size, -1)
415+
value = self.value_cache.view( # type: ignore
416+
num_block, block_size, -1)
417+
softmax_lse = torch.empty(num_tokens,
418+
dtype=query.dtype,
419+
device=query.device)
405420
if workspace is None:
406-
workspace = torch_npu._npu_paged_attention_get_workspace(
421+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
422+
query=query,
423+
key=key,
424+
value=value,
425+
block_table=attn_metadata.block_tables,
426+
input_layout="TND",
427+
block_size=block_size,
428+
actual_seq_lengths=query_start_loc,
429+
actual_seq_lengths_kv=seq_lens,
430+
num_key_value_heads=self.num_kv_heads,
431+
num_heads=self.num_heads,
432+
sparse_mode=0,
433+
scale=self.scale,)
434+
update_graph_params_workspaces(num_tokens, workspace)
435+
436+
# Handle graph capturing mode
437+
stream = torch_npu.npu.current_stream()
438+
439+
event = torch.npu.ExternalEvent()
440+
event.wait(stream)
441+
event.reset(stream)
442+
graph_params.events[num_tokens].append(event)
443+
graph_params.attn_params[num_tokens].append((
444+
weak_ref_tensors(query),
445+
weak_ref_tensors(key),
446+
weak_ref_tensors(value),
447+
weak_ref_tensors(attn_metadata.block_tables),
448+
block_size,
449+
seq_lens,
450+
query_start_loc,
451+
self.num_kv_heads,
452+
self.num_heads,
453+
self.scale,
454+
weak_ref_tensors(output),
455+
weak_ref_tensors(softmax_lse)
456+
))
457+
458+
torch.npu.graph_task_group_begin(stream)
459+
torch_npu.npu_fused_infer_attention_score.out(
460+
query=query,
461+
key=key,
462+
value=value,
463+
block_table=attn_metadata.block_tables,
464+
input_layout="TND",
465+
block_size=block_size,
466+
actual_seq_lengths=query_start_loc,
467+
actual_seq_lengths_kv=seq_lens,
468+
num_key_value_heads=self.num_kv_heads,
469+
num_heads=self.num_heads,
470+
scale=self.scale,
471+
sparse_mode=0,
472+
workspace=workspace,
473+
out=[output, softmax_lse],
474+
)
475+
476+
output = output.view(num_tokens, self.num_heads,
477+
self.head_size)
478+
479+
handle = torch.npu.graph_task_group_end(stream)
480+
graph_params.handles[num_tokens].append(handle)
481+
else:
482+
if self.torch_npu_check:
483+
# Get workspace from cache or calculate it if not present.
484+
workspace = graph_params.workspaces.get(num_tokens)
485+
if workspace is None:
486+
workspace = torch_npu._npu_paged_attention_get_workspace(
487+
query=query,
488+
key_cache=self.key_cache,
489+
value_cache=self.value_cache,
490+
num_kv_heads=self.num_kv_heads,
491+
num_heads=self.num_heads,
492+
scale_value=self.scale,
493+
block_table=attn_metadata.block_tables,
494+
context_lens=attn_metadata.seq_lens,
495+
out=output)
496+
update_graph_params_workspaces(num_tokens, workspace)
497+
# Handle graph capturing mode
498+
stream = torch_npu.npu.current_stream()
499+
500+
event = torch.npu.ExternalEvent()
501+
event.wait(stream)
502+
event.reset(stream)
503+
graph_params.events[num_tokens].append(event)
504+
graph_params.attn_params[num_tokens].append((
505+
weak_ref_tensors(query),
506+
weak_ref_tensors(self.key_cache),
507+
weak_ref_tensors(self.value_cache),
508+
self.num_kv_heads,
509+
self.num_heads,
510+
self.scale,
511+
weak_ref_tensors(attn_metadata.block_tables),
512+
attn_metadata.seq_lens,
513+
weak_ref_tensors(output),
514+
))
515+
516+
torch.npu.graph_task_group_begin(stream)
517+
518+
if self.torch_npu_check:
519+
torch_npu._npu_paged_attention(
520+
query=query,
521+
key_cache=self.key_cache,
522+
value_cache=self.value_cache,
523+
num_kv_heads=self.num_kv_heads,
524+
num_heads=self.num_heads,
525+
scale_value=self.scale,
526+
block_table=attn_metadata.block_tables,
527+
context_lens=attn_metadata.seq_lens,
528+
out=output,
529+
workspace=workspace)
530+
else:
531+
torch_npu._npu_paged_attention(
407532
query=query,
408533
key_cache=self.key_cache,
409534
value_cache=self.value_cache,
@@ -413,41 +538,27 @@ def _forward_decode_only(
413538
block_table=attn_metadata.block_tables,
414539
context_lens=attn_metadata.seq_lens,
415540
out=output)
416-
update_graph_params_workspaces(num_tokens, workspace)
417-
418-
# Handle graph capturing mode
419-
stream = torch_npu.npu.current_stream()
420-
421-
event = torch.npu.ExternalEvent()
422-
event.wait(stream)
423-
event.reset(stream)
424-
graph_params.events[num_tokens].append(event)
425-
graph_params.attn_params[num_tokens].append((
426-
weak_ref_tensors(query),
427-
weak_ref_tensors(self.key_cache),
428-
weak_ref_tensors(self.value_cache),
429-
self.num_kv_heads,
430-
self.num_heads,
431-
self.scale,
432-
weak_ref_tensors(attn_metadata.block_tables),
433-
attn_metadata.seq_lens,
434-
weak_ref_tensors(output),
435-
))
436-
437-
torch.npu.graph_task_group_begin(stream)
438-
439-
if self.torch_npu_check:
440-
torch_npu._npu_paged_attention(
541+
else:
542+
if torch.version.cann.startswith("8.3"):
543+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
544+
key = self.key_cache.view(
545+
num_block, block_size, -1)
546+
value = self.value_cache.view(
547+
num_block, block_size, -1)
548+
output, _ = torch_npu.npu_fused_infer_attention_score(
441549
query=query,
442-
key_cache=self.key_cache,
443-
value_cache=self.value_cache,
444-
num_kv_heads=self.num_kv_heads,
445-
num_heads=self.num_heads,
446-
scale_value=self.scale,
550+
key=key,
551+
value=value,
447552
block_table=attn_metadata.block_tables,
448-
context_lens=attn_metadata.seq_lens,
449-
out=output,
450-
workspace=workspace)
553+
input_layout="TND",
554+
block_size=block_size,
555+
actual_seq_lengths=attn_metadata.query_start_loc_list,
556+
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
557+
num_key_value_heads=self.num_kv_heads,
558+
num_heads=self.num_heads,
559+
scale=self.scale,
560+
sparse_mode=0
561+
)
451562
else:
452563
torch_npu._npu_paged_attention(
453564
query=query,
@@ -459,19 +570,6 @@ def _forward_decode_only(
459570
block_table=attn_metadata.block_tables,
460571
context_lens=attn_metadata.seq_lens,
461572
out=output)
462-
handle = torch.npu.graph_task_group_end(stream)
463-
graph_params.handles[num_tokens].append(handle)
464-
else:
465-
torch_npu._npu_paged_attention(
466-
query=query,
467-
key_cache=self.key_cache,
468-
value_cache=self.value_cache,
469-
num_kv_heads=self.num_kv_heads,
470-
num_heads=self.num_heads,
471-
scale_value=self.scale,
472-
block_table=attn_metadata.block_tables,
473-
context_lens=attn_metadata.seq_lens,
474-
out=output)
475573
return output
476574

477575
def _forward_v1_style(

vllm_ascend/compilation/acl_graph.py

Lines changed: 83 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -201,48 +201,87 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
201201
graph_params.handles[runtime_shape],
202202
graph_params.events[runtime_shape],
203203
):
204-
(
205-
query,
206-
key_cache,
207-
value_cache,
208-
num_kv_heads,
209-
num_heads,
210-
scale,
211-
block_table,
212-
seq_lens,
213-
output,
214-
) = param
215-
# block_table = forward_context.attn_metadata[key].block_tables
216-
seq_lens = forward_context.attn_metadata[key].seq_lens
217-
torch_npu_check = version_check()
218-
219-
with torch.npu.stream(update_stream):
220-
torch.npu.graph_task_update_begin(update_stream, handle)
221-
if torch_npu_check:
222-
torch_npu._npu_paged_attention(
204+
if torch.version.cann.startswith("8.3"):
205+
(
206+
query,
207+
key_cache,
208+
value,
209+
block_tables,
210+
block_size,
211+
seq_lens,
212+
query_start_loc,
213+
num_kv_heads,
214+
num_heads,
215+
scale,
216+
attn_output,
217+
softmax_lse
218+
) = param
219+
220+
# block_table = forward_context.attn_metadata[key].block_tables
221+
seq_lens = forward_context.attn_metadata[key].seq_lens
222+
223+
with torch.npu.stream(update_stream):
224+
torch.npu.graph_task_update_begin(update_stream, handle)
225+
torch_npu.npu_fused_infer_attention_score.out(
223226
query=query,
224-
key_cache=key_cache,
225-
value_cache=value_cache,
226-
num_kv_heads=num_kv_heads,
227+
key=key_cache,
228+
value=value,
229+
block_table=block_tables,
230+
input_layout="TND",
231+
block_size=block_size,
232+
actual_seq_lengths=query_start_loc,
233+
actual_seq_lengths_kv=seq_lens,
234+
num_key_value_heads=num_kv_heads,
227235
num_heads=num_heads,
228-
scale_value=scale,
229-
block_table=block_table,
230-
context_lens=seq_lens,
231-
out=output,
232-
workspace=graph_params.workspaces.get(runtime_shape))
233-
else:
234-
torch_npu._npu_paged_attention(query=query,
235-
key_cache=key_cache,
236-
value_cache=value_cache,
237-
num_kv_heads=num_kv_heads,
238-
num_heads=num_heads,
239-
scale_value=scale,
240-
block_table=block_table,
241-
context_lens=seq_lens,
242-
out=output)
243-
torch.npu.graph_task_update_end(update_stream)
244-
245-
event.record(update_stream)
236+
scale=scale,
237+
sparse_mode=0,
238+
workspace=graph_params.workspaces.get(runtime_shape),
239+
out=[attn_output, softmax_lse],
240+
)
241+
torch.npu.graph_task_update_end(update_stream)
242+
243+
event.record(update_stream)
244+
else:
245+
(
246+
query,
247+
key_cache,
248+
value_cache,
249+
num_kv_heads,
250+
num_heads,
251+
scale,
252+
block_table,
253+
seq_lens,
254+
output,
255+
) = param
256+
# block_table = forward_context.attn_metadata[key].block_tables
257+
seq_lens = forward_context.attn_metadata[key].seq_lens
258+
torch_npu_check = version_check()
259+
260+
with torch.npu.stream(update_stream):
261+
torch.npu.graph_task_update_begin(update_stream, handle)
262+
if torch_npu_check:
263+
torch_npu._npu_paged_attention(
264+
query=query,
265+
key_cache=key_cache,
266+
value_cache=value_cache,
267+
num_kv_heads=num_kv_heads,
268+
num_heads=num_heads,
269+
scale_value=scale,
270+
block_table=block_table,
271+
context_lens=seq_lens,
272+
out=output,
273+
workspace=graph_params.workspaces.get(runtime_shape))
274+
else:
275+
torch_npu._npu_paged_attention(query=query,
276+
key_cache=key_cache,
277+
value_cache=value_cache,
278+
num_kv_heads=num_kv_heads,
279+
num_heads=num_heads,
280+
scale_value=scale,
281+
block_table=block_table,
282+
context_lens=seq_lens,
283+
out=output)
284+
torch.npu.graph_task_update_end(update_stream)
246285

247286

248287
def update_mla_attn_params(update_stream, forward_context, runtime_shape):
@@ -317,6 +356,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
317356
for size in aclgraph_capture_sizes},
318357
)
319358

359+
def update_graph_params_workspaces(num_tokens: int, workspace: int):
360+
global _graph_params
361+
if _graph_params is not None:
362+
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
320363

321364
def update_graph_params_workspaces(num_tokens: int, workspace: int):
322365
global _graph_params

0 commit comments

Comments
 (0)