Skip to content

Commit 00aa0bf

Browse files
authored
support prefill cache mode use fia op (#3696)
### What this PR does / why we need it? support prefill cache mode use fia op for full graph ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@17c540a origin ============ Serving Benchmark Result ============ Successful requests: 30 Maximum request concurrency: 256 Request rate configured (RPS): 0.70 Benchmark duration (s): 131.63 Total input tokens: 61363 Total generated tokens: 61440 Request throughput (req/s): 0.23 Output token throughput (tok/s): 466.77 Peak output token throughput (tok/s): 750.00 Peak concurrent requests: 30.00 Total Token throughput (tok/s): 932.95 ---------------Time to First Token---------------- Mean TTFT (ms): 125.17 Median TTFT (ms): 121.51 P50 TTFT (ms): 121.51 P90 TTFT (ms): 140.91 P99 TTFT (ms): 182.36 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 43.85 Median TPOT (ms): 43.84 P50 TPOT (ms): 43.84 P90 TPOT (ms): 44.28 P99 TPOT (ms): 44.32 ---------------Inter-token Latency---------------- Mean ITL (ms): 43.85 Median ITL (ms): 42.63 P50 ITL (ms): 42.63 P90 ITL (ms): 48.74 P99 ITL (ms): 59.62 ================================================== after ============ Serving Benchmark Result ============ Successful requests: 30 Maximum request concurrency: 256 Request rate configured (RPS): 0.70 Benchmark duration (s): 130.10 Total input tokens: 61363 Total generated tokens: 61440 Request throughput (req/s): 0.23 Output token throughput (tok/s): 472.26 Peak output token throughput (tok/s): 750.00 Peak concurrent requests: 30.00 Total Token throughput (tok/s): 943.94 ---------------Time to First Token---------------- Mean TTFT (ms): 123.69 Median TTFT (ms): 122.51 P50 TTFT (ms): 122.51 P90 TTFT (ms): 143.69 P99 TTFT (ms): 165.00 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 43.07 Median TPOT (ms): 43.13 P50 TPOT (ms): 43.13 P90 TPOT (ms): 43.50 P99 TPOT (ms): 43.57 ---------------Inter-token Latency---------------- Mean ITL (ms): 43.07 Median ITL (ms): 41.81 P50 ITL (ms): 41.81 P90 ITL (ms): 48.11 P99 ITL (ms): 62.13 ================================================== Signed-off-by: shiyuan680 <[email protected]>
1 parent 3e5ae49 commit 00aa0bf

File tree

3 files changed

+45
-14
lines changed

3 files changed

+45
-14
lines changed

vllm_ascend/attention/attention_mask.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
6868

6969
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
7070
device: torch.device):
71+
if max_seq_len == 2048 and torch.version.cann.startswith("8.3"):
72+
return self.chunked_prefill_attn_mask.to(torch.bool)
7173
self._update_attn_cache(max_seq_len, dtype)
7274
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
7375
).to(device, non_blocking=True)

vllm_ascend/attention/attention_v1.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -491,19 +491,44 @@ def _forward_prefill_cache_hit(
491491
compress_mask = attn_metadata.attn_mask
492492
batch_size = attn_metadata.query_lens.shape[0]
493493
block_table = attn_metadata.block_tables[:batch_size, :]
494+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
494495

495-
torch_npu._npu_flash_attention_qlens(
496-
query=query,
497-
key_cache=self.key_cache,
498-
value_cache=self.value_cache,
499-
block_table=block_table,
500-
mask=compress_mask,
501-
seq_len=attn_metadata.query_lens,
502-
context_lens=attn_metadata.seq_lens,
503-
num_kv_heads=self.num_kv_heads,
504-
num_heads=self.num_heads,
505-
scale_value=self.scale,
506-
out=output)
496+
if torch.version.cann.startswith("8.3") and block_size == 128:
497+
# TODO:The npu_fused_infer_attention_score op is planned to
498+
# be utilized in a wider range in upcoming versions.
499+
key = self.key_cache.view( # type: ignore
500+
num_block, block_size, -1)
501+
value = self.value_cache.view( # type: ignore
502+
num_block, block_size, -1)
503+
504+
output, _ = torch_npu.npu_fused_infer_attention_score(
505+
query=query,
506+
key=key,
507+
value=value,
508+
atten_mask=compress_mask,
509+
block_table=block_table,
510+
input_layout="TND",
511+
block_size=block_size,
512+
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
513+
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
514+
num_key_value_heads=self.num_kv_heads,
515+
num_heads=self.num_heads,
516+
scale=self.scale,
517+
sparse_mode=3,
518+
)
519+
else:
520+
torch_npu._npu_flash_attention_qlens(
521+
query=query,
522+
key_cache=self.key_cache,
523+
value_cache=self.value_cache,
524+
block_table=block_table,
525+
mask=compress_mask,
526+
seq_len=attn_metadata.query_lens,
527+
context_lens=attn_metadata.seq_lens,
528+
num_kv_heads=self.num_kv_heads,
529+
num_heads=self.num_heads,
530+
scale_value=self.scale,
531+
out=output)
507532
return output
508533

509534
def _forward_decode_only(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -962,8 +962,12 @@ def _make_attention_mask(self, seq_lens, position,
962962
max_seq_len, self.dtype, self.device)
963963
# Prefill with cache hit.
964964
elif attn_state == AscendAttentionState.PrefillCacheHit:
965-
return self.attn_mask_builder.get_attn_mask(
966-
128, self.dtype, self.device)
965+
if torch.version.cann.startswith("8.3"):
966+
return self.attn_mask_builder.get_attn_mask(
967+
2048, self.dtype, self.device)
968+
else:
969+
return self.attn_mask_builder.get_attn_mask(
970+
128, self.dtype, self.device)
967971
# Decode-only situation.
968972
else:
969973
return None

0 commit comments

Comments
 (0)