Skip to content

Commit bcda351

Browse files
author
wangxiaoxin-sherie
committed
add fullandpiecesewise graph.
Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent 19b85ef commit bcda351

File tree

3 files changed

+236
-76
lines changed

3 files changed

+236
-76
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 151 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@
3434
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
3535
maybe_save_kv_layer_to_connector,
3636
wait_for_kv_layer_from_connector)
37-
from vllm_ascend.compilation.acl_graph import get_graph_params
37+
from vllm_ascend.compilation.acl_graph import (get_graph_params,
38+
update_graph_params_workspaces)
3839
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3940
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
4041
nd_to_nz_2d, nd_to_nz_spec)
4142

4243
from ..utils import weak_ref_tensors
4344

44-
4545
class AscendAttentionBackend(AttentionBackend):
4646
accept_output_buffer: bool = True
4747

@@ -142,6 +142,9 @@ class AscendMetadata:
142142
seq_lens: torch.Tensor = None
143143

144144
query_start_loc: torch.Tensor = None
145+
seq_lens_list: List[int] = None
146+
147+
query_start_loc_list: List[int] = None
145148
query_lens: torch.Tensor = None
146149
# Maximum query length in the batch (None for decoding).
147150
max_query_len: Optional[int] = None
@@ -209,8 +212,6 @@ def build(
209212
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
210213
num_reqs
211214
+ 1]
212-
query_start_loc = query_start_loc_cpu.to(self.device,
213-
non_blocking=True)
214215

215216
if is_310p():
216217
if attn_state == AscendAttentionState.PrefillNoCache:
@@ -225,8 +226,10 @@ def build(
225226
attn_metadata = AscendMetadata(
226227
num_actual_tokens=num_actual_tokens,
227228
block_tables=block_table,
228-
query_start_loc=query_start_loc,
229+
query_start_loc=query_start_loc_cpu,
230+
query_start_loc_list=query_start_loc_cpu[1:].cpu().int().tolist(),
229231
query_lens=query_lens,
232+
seq_lens_list=seq_lens.cpu().int().tolist(),
230233
seq_lens=seq_lens,
231234
max_query_len=common_attn_metadata.max_query_len,
232235
slot_mapping=slot_mapping,
@@ -394,51 +397,151 @@ def _forward_decode_only(
394397
else:
395398
graph_params = get_graph_params()
396399
forward_context: ForwardContext = get_forward_context()
397-
num_tokens = query.shape[0]
398400
if forward_context.capturing:
399-
stream = torch_npu.npu.current_stream()
400-
401-
event = torch.npu.ExternalEvent()
402-
event.wait(stream)
403-
event.reset(stream)
404-
graph_params.events[num_tokens].append(event)
405-
406-
graph_params.attn_params[num_tokens].append((
407-
weak_ref_tensors(query),
408-
weak_ref_tensors(self.key_cache),
409-
weak_ref_tensors(self.value_cache),
410-
self.num_kv_heads,
411-
self.num_heads,
412-
self.scale,
413-
weak_ref_tensors(attn_metadata.block_tables),
414-
attn_metadata.seq_lens,
415-
weak_ref_tensors(output),
416-
))
417-
418-
torch.npu.graph_task_group_begin(stream)
419-
torch_npu._npu_paged_attention(
420-
query=query,
421-
key_cache=self.key_cache,
422-
value_cache=self.value_cache,
423-
num_kv_heads=self.num_kv_heads,
424-
num_heads=self.num_heads,
425-
scale_value=self.scale,
426-
block_table=attn_metadata.block_tables,
427-
context_lens=attn_metadata.seq_lens,
428-
out=output)
429-
handle = torch.npu.graph_task_group_end(stream)
430-
graph_params.handles[num_tokens].append(handle)
401+
if torch.version.cann.startswith("8.3"):
402+
# Prepare tensors for attention output
403+
query_start_loc = attn_metadata.query_start_loc_list
404+
seq_lens = attn_metadata.seq_lens_list
405+
num_tokens = query_start_loc[-1]
406+
query = query[:num_tokens]
407+
408+
# Get workspace from cache or calculate it if not present.
409+
workspace = graph_params.workspaces.get(num_tokens)
410+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
411+
key = self.key_cache.view( # type: ignore
412+
num_block, block_size, -1)
413+
value = self.value_cache.view( # type: ignore
414+
num_block, block_size, -1)
415+
softmax_lse = torch.empty(num_tokens,
416+
dtype=query.dtype,
417+
device=query.device)
418+
if workspace is None:
419+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
420+
query=query,
421+
key=key,
422+
value=value,
423+
block_table=attn_metadata.block_tables,
424+
input_layout="TND",
425+
block_size=block_size,
426+
actual_seq_lengths=query_start_loc,
427+
actual_seq_lengths_kv=seq_lens,
428+
num_key_value_heads=self.num_kv_heads,
429+
num_heads=self.num_heads,
430+
sparse_mode=0,
431+
scale=self.scale,)
432+
update_graph_params_workspaces(num_tokens, workspace)
433+
434+
# Handle graph capturing mode
435+
stream = torch_npu.npu.current_stream()
436+
437+
event = torch.npu.ExternalEvent()
438+
event.wait(stream)
439+
event.reset(stream)
440+
graph_params.events[num_tokens].append(event)
441+
graph_params.attn_params[num_tokens].append((
442+
weak_ref_tensors(query),
443+
weak_ref_tensors(key),
444+
weak_ref_tensors(value),
445+
weak_ref_tensors(attn_metadata.block_tables),
446+
block_size,
447+
seq_lens,
448+
query_start_loc,
449+
self.num_kv_heads,
450+
self.num_heads,
451+
self.scale,
452+
weak_ref_tensors(output),
453+
weak_ref_tensors(softmax_lse)
454+
))
455+
456+
torch.npu.graph_task_group_begin(stream)
457+
torch_npu.npu_fused_infer_attention_score.out(
458+
query=query,
459+
key=key,
460+
value=value,
461+
block_table=attn_metadata.block_tables,
462+
input_layout="TND",
463+
block_size=block_size,
464+
actual_seq_lengths=query_start_loc,
465+
actual_seq_lengths_kv=seq_lens,
466+
num_key_value_heads=self.num_kv_heads,
467+
num_heads=self.num_heads,
468+
scale=self.scale,
469+
sparse_mode=0,
470+
workspace=workspace,
471+
out=[output, softmax_lse],
472+
)
473+
474+
output = output.view(num_tokens, self.num_heads,
475+
self.head_size)
476+
477+
handle = torch.npu.graph_task_group_end(stream)
478+
graph_params.handles[num_tokens].append(handle)
479+
else:
480+
stream = torch_npu.npu.current_stream()
481+
482+
event = torch.npu.ExternalEvent()
483+
event.wait(stream)
484+
event.reset(stream)
485+
graph_params.events[num_tokens].append(event)
486+
487+
graph_params.attn_params[num_tokens].append((
488+
weak_ref_tensors(query),
489+
weak_ref_tensors(self.key_cache),
490+
weak_ref_tensors(self.value_cache),
491+
self.num_kv_heads,
492+
self.num_heads,
493+
self.scale,
494+
weak_ref_tensors(attn_metadata.block_tables),
495+
attn_metadata.seq_lens,
496+
weak_ref_tensors(output),
497+
))
498+
499+
torch.npu.graph_task_group_begin(stream)
500+
torch_npu._npu_paged_attention(
501+
query=query,
502+
key_cache=self.key_cache,
503+
value_cache=self.value_cache,
504+
num_kv_heads=self.num_kv_heads,
505+
num_heads=self.num_heads,
506+
scale_value=self.scale,
507+
block_table=attn_metadata.block_tables,
508+
context_lens=attn_metadata.seq_lens,
509+
out=output)
510+
handle = torch.npu.graph_task_group_end(stream)
511+
graph_params.handles[num_tokens].append(handle)
431512
else:
432-
torch_npu._npu_paged_attention(
433-
query=query,
434-
key_cache=self.key_cache,
435-
value_cache=self.value_cache,
436-
num_kv_heads=self.num_kv_heads,
437-
num_heads=self.num_heads,
438-
scale_value=self.scale,
439-
block_table=attn_metadata.block_tables,
440-
context_lens=attn_metadata.seq_lens,
441-
out=output)
513+
if torch.version.cann.startswith("8.3"):
514+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
515+
key = self.key_cache.view( # type: ignore
516+
num_block, block_size, -1)
517+
value = self.value_cache.view( # type: ignore
518+
num_block, block_size, -1)
519+
520+
output, _ = torch_npu.npu_fused_infer_attention_score(
521+
query=query,
522+
key=key,
523+
value=value,
524+
block_table=attn_metadata.block_tables,
525+
input_layout="TND",
526+
block_size=block_size,
527+
actual_seq_lengths=attn_metadata.query_start_loc_list,
528+
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
529+
num_key_value_heads=self.num_kv_heads,
530+
num_heads=self.num_heads,
531+
scale=self.scale,
532+
sparse_mode=0
533+
)
534+
else:
535+
torch_npu._npu_paged_attention(
536+
query=query,
537+
key_cache=self.key_cache,
538+
value_cache=self.value_cache,
539+
num_kv_heads=self.num_kv_heads,
540+
num_heads=self.num_heads,
541+
scale_value=self.scale,
542+
block_table=attn_metadata.block_tables,
543+
context_lens=attn_metadata.seq_lens,
544+
out=output)
442545
return output
443546

444547
def _forward_v1_style(

vllm_ascend/compilation/acl_graph.py

Lines changed: 73 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -199,34 +199,75 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
199199
graph_params.handles[runtime_shape],
200200
graph_params.events[runtime_shape],
201201
):
202-
(
203-
query,
204-
key_cache,
205-
value_cache,
206-
num_kv_heads,
207-
num_heads,
208-
scale,
209-
block_table,
210-
seq_lens,
211-
output,
212-
) = param
213-
# block_table = forward_context.attn_metadata[key].block_tables
214-
seq_lens = forward_context.attn_metadata[key].seq_lens
215-
216-
with torch.npu.stream(update_stream):
217-
torch.npu.graph_task_update_begin(update_stream, handle)
218-
torch_npu._npu_paged_attention(query=query,
219-
key_cache=key_cache,
220-
value_cache=value_cache,
221-
num_kv_heads=num_kv_heads,
222-
num_heads=num_heads,
223-
scale_value=scale,
224-
block_table=block_table,
225-
context_lens=seq_lens,
226-
out=output)
227-
torch.npu.graph_task_update_end(update_stream)
228-
229-
event.record(update_stream)
202+
if torch.version.cann.startswith("8.3"):
203+
(
204+
query,
205+
key_cache,
206+
value,
207+
block_tables,
208+
block_size,
209+
seq_lens,
210+
query_start_loc,
211+
num_kv_heads,
212+
num_heads,
213+
scale,
214+
attn_output,
215+
softmax_lse
216+
) = param
217+
218+
# block_table = forward_context.attn_metadata[key].block_tables
219+
seq_lens = forward_context.attn_metadata[key].seq_lens
220+
221+
with torch.npu.stream(update_stream):
222+
torch.npu.graph_task_update_begin(update_stream, handle)
223+
torch_npu.npu_fused_infer_attention_score.out(
224+
query=query,
225+
key=key_cache,
226+
value=value,
227+
block_table=block_tables,
228+
input_layout="TND",
229+
block_size=block_size,
230+
actual_seq_lengths=query_start_loc,
231+
actual_seq_lengths_kv=seq_lens,
232+
num_key_value_heads=num_kv_heads,
233+
num_heads=num_heads,
234+
scale=scale,
235+
sparse_mode=0,
236+
workspace=graph_params.workspaces.get(runtime_shape),
237+
out=[attn_output, softmax_lse],
238+
)
239+
torch.npu.graph_task_update_end(update_stream)
240+
241+
event.record(update_stream)
242+
else:
243+
(
244+
query,
245+
key_cache,
246+
value_cache,
247+
num_kv_heads,
248+
num_heads,
249+
scale,
250+
block_table,
251+
seq_lens,
252+
output,
253+
) = param
254+
# block_table = forward_context.attn_metadata[key].block_tables
255+
seq_lens = forward_context.attn_metadata[key].seq_lens
256+
257+
with torch.npu.stream(update_stream):
258+
torch.npu.graph_task_update_begin(update_stream, handle)
259+
torch_npu._npu_paged_attention(query=query,
260+
key_cache=key_cache,
261+
value_cache=value_cache,
262+
num_kv_heads=num_kv_heads,
263+
num_heads=num_heads,
264+
scale_value=scale,
265+
block_table=block_table,
266+
context_lens=seq_lens,
267+
out=output)
268+
torch.npu.graph_task_update_end(update_stream)
269+
270+
event.record(update_stream)
230271

231272

232273
def update_mla_attn_params(update_stream, forward_context, runtime_shape):
@@ -301,6 +342,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
301342
for size in aclgraph_capture_sizes},
302343
)
303344

345+
def update_graph_params_workspaces(num_tokens: int, workspace: int):
346+
global _graph_params
347+
if _graph_params is not None:
348+
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
304349

305350
def get_graph_params():
306351
return _graph_params

vllm_ascend/platform.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
226226
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
227227
])
228228
update_aclgraph_sizes(vllm_config)
229+
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
230+
logger.info(
231+
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
232+
"using only ACL Graph mode")
233+
assert compilation_config.level == CompilationLevel.PIECEWISE, \
234+
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
235+
compilation_config.set_splitting_ops_for_v1()
236+
compilation_config.use_inductor = False
237+
compilation_config.splitting_ops.extend([
238+
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
239+
])
240+
update_aclgraph_sizes(vllm_config)
229241
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
230242
logger.info(
231243
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "

0 commit comments

Comments
 (0)