Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/_e2e_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ jobs:
VLLM_USE_MODELSCOPE: True
if: ${{ inputs.type == 'full' }}
run: |
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
pytest -sv tests/e2e/multicard/test_data_parallel.py
pytest -sv tests/e2e/multicard/test_expert_parallel.py
# pytest -sv tests/e2e/multicard/test_external_launcher.py
Expand Down
44 changes: 43 additions & 1 deletion tests/e2e/multicard/test_full_graph_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from tests.e2e.model_utils import check_outputs_equal


def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY():
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE']
prompts = [
Expand Down Expand Up @@ -70,3 +70,45 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
name_0="vllm_eager_outputs",
name_1="vllm_fullgraph_outputs",
)

def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL():
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE']
prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]
model = "Qwen/Qwen3-30B-A3B"
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
with VllmRunner(model,
max_model_len=1024,
tensor_parallel_size=2,
enforce_eager=False,
compilation_config={"cudagraph_mode":
"FULL"}) as runner:
vllm_fullgraph_outputs = runner.model.generate(prompts,
sampling_params)

with VllmRunner(
model,
max_model_len=1024,
enforce_eager=True,
) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)

vllm_fullgraph_outputs_list = []
for output in vllm_fullgraph_outputs:
vllm_fullgraph_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))

vllm_eager_outputs_list = []
for output in vllm_eager_outputs:
vllm_eager_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))

check_outputs_equal(
outputs_0_lst=vllm_eager_outputs_list,
outputs_1_lst=vllm_fullgraph_outputs_list,
name_0="vllm_eager_outputs",
name_1="vllm_fullgraph_outputs",
)
205 changes: 161 additions & 44 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class AscendMetadata:
seq_lens: torch.Tensor = None
seq_lens_list: List[int] = None # type: ignore
actual_seq_lengths_q: List[int] = None # type: ignore
query_start_loc_list: List[int] = None # type: ignore

query_start_loc: torch.Tensor = None
query_lens: torch.Tensor = None
Expand All @@ -217,7 +218,8 @@ class AscendMetadata:
class AscendAttentionMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
AttentionCGSupport.ALWAYS
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
Expand Down Expand Up @@ -368,6 +370,7 @@ def build(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
block_tables=block_table,
query_start_loc=query_start_loc,
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
query_lens=query_lens,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
Expand Down Expand Up @@ -449,6 +452,123 @@ def __init__(
) if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group(
).device_group if self.dcp_size > 1 else None

def full_graph_attention(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
num_tokens=0):
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
block_size = 128
block_table = None
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
batch_size = attn_metadata.query_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
actual_seq_lengths_kv = attn_metadata.seq_lens_list
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list
# Normal V1 situation.
else:
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
block_table = attn_metadata.block_tables
actual_seq_lengths_kv = attn_metadata.seq_lens_list

num_tokens = attn_metadata.query_start_loc_list[-1]
query = query[:num_tokens]
graph_params = get_graph_params()
query_start_loc = attn_metadata.query_start_loc_list
# Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level

# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
softmax_lse = torch.empty(1,
dtype=query.dtype,
device=query.device)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
sparse_mode=3,
scale=self.scale,)
graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)

# Handle graph capturing mode
stream = torch_npu.npu.current_stream()

event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
weak_ref_tensors(query),
weak_ref_tensors(key),
weak_ref_tensors(value),
weak_ref_tensors(block_table),
weak_ref_tensors(attn_metadata.attn_mask),
block_size,
actual_seq_lengths_kv,
query_start_loc,
self.num_kv_heads,
self.num_heads,
self.scale,
weak_ref_tensors(output),
weak_ref_tensors(softmax_lse)
))

torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
workspace=workspace,
out=[output, softmax_lse],
)

output = output.view(num_tokens, self.num_heads,
self.head_size)

handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
return output, num_tokens

def _forward_prefill_no_cache(
self,
Expand Down Expand Up @@ -500,7 +620,6 @@ def _forward_prefill_cache_hit(
batch_size = attn_metadata.query_lens.shape[0]
block_table = attn_metadata.block_tables[:batch_size, :]
num_block, block_size, _, _ = self.key_cache.shape # type: ignore

if torch.version.cann.startswith("8.3") and block_size == 128:
# TODO:The npu_fused_infer_attention_score op is planned to
# be utilized in a wider range in upcoming versions.
Expand Down Expand Up @@ -545,6 +664,7 @@ def _forward_decode_only(
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
num_tokens = query.shape[0]
if is_310p():
# seq_lens_tensor needs to be transferred to the device for 310P.
attn_metadata.seq_lens = \
Expand Down Expand Up @@ -1136,49 +1256,46 @@ def forward(
slot_mapping[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded])

if self.pcp_size * self.dcp_size > 1:
intermediate_output = self._forward_pcp_dcp(
query, key, value, attn_metadata, output)
elif attn_type == AttentionType.ENCODER_ONLY:
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
intermediate_output = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
sparse_mode=4,
atten_mask=attn_metadata.attn_mask,
pre_tockens=attn_metadata.max_query_len,
next_tockens=attn_metadata.max_query_len,
actual_seq_qlen=cum_seq_len,
actual_seq_kvlen=cum_seq_len,
)[0]
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
intermediate_output = self._forward_prefill_no_cache(
query, key, value, attn_metadata, output, num_tokens)
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
intermediate_output = self._forward_prefill_cache_hit(
query, attn_metadata, output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
intermediate_output = self._forward_decode_only(
query, attn_metadata, output)
# Normal V1 situation.
forward_context: ForwardContext = get_forward_context()
if ~forward_context.capturing:
if self.pcp_size * self.dcp_size > 1:
intermediate_output = self._forward_pcp_dcp(
query, key, value, attn_metadata, output)
elif attn_type == AttentionType.ENCODER_ONLY:
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
intermediate_output = torch_npu.npu_fusion_attention(
query,
key,
value,
head_num=self.num_heads,
input_layout="TND",
scale=self.scale,
sparse_mode=4,
atten_mask=attn_metadata.attn_mask,
pre_tockens=attn_metadata.max_query_len,
next_tockens=attn_metadata.max_query_len,
actual_seq_qlen=cum_seq_len,
actual_seq_kvlen=cum_seq_len,
)[0]
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
intermediate_output = self._forward_prefill_no_cache(
query, key, value, attn_metadata, output, num_tokens)
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
intermediate_output = self._forward_prefill_cache_hit(
query, attn_metadata, output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
intermediate_output = self._forward_decode_only(
query, attn_metadata, output)
# Normal V1 situation.
else:
intermediate_output = self._forward_v1_style(
query, attn_metadata, output)
else:
if torch.version.cann.startswith("8.3"):
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
intermediate_output = self._forward_v1_style(
query, attn_metadata, output)

intermediate_output, num_tokens = self.full_graph_attention(query, key, value, attn_metadata,
output)
output[:num_tokens] = intermediate_output[:num_tokens]

return output
Loading
Loading