Skip to content

Commit a33df92

Browse files
author
wangxiaoxin-sherie
committed
xx
1 parent 694e6ae commit a33df92

File tree

2 files changed

+81
-19
lines changed

2 files changed

+81
-19
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 79 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,82 @@ def __init__(
329329
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
330330
self.key_cache = None
331331
self.value_cache = None
332+
333+
def full_graph_attention(self,
334+
query: torch.Tensor,
335+
key: torch.Tensor,
336+
value: torch.Tensor,
337+
kv_cache: Tuple[torch.Tensor],
338+
attn_metadata: AscendMetadata,
339+
output: torch.Tensor,
340+
num_tokens=0):
341+
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
342+
block_size = 128
343+
block_table = None
344+
actual_seq_lengths_kv = attn_metadata.actual_seq_lengths_q
345+
if is_310p():
346+
# align q k v output tensors
347+
query = aligned_16(query)
348+
key = aligned_16(key)
349+
value = aligned_16(value)
350+
output = aligned_16(output)
351+
# do reformat in case of broadcasted tensors
352+
mask = mask.repeat(attn_metadata.seq_lens.size(0), 1, 1, 1)
353+
mask = torch_npu.npu_format_cast(mask.contiguous(),
354+
ACL_FORMAT_FRACTAL_NZ)
355+
elif attn_metadata.attn_state == \
356+
AscendAttentionState.PrefillCacheHit:
357+
batch_size = attn_metadata.query_lens.shape[0]
358+
block_table = attn_metadata.block_tables[:batch_size, :]
359+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
360+
key = self.key_cache.view( # type: ignore
361+
num_block, block_size, -1)
362+
value = self.value_cache.view( # type: ignore
363+
num_block, block_size, -1)
364+
actual_seq_lengths_kv = attn_metadata.seq_lens_list
365+
# Normal V1 situation.
366+
else:
367+
if is_310p():
368+
# Do reformat in case of broadcasted tensors.
369+
attn_metadata.attn_mask = \
370+
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(),
371+
ACL_FORMAT_FRACTAL_NZ)
372+
attn_metadata.seq_lens = \
373+
attn_metadata.seq_lens.to(device=query.device)
374+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
375+
key = self.key_cache.view( # type: ignore
376+
num_block, block_size, -1)
377+
value = self.value_cache.view( # type: ignore
378+
num_block, block_size, -1)
379+
block_table = attn_metadata.block_tables
380+
actual_seq_lengths_kv = attn_metadata.seq_lens_list
381+
382+
num_tokens = attn_metadata.actual_seq_lengths_q[-1]
383+
query = query[:num_tokens]
384+
# Prepare tensors for attention output
385+
# TODO: Refactor this to step-level instead of layer-level
386+
387+
# Get workspace from cache or calculate it if not present.
388+
output, _ = torch_npu.npu_fused_infer_attention_score(
389+
query=query,
390+
key=key,
391+
value=value,
392+
atten_mask=attn_metadata.attn_mask,
393+
block_table=block_table,
394+
input_layout="TND",
395+
block_size=block_size,
396+
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
397+
actual_seq_lengths_kv=actual_seq_lengths_kv,
398+
num_key_value_heads=self.num_kv_heads,
399+
num_heads=self.num_heads,
400+
scale=self.scale,
401+
sparse_mode=3,
402+
)
403+
404+
output = output.view(num_tokens, self.num_heads, self.head_size)
405+
406+
return output, num_tokens
407+
332408

333409
def _forward_prefill_no_cache(
334410
self,
@@ -662,24 +738,14 @@ def forward(
662738
)
663739
output = attn_out[0]
664740
# V0-Style scheduler situation.
665-
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
666-
output = self._forward_prefill_no_cache(
667-
query, key, value, attn_metadata, output, num_tokens)
668-
elif attn_metadata.attn_state == \
669-
AscendAttentionState.PrefillCacheHit:
670-
output = self._forward_prefill_cache_hit(
671-
query, attn_metadata, output)
672741
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
673742
output = self._forward_decode_only(query, attn_metadata,
674743
output)
675744
# Normal V1 situation.
676745
else:
677-
# npu_fused_infer_attention_score does not support cases
678-
# where query.shape[0] != attn_metadata.query_start_loc[-1].
679-
# Thus we need unpad it here.
680-
num_tokens = attn_metadata.query_start_loc[-1]
681-
query = query[:num_tokens]
682-
output = self._forward_v1_style(query, attn_metadata, output)
746+
intermediate_output, query_num_tokens = self.full_graph_attention(
747+
query, key, value, kv_cache, attn_metadata, output)
748+
output[:query_num_tokens] = intermediate_output[:query_num_tokens]
683749

684750
# to make in-place change to the output tensor
685751
if hasattr(layer, 'quant_method') and use_kv_cache_int8:

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -897,13 +897,9 @@ def _make_attention_mask(self, seq_lens, position,
897897
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
898898
return self.attn_mask_builder.get_splitfuse_attn_mask()
899899

900-
# Prefill without cache situation.
901-
elif attn_state == AscendAttentionState.PrefillNoCache:
900+
# Prefill without cache situation and Prefill with cache hit.
901+
elif attn_state == AscendAttentionState.PrefillNoCache or attn_state == AscendAttentionState.PrefillCacheHit:
902902
return self.attn_mask_builder.get_splitfuse_attn_mask()
903-
# Prefill with cache hit.
904-
elif attn_state == AscendAttentionState.PrefillCacheHit:
905-
return self.attn_mask_builder.get_attn_mask(
906-
128, self.dtype, self.device)
907903
# Decode-only situation.
908904
else:
909905
return None

0 commit comments

Comments
 (0)