-
Notifications
You must be signed in to change notification settings - Fork 624
[Feat]enable sfa cp for dsv3.2 #4702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enables Sparse Flash Attention with Context Parallelism (SFA-CP) for dsv3.2. The changes are extensive, introducing sequence parallelism by modifying the attention mechanism to use ReplicatedLinear layers and a new shared_weight_layer utility for memory-efficient weight management and prefetching. While the implementation of sequence parallelism seems mostly correct, I've identified a few issues. There is a critical bug in the weight prefetching logic that could lead to incorrect weights being loaded. Additionally, there's a performance bottleneck in a new loop that should be vectorized, and a logical error in a new utility function. Addressing these points will improve the correctness and performance of the implementation.
| next_layer_idx = (layer_idx + self.prefetch_step | ||
| ) % self.num_layers + self.start_layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a bug in the calculation of next_layer_idx. The formula (layer_idx + self.prefetch_step) % self.num_layers incorrectly uses the global layer_idx. This will lead to prefetching the wrong layer's weights unless self.start_layer is a multiple of self.num_layers. The calculation should be based on the layer's index relative to the start of the series.
| next_layer_idx = (layer_idx + self.prefetch_step | |
| ) % self.num_layers + self.start_layer | |
| next_layer_idx = ((layer_idx - self.start_layer + self.prefetch_step) % | |
| self.num_layers + self.start_layer) |
vllm_ascend/attention/sfa_v1.py
Outdated
| if self.enable_sfa_cp: | ||
| actual_seq_lengths_query = torch.empty_like(cum_query_lens) | ||
| actual_seq_lengths_key = torch.empty_like(seq_lens) | ||
| num_segs = cum_query_lens.shape[0] | ||
| last_token = 0 | ||
| cum = 0 | ||
| for i in range(0, num_segs): | ||
| global_start = last_token | ||
| global_end = cum_query_lens[i].item() | ||
| last_token = global_end | ||
|
|
||
| local_start = max(global_start, sfa_sp_context.local_start) | ||
| local_end = min(global_end, sfa_sp_context.local_end_with_pad) | ||
| num_local_tokens = local_end - local_start | ||
|
|
||
| if num_local_tokens > 0: | ||
| cum += num_local_tokens | ||
| actual_seq_lengths_query[i] = cum | ||
|
|
||
| offset = global_end - local_end | ||
| actual_seq_lengths_key[i] = seq_lens[i].item() - offset | ||
| else: | ||
| actual_seq_lengths_query[i] = cum | ||
| actual_seq_lengths_key[i] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop to calculate actual_seq_lengths_query and actual_seq_lengths_key is executed on the CPU and involves .item() calls, which will cause GPU-CPU synchronization for each sequence in the batch. This can be a significant performance bottleneck, especially for large batches. This logic should be vectorized to run efficiently on the GPU, avoiding per-item synchronization.
if self.enable_sfa_cp:
global_ends = cum_query_lens
global_starts = torch.cat((torch.tensor([0], device=global_ends.device, dtype=global_ends.dtype), global_ends[:-1]))
local_starts = torch.max(global_starts, torch.tensor(sfa_sp_context.local_start, device=global_starts.device, dtype=global_starts.dtype))
local_ends = torch.min(global_ends, torch.tensor(sfa_sp_context.local_end_with_pad, device=global_ends.device, dtype=global_ends.dtype))
num_local_tokens = torch.clamp(local_ends - local_starts, min=0)
actual_seq_lengths_query = torch.cumsum(num_local_tokens, dim=0)
offsets = torch.clamp(global_ends - local_ends, min=0)
actual_seq_lengths_key = torch.clamp(seq_lens - offsets, min=0)| def check_diff(a: torch.Tensor, b: torch.Tensor) -> Any: | ||
| if torch.equal(a, b): | ||
| absolute = torch.abs(a - b) | ||
| relative = torch.abs(a - b) / (torch.abs(a) + 1e-9) | ||
| return (torch.max(absolute).item(), torch.max(relative).item()) | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic in check_diff is incorrect. It calculates the difference only if the tensors are equal (in which case the difference is zero), and returns False if they are not equal. The intention is likely the opposite: to calculate and return the difference if the tensors are not equal, and return False if they are.
| def check_diff(a: torch.Tensor, b: torch.Tensor) -> Any: | |
| if torch.equal(a, b): | |
| absolute = torch.abs(a - b) | |
| relative = torch.abs(a - b) / (torch.abs(a) + 1e-9) | |
| return (torch.max(absolute).item(), torch.max(relative).item()) | |
| return False | |
| def check_diff(a: torch.Tensor, b: torch.Tensor) -> Any: | |
| if torch.equal(a, b): | |
| return False | |
| absolute = torch.abs(a - b) | |
| relative = torch.abs(a - b) / (torch.abs(a) + 1e-9) | |
| return (torch.max(absolute).item(), torch.max(relative).item()) |
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
33e8f64 to
6b36700
Compare
Signed-off-by: AlvisGong <[email protected]> Co-authored-by: clrs97 <[email protected]> Co-authored-by: zzhx1 <[email protected]>
Signed-off-by: zzhx1 <[email protected]> Co-authored-by: clrs97 <[email protected]> Co-authored-by: AlvisGong <[email protected]>
|
Amazing work |
Signed-off-by: zzhx1 <[email protected]>
What this PR does / why we need it?
RFC: vllm-project/vllm#30055
Does this PR introduce any user-facing change?
How was this patch tested?
export VLLM_ASCEND_ENABLE_FLASHCOMM1=1
--additional-config '{ "enable_sfa_cp": true }' \
we achieved 3.36X imporevment in prefill.
