Skip to content

Commit fb032fb

Browse files
committed
add graph mode check
Signed-off-by: whx-sjtu <[email protected]>
1 parent 540ca57 commit fb032fb

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
from vllm.v1.kv_cache_interface import AttentionSpec
3535

3636
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
37-
split_decodes_and_prefills)
37+
split_decodes_and_prefills,
38+
using_paged_attention)
3839
from vllm_ascend.compilation.acl_graph import (get_graph_params,
3940
update_graph_params_workspaces)
4041
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
@@ -763,9 +764,7 @@ def forward_impl(
763764
attn_metadata, output)
764765
else:
765766
num_tokens = query.shape[0]
766-
if get_current_vllm_config(
767-
).speculative_config is None and attn_metadata.attn_state == AscendAttentionState.DecodeOnly and num_tokens in get_ascend_config(
768-
).pa_shape_list:
767+
if using_paged_attention(attn_metadata.attn_state, num_tokens):
769768
output = self.full_graph_attention_with_pa(
770769
query, attn_metadata, output)
771770
else:

vllm_ascend/attention/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,29 @@
11
from dataclasses import dataclass
2+
from functools import lru_cache
23
from typing import Any, List, Optional
34

45
import torch
56
import torch.nn.functional as F
7+
from vllm.config import get_current_vllm_config
68
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
79
has_kv_transfer_group,
810
is_v1_kv_transfer_group)
911
from vllm.forward_context import ForwardContext, get_forward_context
1012

13+
from vllm_ascend.utils import get_ascend_config
14+
15+
16+
@lru_cache
17+
def using_paged_attention(attn_state, runtime_shape: int) -> bool:
18+
vllm_config = get_current_vllm_config()
19+
if vllm_config.speculative_config is not None:
20+
return False
21+
from vllm.config.compilation import CUDAGraphMode
22+
23+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
24+
return vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY and attn_state == AscendAttentionState.DecodeOnly and runtime_shape in get_ascend_config(
25+
).pa_shape_list
26+
1127

1228
@dataclass
1329
# class AscendCommonLongSequenceMetadata:

vllm_ascend/compilation/acl_graph.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from vllm.logger import logger
2020
from vllm.platforms import current_platform
2121

22-
from ..attention.utils import PAGED_ATTENTION_LIST
22+
from vllm_ascend.attention.utils import using_paged_attention
23+
2324
from ..utils import weak_ref_tensors
2425

2526

@@ -296,7 +297,8 @@ def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
296297

297298

298299
def update_attn_params(update_stream, forward_context, runtime_shape):
299-
if runtime_shape in PAGED_ATTENTION_LIST:
300+
if using_paged_attention(forward_context.attn_metadata.attn_state,
301+
runtime_shape):
300302
_update_attn_pa_params(update_stream, forward_context, runtime_shape)
301303
else:
302304
_update_attn_fia_params(update_stream, forward_context, runtime_shape)

0 commit comments

Comments
 (0)