Skip to content

Conversation

@AlvisGong
Copy link
Contributor

@AlvisGong AlvisGong commented Dec 4, 2025

What this PR does / why we need it?

RFC: vllm-project/vllm#30055

Does this PR introduce any user-facing change?

How was this patch tested?

  1. enable flashcommon1
    export VLLM_ASCEND_ENABLE_FLASHCOMM1=1
  2. enable sfa-cp
    --additional-config '{ "enable_sfa_cp": true }' \

we achieved 3.36X imporevment in prefill.
image

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables Sparse Flash Attention with Context Parallelism (SFA-CP) for dsv3.2. The changes are extensive, introducing sequence parallelism by modifying the attention mechanism to use ReplicatedLinear layers and a new shared_weight_layer utility for memory-efficient weight management and prefetching. While the implementation of sequence parallelism seems mostly correct, I've identified a few issues. There is a critical bug in the weight prefetching logic that could lead to incorrect weights being loaded. Additionally, there's a performance bottleneck in a new loop that should be vectorized, and a logical error in a new utility function. Addressing these points will improve the correctness and performance of the implementation.

Comment on lines +113 to +114
next_layer_idx = (layer_idx + self.prefetch_step
) % self.num_layers + self.start_layer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a bug in the calculation of next_layer_idx. The formula (layer_idx + self.prefetch_step) % self.num_layers incorrectly uses the global layer_idx. This will lead to prefetching the wrong layer's weights unless self.start_layer is a multiple of self.num_layers. The calculation should be based on the layer's index relative to the start of the series.

Suggested change
next_layer_idx = (layer_idx + self.prefetch_step
) % self.num_layers + self.start_layer
next_layer_idx = ((layer_idx - self.start_layer + self.prefetch_step) %
self.num_layers + self.start_layer)

Comment on lines 584 to 630
if self.enable_sfa_cp:
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
actual_seq_lengths_key = torch.empty_like(seq_lens)
num_segs = cum_query_lens.shape[0]
last_token = 0
cum = 0
for i in range(0, num_segs):
global_start = last_token
global_end = cum_query_lens[i].item()
last_token = global_end

local_start = max(global_start, sfa_sp_context.local_start)
local_end = min(global_end, sfa_sp_context.local_end_with_pad)
num_local_tokens = local_end - local_start

if num_local_tokens > 0:
cum += num_local_tokens
actual_seq_lengths_query[i] = cum

offset = global_end - local_end
actual_seq_lengths_key[i] = seq_lens[i].item() - offset
else:
actual_seq_lengths_query[i] = cum
actual_seq_lengths_key[i] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The loop to calculate actual_seq_lengths_query and actual_seq_lengths_key is executed on the CPU and involves .item() calls, which will cause GPU-CPU synchronization for each sequence in the batch. This can be a significant performance bottleneck, especially for large batches. This logic should be vectorized to run efficiently on the GPU, avoiding per-item synchronization.

        if self.enable_sfa_cp:
            global_ends = cum_query_lens
            global_starts = torch.cat((torch.tensor([0], device=global_ends.device, dtype=global_ends.dtype), global_ends[:-1]))

            local_starts = torch.max(global_starts, torch.tensor(sfa_sp_context.local_start, device=global_starts.device, dtype=global_starts.dtype))
            local_ends = torch.min(global_ends, torch.tensor(sfa_sp_context.local_end_with_pad, device=global_ends.device, dtype=global_ends.dtype))
            
            num_local_tokens = torch.clamp(local_ends - local_starts, min=0)
            
            actual_seq_lengths_query = torch.cumsum(num_local_tokens, dim=0)
            
            offsets = torch.clamp(global_ends - local_ends, min=0)
            actual_seq_lengths_key = torch.clamp(seq_lens - offsets, min=0)

Comment on lines +48 to +62
def check_diff(a: torch.Tensor, b: torch.Tensor) -> Any:
if torch.equal(a, b):
absolute = torch.abs(a - b)
relative = torch.abs(a - b) / (torch.abs(a) + 1e-9)
return (torch.max(absolute).item(), torch.max(relative).item())
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic in check_diff is incorrect. It calculates the difference only if the tensors are equal (in which case the difference is zero), and returns False if they are not equal. The intention is likely the opposite: to calculate and return the difference if the tensors are not equal, and return False if they are.

Suggested change
def check_diff(a: torch.Tensor, b: torch.Tensor) -> Any:
if torch.equal(a, b):
absolute = torch.abs(a - b)
relative = torch.abs(a - b) / (torch.abs(a) + 1e-9)
return (torch.max(absolute).item(), torch.max(relative).item())
return False
def check_diff(a: torch.Tensor, b: torch.Tensor) -> Any:
if torch.equal(a, b):
return False
absolute = torch.abs(a - b)
relative = torch.abs(a - b) / (torch.abs(a) + 1e-9)
return (torch.max(absolute).item(), torch.max(relative).item())

@github-actions
Copy link

github-actions bot commented Dec 4, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@AlvisGong AlvisGong changed the title enable sfa cp for dsv3.2 [Feat]enable sfa cp for dsv3.2 Dec 4, 2025
Signed-off-by: AlvisGong <[email protected]>
Co-authored-by: clrs97 <[email protected]>
Co-authored-by: zzhx1 <[email protected]>
Signed-off-by: zzhx1 <[email protected]>
Co-authored-by: clrs97 <[email protected]>
Co-authored-by: AlvisGong <[email protected]>
@zzhx1
Copy link
Contributor

zzhx1 commented Dec 4, 2025

Amazing work

Signed-off-by: zzhx1 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants