-
Notifications
You must be signed in to change notification settings - Fork 629
Support DeepSeekV3.2 with MLAPO operator #4753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for DeepSeekV3.2 with the MLAPO operator and introduces context parallelism for SFA (SFA CP). The changes are extensive, involving modifications to the attention mechanism, introduction of shared weight layers for memory optimization, and new distributed communication patterns.
My review has identified a few critical issues that will cause NameError exceptions at runtime due to variables being used before they are defined in all code paths. Specifically, cum_query_lens and seq_lens are used without being defined, and actual_seq_lengths_query and actual_seq_lengths_key are not initialized in all execution branches before being passed to a function. I have also found a logical error in a new utility function. Please address these issues to ensure the code is correct and robust.
vllm_ascend/attention/sfa_v1.py
Outdated
| actual_seq_lengths_query = torch.empty_like(cum_query_lens) | ||
| actual_seq_lengths_key = torch.empty_like(seq_lens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variables cum_query_lens and seq_lens are used here but are not defined within this scope. They should be accessed from attn_metadata, e.g., attn_metadata.cum_query_lens. This also applies to their usages on lines 775, 780, and 792 within this if self.enable_sfa_cp: block. This will cause a NameError at runtime.
| actual_seq_lengths_query = torch.empty_like(cum_query_lens) | |
| actual_seq_lengths_key = torch.empty_like(seq_lens) | |
| actual_seq_lengths_query = torch.empty_like(attn_metadata.cum_query_lens) | |
| actual_seq_lengths_key = torch.empty_like(attn_metadata.seq_lens) |
| topk_indices = self.indexer_select( | ||
| x=hidden_states, | ||
| qr=q_c, | ||
| kv_cache=kv_cache, | ||
| attn_metadata=attn_metadata, | ||
| cos=cos, | ||
| sin=sin, | ||
| actual_seq_lengths_query=actual_seq_lengths_query, | ||
| actual_seq_lengths_key=actual_seq_lengths_key, | ||
| need_gather_q_kv=need_gather_q_kv) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to self.indexer_select requires actual_seq_lengths_query and actual_seq_lengths_key as arguments. However, these variables are only defined when self.enable_sfa_cp is true. They are not defined for the else case, nor for the if self.enable_mlapo and not forward_context.with_prefill: path. This will lead to a NameError.
You should define these variables in all code paths leading to this call. For the paths where they are not defined, they should likely be initialized to attn_metadata.cum_query_lens and attn_metadata.seq_lens respectively.
| def check_diff(a: torch.Tensor, b: torch.Tensor) -> Any: | ||
| if torch.equal(a, b): | ||
| absolute = torch.abs(a - b) | ||
| relative = torch.abs(a - b) / (torch.abs(a) + 1e-9) | ||
| return (torch.max(absolute).item(), torch.max(relative).item()) | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic in this function seems inverted. If torch.equal(a, b) is true, it proceeds to calculate the difference (which will be zero) and returns (0.0, 0.0). If they are not equal, it returns False. A function named check_diff would typically return False or (0.0, 0.0) to indicate no difference, and return the difference metrics when the tensors are not equal.
| def check_diff(a: torch.Tensor, b: torch.Tensor) -> Any: | |
| if torch.equal(a, b): | |
| absolute = torch.abs(a - b) | |
| relative = torch.abs(a - b) / (torch.abs(a) + 1e-9) | |
| return (torch.max(absolute).item(), torch.max(relative).item()) | |
| return False | |
| def check_diff(a: torch.Tensor, b: torch.Tensor) -> Any: | |
| if torch.equal(a, b): | |
| return False | |
| absolute = torch.abs(a - b) | |
| relative = torch.abs(a - b) / (torch.abs(a) + 1e-9) | |
| return (torch.max(absolute).item(), torch.max(relative).item()) |
4408d64 to
a6c12be
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
b97166d to
22b945f
Compare
2ef7739 to
283116c
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
a362648 to
da2f3e4
Compare
22e364e to
3b46f19
Compare
Signed-off-by: ZYang6263 <[email protected]> [Feat]enable sfa cp for dsv3.2 (vllm-project#4702) RFC: vllm-project/vllm#30055 1. enable flashcommon1 export VLLM_ASCEND_ENABLE_FLASHCOMM1=1 2. enable sfa-cp --additional-config '{ "enable_sfa_cp": true }' \ - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Co-authored-by: Yizhou Liu <[email protected]> Signed-off-by: ZYang6263 <[email protected]>
What this PR does / why we need it?
This PR adds support for the optimized MLAPO operator in DSV3.2 and this operator provides an optimized implementation that avoids redundant q_down recomputation.
The operator implementation and optimizations were introduced in PR #4707.
Does this PR introduce any user-facing change?
How was this patch tested?