Skip to content

Commit d00047a

Browse files
author
wangxiaoxin-sherie
committed
XX
1 parent 2839f2a commit d00047a

File tree

2 files changed

+129
-144
lines changed

2 files changed

+129
-144
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 121 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,110 @@ def __init__(
441441
) if self.dcp_size > 1 else 0
442442
self.dcp_group = get_dcp_group(
443443
).device_group if self.dcp_size > 1 else None
444+
445+
def full_graph_attention(self,
446+
query: torch.Tensor,
447+
key: torch.Tensor,
448+
value: torch.Tensor,
449+
attn_metadata: AscendMetadata,
450+
block_size: int,
451+
output: Optional[torch.Tensor] = None,
452+
num_tokens=0,):
453+
num_tokens = query.shape[0]
454+
forward_context: ForwardContext = get_forward_context()
455+
if forward_context.capturing:
456+
graph_params = get_graph_params()
457+
query_start_loc = attn_metadata.actual_seq_lengths_q
458+
seq_lens = attn_metadata.seq_lens_lis
459+
# Prepare tensors for attention output
460+
# TODO: Refactor this to step-level instead of layer-level
461+
462+
# Get workspace from cache or calculate it if not present.
463+
workspace = graph_params.workspaces.get(num_tokens)
464+
softmax_lse = torch.empty(num_tokens,
465+
dtype=query.dtype,
466+
device=query.device)
467+
if workspace is None:
468+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
469+
query=query,
470+
key=key,
471+
value=value,
472+
atten_mask=attn_metadata.attn_mask,
473+
block_table=attn_metadata.block_tables,
474+
input_layout="TND",
475+
block_size=block_size,
476+
actual_seq_lengths=query_start_loc,
477+
actual_seq_lengths_kv=seq_lens,
478+
num_key_value_heads=self.num_kv_heads,
479+
num_heads=self.num_heads,
480+
sparse_mode=3,
481+
scale=self.scale,)
482+
graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
483+
484+
# Handle graph capturing mode
485+
stream = torch_npu.npu.current_stream()
486+
487+
event = torch.npu.ExternalEvent()
488+
event.wait(stream)
489+
event.reset(stream)
490+
graph_params.events[num_tokens].append(event)
491+
graph_params.attn_params[num_tokens].append((
492+
weak_ref_tensors(query),
493+
weak_ref_tensors(key),
494+
weak_ref_tensors(value),
495+
weak_ref_tensors(attn_metadata.block_tables),
496+
weak_ref_tensors(attn_metadata.attn_mask),
497+
block_size,
498+
seq_lens,
499+
query_start_loc,
500+
self.num_kv_heads,
501+
self.num_heads,
502+
self.scale,
503+
weak_ref_tensors(output),
504+
weak_ref_tensors(softmax_lse)
505+
))
506+
507+
torch.npu.graph_task_group_begin(stream)
508+
torch_npu.npu_fused_infer_attention_score.out(
509+
query=query,
510+
key=key,
511+
value=value,
512+
atten_mask=attn_metadata.attn_mask,
513+
block_table=attn_metadata.block_tables,
514+
input_layout="TND",
515+
block_size=block_size,
516+
actual_seq_lengths=query_start_loc,
517+
actual_seq_lengths_kv=seq_lens,
518+
num_key_value_heads=self.num_kv_heads,
519+
num_heads=self.num_heads,
520+
scale=self.scale,
521+
sparse_mode=3,
522+
workspace=workspace,
523+
out=[output, softmax_lse],
524+
)
525+
526+
output = output.view(num_tokens, self.num_heads,
527+
self.head_size)
528+
529+
handle = torch.npu.graph_task_group_end(stream)
530+
graph_params.handles[num_tokens].append(handle)
531+
else:
532+
output, _ = torch_npu.npu_fused_infer_attention_score(
533+
query=query,
534+
key=key,
535+
value=value,
536+
block_table=attn_metadata.block_tables,
537+
atten_mask=attn_metadata.attn_mask,
538+
input_layout="TND",
539+
block_size=block_size,
540+
actual_seq_lengths=attn_metadata.query_start_loc_list,
541+
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
542+
num_key_value_heads=self.num_kv_heads,
543+
num_heads=self.num_heads,
544+
scale=self.scale,
545+
sparse_mode=3
546+
)
547+
return output
444548

445549
def _forward_prefill_no_cache(
446550
self,
@@ -467,15 +571,7 @@ def _forward_prefill_no_cache(
467571
mask = torch_npu.npu_format_cast(mask.contiguous(),
468572
ACL_FORMAT_FRACTAL_NZ)
469573

470-
torch_npu._npu_flash_attention(query=query,
471-
key=key,
472-
value=value,
473-
mask=mask,
474-
seq_len=attn_metadata.seq_lens,
475-
scale_value=self.scale,
476-
num_heads=self.num_heads,
477-
num_kv_heads=self.num_kv_heads,
478-
out=output)
574+
output = self.full_graph_attention(query, key, value, attn_metadata, 128, output)
479575
assert output is not None
480576
return output[:num_tokens]
481577

@@ -569,84 +665,12 @@ def _forward_decode_only(
569665

570666
output = output.view(batch_size, self.num_heads, self.head_size)
571667
else:
572-
graph_params = get_graph_params()
573-
forward_context: ForwardContext = get_forward_context()
574-
num_tokens = query.shape[0]
575-
if forward_context.capturing:
576-
if self.torch_npu_check:
577-
# Get workspace from cache or calculate it if not present.
578-
workspace = graph_params.workspaces.get(num_tokens)
579-
if workspace is None:
580-
workspace = torch_npu._npu_paged_attention_get_workspace(
581-
query=query,
582-
key_cache=self.key_cache,
583-
value_cache=self.value_cache,
584-
num_kv_heads=self.num_kv_heads,
585-
num_heads=self.num_heads,
586-
scale_value=self.scale,
587-
block_table=attn_metadata.block_tables,
588-
context_lens=attn_metadata.seq_lens,
589-
out=output)
590-
update_graph_params_workspaces(
591-
num_tokens, weak_ref_tensors(workspace))
592-
593-
# Handle graph capturing mode
594-
stream = torch_npu.npu.current_stream()
595-
596-
event = torch.npu.ExternalEvent()
597-
event.wait(stream)
598-
event.reset(stream)
599-
graph_params.events[num_tokens].append(event)
600-
graph_params.attn_params[num_tokens].append((
601-
weak_ref_tensors(query),
602-
weak_ref_tensors(self.key_cache),
603-
weak_ref_tensors(self.value_cache),
604-
self.num_kv_heads,
605-
self.num_heads,
606-
self.scale,
607-
attn_metadata.block_tables,
608-
attn_metadata.seq_lens,
609-
weak_ref_tensors(output),
610-
))
611-
612-
torch.npu.graph_task_group_begin(stream)
613-
614-
if self.torch_npu_check:
615-
torch_npu._npu_paged_attention(
616-
query=query,
617-
key_cache=self.key_cache,
618-
value_cache=self.value_cache,
619-
num_kv_heads=self.num_kv_heads,
620-
num_heads=self.num_heads,
621-
scale_value=self.scale,
622-
block_table=attn_metadata.block_tables,
623-
context_lens=attn_metadata.seq_lens,
624-
out=output,
625-
workspace=workspace)
626-
else:
627-
torch_npu._npu_paged_attention(
628-
query=query,
629-
key_cache=self.key_cache,
630-
value_cache=self.value_cache,
631-
num_kv_heads=self.num_kv_heads,
632-
num_heads=self.num_heads,
633-
scale_value=self.scale,
634-
block_table=attn_metadata.block_tables,
635-
context_lens=attn_metadata.seq_lens,
636-
out=output)
637-
handle = torch.npu.graph_task_group_end(stream)
638-
graph_params.handles[num_tokens].append(handle)
639-
else:
640-
torch_npu._npu_paged_attention(
641-
query=query,
642-
key_cache=self.key_cache,
643-
value_cache=self.value_cache,
644-
num_kv_heads=self.num_kv_heads,
645-
num_heads=self.num_heads,
646-
scale_value=self.scale,
647-
block_table=attn_metadata.block_tables,
648-
context_lens=attn_metadata.seq_lens,
649-
out=output)
668+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
669+
key = self.key_cache.view( # type: ignore
670+
num_block, block_size, -1)
671+
value = self.value_cache.view( # type: ignore
672+
num_block, block_size, -1)
673+
output = self.full_graph_attention(query, key, value, attn_metadata, block_size, output)
650674
return output
651675

652676
def _forward_v1_style(
@@ -687,43 +711,12 @@ def _forward_v1_style(
687711
attn_metadata.seq_lens = \
688712
attn_metadata.seq_lens.to(device=query.device)
689713

690-
if torch.version.cann.startswith("8.3"):
691-
# TODO:The npu_fused_infer_attention_score op is planned to
692-
# be utilized in a wider range in upcoming versions.
693-
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
694-
key = self.key_cache.view( # type: ignore
695-
num_block, block_size, -1)
696-
value = self.value_cache.view( # type: ignore
697-
num_block, block_size, -1)
698-
699-
output, _ = torch_npu.npu_fused_infer_attention_score(
700-
query=query,
701-
key=key,
702-
value=value,
703-
atten_mask=attn_metadata.attn_mask,
704-
block_table=attn_metadata.block_tables,
705-
input_layout="TND",
706-
block_size=block_size,
707-
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
708-
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
709-
num_key_value_heads=self.num_kv_heads,
710-
num_heads=self.num_heads,
711-
scale=self.scale,
712-
sparse_mode=3,
713-
)
714-
else:
715-
torch_npu._npu_paged_attention_splitfuse(
716-
query=query,
717-
key_cache=self.key_cache,
718-
value_cache=self.value_cache,
719-
mask=attn_metadata.attn_mask,
720-
block_table=attn_metadata.block_tables,
721-
seq_len=attn_metadata.query_lens,
722-
context_lens=attn_metadata.seq_lens,
723-
num_kv_heads=self.num_kv_heads,
724-
num_heads=self.num_heads,
725-
scale_value=self.scale,
726-
out=output)
714+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
715+
key = self.key_cache.view( # type: ignore
716+
num_block, block_size, -1)
717+
value = self.value_cache.view( # type: ignore
718+
num_block, block_size, -1)
719+
output = self.full_graph_attention(query, key, value, attn_metadata, block_size, output)
727720
return output
728721

729722
def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor,
@@ -1161,26 +1154,18 @@ def forward(
11611154
)[0]
11621155
# V0-Style scheduler situation.
11631156
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
1164-
intermediate_output = self._forward_prefill_no_cache(
1157+
output = self._forward_prefill_no_cache(
11651158
query, key, value, attn_metadata, output, num_tokens)
11661159
elif attn_metadata.attn_state == \
11671160
AscendAttentionState.PrefillCacheHit:
1168-
intermediate_output = self._forward_prefill_cache_hit(
1161+
output = self._forward_prefill_cache_hit(
11691162
query, attn_metadata, output)
11701163
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
1171-
intermediate_output = self._forward_decode_only(
1164+
output = self._forward_decode_only(
11721165
query, attn_metadata, output)
11731166
# Normal V1 situation.
11741167
else:
1175-
if torch.version.cann.startswith("8.3"):
1176-
# npu_fused_infer_attention_score does not support cases
1177-
# where query.shape[0] != attn_metadata.query_start_loc[-1].
1178-
# Thus we need unpad it here.
1179-
num_tokens = attn_metadata.query_start_loc[-1]
1180-
query = query[:num_tokens]
1181-
intermediate_output = self._forward_v1_style(
1168+
output = self._forward_v1_style(
11821169
query, attn_metadata, output)
11831170

1184-
output[:num_tokens] = intermediate_output[:num_tokens]
1185-
11861171
return output

vllm_ascend/platform.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
240240
vllm_config.compilation_config.init_with_cudagraph_sizes(
241241
sp_aclgraph_sizes)
242242

243-
# TODO: Full graph is fully supported later, and the default value will be set to full graph.
244-
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
245-
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
246-
247243
if vllm_version_is("0.11.0"):
248244
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
249245
compilation_config.level = CompilationLevel.NO_COMPILATION
250-
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
246+
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or\
247+
compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
251248
logger.info(
252249
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
253250
"using only ACL Graph mode")
@@ -260,7 +257,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
260257
"vllm.mla_forward"
261258
])
262259
update_aclgraph_sizes(vllm_config)
263-
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
260+
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
261+
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
264262
logger.info(
265263
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
266264
"using only ACL Graph mode")
@@ -287,7 +285,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
287285
else:
288286
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
289287
compilation_config.mode = CompilationMode.NONE
290-
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
288+
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or\
289+
compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
291290
logger.info(
292291
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
293292
"using only ACL Graph mode")
@@ -297,7 +296,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
297296
compilation_config.use_inductor = False
298297
compilation_config.splitting_ops.extend(["vllm::mla_forward"])
299298
update_aclgraph_sizes(vllm_config)
300-
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
299+
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\
300+
compilation_config.cudagraph_mode == CUDAGraphMode.FULL:
301301
logger.info(
302302
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
303303
"using only ACL Graph mode")

0 commit comments

Comments
 (0)