Skip to content

Conversation

@whx-sjtu
Copy link
Collaborator

@whx-sjtu whx-sjtu commented Dec 7, 2025

What this PR does / why we need it?

This PR adds back pa in scenarios of small batch sizes due to performance consideration. Will remove pa once fia performs better than pa in all scenarios.

Does this PR introduce any user-facing change?

No

How was this patch tested?

CI passed with existing test.

@whx-sjtu whx-sjtu requested a review from yiz-liu December 7, 2025 10:53
@github-actions
Copy link

github-actions bot commented Dec 7, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@whx-sjtu whx-sjtu requested a review from weijinqian0 December 7, 2025 10:53
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a temporary performance optimization by using paged attention for small batch sizes during graph capture. The changes are logical and well-contained. I've identified one critical issue in the implementation of the new attention function that could lead to runtime errors under certain conditions. My feedback includes a code suggestion to resolve this issue.

Comment on lines +510 to +551
if forward_context.capturing:
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))

# Handle graph capturing mode
stream = torch_npu.npu.current_stream()

event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
weak_ref_tensors(query),
weak_ref_tensors(self.key_cache),
weak_ref_tensors(self.value_cache),
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
weak_ref_tensors(output),
))

torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The if forward_context.capturing: check is problematic. If this function is ever called in a non-capturing context, it will implicitly return None, but the caller expects a torch.Tensor. This would cause a runtime error.

The call site in forward_impl already ensures this function is only called when forward_context.capturing is true, making this check redundant. For consistency with the similar full_graph_attention function and to prevent potential bugs, this check should be removed. The function's logic should assume it is running in a capturing context.

        # Get workspace from cache or calculate it if not present.
        workspace = graph_params.workspaces.get(num_tokens)
        if workspace is None:
            workspace = torch_npu._npu_paged_attention_get_workspace(
                query=query,
                key_cache=self.key_cache,
                value_cache=self.value_cache,
                num_kv_heads=self.num_kv_heads,
                num_heads=self.num_heads,
                scale_value=self.scale,
                block_table=attn_metadata.block_tables,
                context_lens=attn_metadata.seq_lens,
                out=output)
            update_graph_params_workspaces(num_tokens,
                                           weak_ref_tensors(workspace))

        # Handle graph capturing mode
        stream = torch_npu.npu.current_stream()

        event = torch.npu.ExternalEvent()
        event.wait(stream)
        event.reset(stream)
        graph_params.events[num_tokens].append(event)
        graph_params.attn_params[num_tokens].append((
            weak_ref_tensors(query),
            weak_ref_tensors(self.key_cache),
            weak_ref_tensors(self.value_cache),
            self.num_kv_heads,
            self.num_heads,
            self.scale,
            attn_metadata.block_tables,
            attn_metadata.seq_lens,
            weak_ref_tensors(output),
        ))

        torch.npu.graph_task_group_begin(stream)
        torch_npu._npu_paged_attention(
            query=query,
            key_cache=self.key_cache,
            value_cache=self.value_cache,
            num_kv_heads=self.num_kv_heads,
            num_heads=self.num_heads,
            scale_value=self.scale,
            block_table=attn_metadata.block_tables,
            context_lens=attn_metadata.seq_lens,
            out=output,
            workspace=workspace)
        handle = torch.npu.graph_task_group_end(stream)
        graph_params.handles[num_tokens].append(handle)
        return output

@whx-sjtu whx-sjtu changed the title [Attention] Temporarily add back pa in small batch sizes. [Attention] Temporarily add back pa for small batch sizes. Dec 7, 2025
Copy link
Collaborator

@yiz-liu yiz-liu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last question, when will we finally remove this PA?

@whx-sjtu
Copy link
Collaborator Author

One last question, when will we finally remove this PA?

After 1230, fia will support flash decoding. Then all GQA models will do performance tests of different scenarios using new fia. If results show no performances problem, I will remove pa finally.

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Copy link
Collaborator

@wangxiyuan wangxiyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@whx-sjtu whx-sjtu force-pushed the add_back_pa_main branch 4 times, most recently from fb032fb to b44b7d7 Compare December 15, 2025 06:53
@weijinqian0 weijinqian0 merged commit a962585 into vllm-project:main Dec 15, 2025
12 of 14 checks passed
chenaoxuan pushed a commit to chenaoxuan/vllm-ascend that referenced this pull request Dec 20, 2025
…ect#4765)

### What this PR does / why we need it?
This PR adds back pa in scenarios of small batch sizes due to
performance consideration. Will remove pa once fia performs better than
pa in all scenarios.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with existing test.


- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: whx-sjtu <[email protected]>
Co-authored-by: weijinqian0 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants