-
Notifications
You must be signed in to change notification settings - Fork 671
[Attention] Temporarily add back pa for small batch sizes. #4765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a temporary performance optimization by using paged attention for small batch sizes during graph capture. The changes are logical and well-contained. I've identified one critical issue in the implementation of the new attention function that could lead to runtime errors under certain conditions. My feedback includes a code suggestion to resolve this issue.
| if forward_context.capturing: | ||
| # Get workspace from cache or calculate it if not present. | ||
| workspace = graph_params.workspaces.get(num_tokens) | ||
| if workspace is None: | ||
| workspace = torch_npu._npu_paged_attention_get_workspace( | ||
| query=query, | ||
| key_cache=self.key_cache, | ||
| value_cache=self.value_cache, | ||
| num_kv_heads=self.num_kv_heads, | ||
| num_heads=self.num_heads, | ||
| scale_value=self.scale, | ||
| block_table=attn_metadata.block_tables, | ||
| context_lens=attn_metadata.seq_lens, | ||
| out=output) | ||
| update_graph_params_workspaces(num_tokens, | ||
| weak_ref_tensors(workspace)) | ||
|
|
||
| # Handle graph capturing mode | ||
| stream = torch_npu.npu.current_stream() | ||
|
|
||
| event = torch.npu.ExternalEvent() | ||
| event.wait(stream) | ||
| event.reset(stream) | ||
| graph_params.events[num_tokens].append(event) | ||
| graph_params.attn_params[num_tokens].append(( | ||
| weak_ref_tensors(query), | ||
| weak_ref_tensors(self.key_cache), | ||
| weak_ref_tensors(self.value_cache), | ||
| self.num_kv_heads, | ||
| self.num_heads, | ||
| self.scale, | ||
| attn_metadata.block_tables, | ||
| attn_metadata.seq_lens, | ||
| weak_ref_tensors(output), | ||
| )) | ||
|
|
||
| torch.npu.graph_task_group_begin(stream) | ||
| torch_npu._npu_paged_attention( | ||
| query=query, | ||
| key_cache=self.key_cache, | ||
| value_cache=self.value_cache, | ||
| num_kv_heads=self.num_kv_heads, | ||
| num_heads=self.num_heads, | ||
| scale_value=self.scale, | ||
| block_table=attn_metadata.block_tables, | ||
| context_lens=attn_metadata.seq_lens, | ||
| out=output, | ||
| workspace=workspace) | ||
| handle = torch.npu.graph_task_group_end(stream) | ||
| graph_params.handles[num_tokens].append(handle) | ||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if forward_context.capturing: check is problematic. If this function is ever called in a non-capturing context, it will implicitly return None, but the caller expects a torch.Tensor. This would cause a runtime error.
The call site in forward_impl already ensures this function is only called when forward_context.capturing is true, making this check redundant. For consistency with the similar full_graph_attention function and to prevent potential bugs, this check should be removed. The function's logic should assume it is running in a capturing context.
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
weak_ref_tensors(query),
weak_ref_tensors(self.key_cache),
weak_ref_tensors(self.value_cache),
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
weak_ref_tensors(output),
))
torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
return outputda55e0d to
91ba5ec
Compare
yiz-liu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One last question, when will we finally remove this PA?
91ba5ec to
58876f1
Compare
After 1230, fia will support flash decoding. Then all GQA models will do performance tests of different scenarios using new fia. If results show no performances problem, I will remove pa finally. |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
58876f1 to
859dc59
Compare
wangxiyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fb032fb to
b44b7d7
Compare
Signed-off-by: whx-sjtu <[email protected]>
Signed-off-by: whx-sjtu <[email protected]>
Signed-off-by: whx-sjtu <[email protected]>
Signed-off-by: whx-sjtu <[email protected]>
b44b7d7 to
a875aa7
Compare
…ect#4765) ### What this PR does / why we need it? This PR adds back pa in scenarios of small batch sizes due to performance consideration. Will remove pa once fia performs better than pa in all scenarios. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with existing test. - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: whx-sjtu <[email protected]> Co-authored-by: weijinqian0 <[email protected]>
What this PR does / why we need it?
This PR adds back pa in scenarios of small batch sizes due to performance consideration. Will remove pa once fia performs better than pa in all scenarios.
Does this PR introduce any user-facing change?
No
How was this patch tested?
CI passed with existing test.