Skip to content

Commit 4011513

Browse files
zzhx1clrs97
andcommitted
Fix the bug in sfa-cp under multi-DP scenarios
Signed-off-by: zzhxx <[email protected]> Co-authored-by: clrs97 <[email protected]>
1 parent dd622aa commit 4011513

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

vllm_ascend/attention/sfa_v1.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ class SfaCpContext:
6666
local_start: int
6767
local_end: int
6868
local_end_with_pad: int
69-
pad_size: int
70-
local_pad_size: int
7169
slot_mapping_cp: torch.Tensor
7270
actual_seq_lengths_query: torch.Tensor
7371
actual_seq_lengths_key: torch.Tensor
@@ -206,23 +204,41 @@ def build(
206204
sfa_cp_context = None
207205
if self.enable_sfa_cp:
208206
global_tp_size = get_tp_group().world_size
209-
num_tokens = num_actual_tokens
210-
num_tokens_pad = _round_up(num_actual_tokens, global_tp_size)
207+
num_tokens = num_input_tokens
208+
num_tokens_pad = _round_up(num_tokens, global_tp_size)
211209
num_tokens_per_device = num_tokens_pad // global_tp_size
212-
pad_size = num_tokens_pad - num_tokens
213210
local_start = get_tp_group().rank_in_group * num_tokens_per_device
214211
local_end_with_pad = local_start + num_tokens_per_device
215212
local_end = min(local_end_with_pad, num_actual_tokens)
216-
local_pad_size = local_end_with_pad - local_end
213+
214+
pad_size = num_tokens_pad - cos.shape[0]
215+
assert cos.shape == sin.shape, \
216+
f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}"
217217

218218
if pad_size > 0:
219219
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
220220
sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size))
221-
slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size),
221+
222+
pad_size_slot = num_tokens_pad - slot_mapping.shape[0]
223+
if pad_size_slot > 0:
224+
slot_mapping = nn.functional.pad(slot_mapping,
225+
(0, pad_size_slot),
222226
value=-1)
227+
else:
228+
slot_mapping = slot_mapping[:num_tokens_pad]
229+
223230
cos = cos[local_start:local_end_with_pad]
224231
sin = sin[local_start:local_end_with_pad]
225232
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
233+
assert cos.shape[0] == num_tokens_per_device, \
234+
f"cos.shape[0] must be equal to num_tokens_per_device, \
235+
got {cos.shape[0]} and {num_tokens_per_device}"
236+
assert slot_mapping_cp.shape[0] == num_tokens_per_device, \
237+
f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \
238+
got {slot_mapping_cp.shape[0]} and {num_tokens_per_device}"
239+
assert slot_mapping.shape[0] == num_tokens_pad, \
240+
f"slot_mapping.shape[0] must be equal to num_tokens_pad, \
241+
got {slot_mapping.shape[0]} and {num_tokens_pad}"
226242

227243
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
228244
actual_seq_lengths_key = torch.empty_like(seq_lens)
@@ -254,8 +270,6 @@ def build(
254270
local_start=local_start,
255271
local_end=local_end,
256272
local_end_with_pad=local_end_with_pad,
257-
pad_size=pad_size,
258-
local_pad_size=local_pad_size,
259273
slot_mapping_cp=slot_mapping_cp,
260274
actual_seq_lengths_query=actual_seq_lengths_query,
261275
actual_seq_lengths_key=actual_seq_lengths_key,

0 commit comments

Comments
 (0)