Skip to content

Commit 35267e0

Browse files
committed
According to community feedback, the code has been reconstructed.
Signed-off-by: zzhx1 <[email protected]>
1 parent 830518a commit 35267e0

File tree

4 files changed

+171
-126
lines changed

4 files changed

+171
-126
lines changed

vllm_ascend/ascend_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def __init__(self, vllm_config):
8080
enable_shared_expert_dp=True)
8181
self.multistream_overlap_shared_expert = additional_config.get(
8282
"multistream_overlap_shared_expert", False)
83-
self.enable_sfa_cp = additional_config.get("enable_sfa_cp", False)
8483
self.recompute_scheduler_enable = additional_config.get(
8584
"recompute_scheduler_enable", False)
8685
self.lmhead_tensor_parallel_size = additional_config.get(

vllm_ascend/attention/sfa_v1.py

Lines changed: 163 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar
3-
3+
import math
44
import torch
55
import torch_npu
66
from torch import nn
@@ -17,8 +17,11 @@
1717
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
1818
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
1919
wait_for_kv_layer_from_connector)
20-
from vllm_ascend.distributed.sfa_sp_context import (get_sfa_sp_context,
21-
set_sfa_sp_context)
20+
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
21+
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
22+
is_enable_nz, _round_up)
23+
from vllm_ascend.worker.npu_input_batch import InputBatch
24+
from vllm_ascend.utils import dispose_tensor, dispose_layer, replace_layer, enable_sp
2225
from vllm_ascend.ops.shared_weight_layer import (
2326
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
2427
reach_layer_for_shared_weight_series,
@@ -28,6 +31,7 @@
2831
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
2932
dispose_layer, is_enable_nz, replace_layer)
3033
from vllm_ascend.worker.npu_input_batch import InputBatch
34+
from vllm.forward_context import get_forward_context
3135

3236
if TYPE_CHECKING:
3337
from vllm.v1.core.sched.output import SchedulerOutput
@@ -55,6 +59,20 @@ def get_impl_cls() -> Type["AscendSFAImpl"]:
5559
return AscendSFAImpl
5660

5761

62+
63+
@dataclass
64+
class SfaCpContext:
65+
num_tokens: int
66+
num_tokens_pad: int
67+
local_start: int
68+
local_end: int
69+
local_end_with_pad: int
70+
pad_size: int
71+
local_pad_size: int
72+
slot_mapping_cp: torch.Tensor
73+
actual_seq_lengths_query: torch.Tensor
74+
actual_seq_lengths_key: torch.Tensor
75+
5876
@dataclass
5977
class AscendSFAMetadata:
6078
"""Metadata for MLACommon.
@@ -85,6 +103,7 @@ class AscendSFAMetadata:
85103
attn_mask: torch.Tensor = None
86104
# chunked prefill by default if no attn_states passed
87105
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
106+
sfa_cp_context: Optional[SfaCpContext] = None
88107

89108

90109
M = TypeVar("M", bound=AscendSFAMetadata)
@@ -128,6 +147,9 @@ def __init__(self,
128147
self.cos_cache = None
129148
self.sin_cache = None
130149

150+
self.enable_sfa_cp = enable_sp() and \
151+
hasattr(self.model_config.hf_config, "index_topk")
152+
131153
def reorder_batch(self, input_batch: "InputBatch",
132154
scheduler_output: "SchedulerOutput") -> bool:
133155
# No need to reorder for Ascend SFA
@@ -176,6 +198,63 @@ def build(
176198
1).unsqueeze(2)
177199
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
178200
1).unsqueeze(2)
201+
202+
sfa_cp_context = None
203+
if self.enable_sfa_cp:
204+
global_tp_size = get_tp_group().world_size
205+
num_tokens = num_actual_tokens
206+
num_tokens_pad = _round_up(num_actual_tokens, global_tp_size)
207+
num_tokens_per_device = num_tokens_pad // global_tp_size
208+
pad_size = num_tokens_pad - num_tokens
209+
local_start = get_tp_group().rank_in_group * num_tokens_per_device
210+
local_end_with_pad = local_start + num_tokens_per_device
211+
local_end = min(local_end_with_pad, num_actual_tokens)
212+
local_pad_size = local_end_with_pad - local_end
213+
214+
if pad_size > 0:
215+
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
216+
sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size))
217+
slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size), value=-1)
218+
cos = cos[local_start:local_end_with_pad]
219+
sin = sin[local_start:local_end_with_pad]
220+
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
221+
222+
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
223+
actual_seq_lengths_key = torch.empty_like(seq_lens)
224+
num_segs = cum_query_lens.shape[0]
225+
last_token = 0
226+
cum = 0
227+
for i in range(0, num_segs):
228+
global_start = last_token
229+
global_end = cum_query_lens[i].item()
230+
last_token = global_end
231+
232+
local_start = max(global_start, local_start)
233+
local_end = min(global_end, local_end_with_pad)
234+
num_local_tokens = local_end - local_start
235+
236+
if num_local_tokens > 0:
237+
cum += num_local_tokens
238+
actual_seq_lengths_query[i] = cum
239+
240+
offset = global_end - local_end
241+
actual_seq_lengths_key[i] = seq_lens[i].item() - offset
242+
else:
243+
actual_seq_lengths_query[i] = cum
244+
actual_seq_lengths_key[i] = 0
245+
246+
sfa_cp_context = SfaCpContext(
247+
num_tokens=num_tokens,
248+
num_tokens_pad=num_tokens_pad,
249+
local_start=local_start,
250+
local_end=local_end,
251+
local_end_with_pad=local_end_with_pad,
252+
pad_size=pad_size,
253+
local_pad_size=local_pad_size,
254+
slot_mapping_cp=slot_mapping_cp,
255+
actual_seq_lengths_query=actual_seq_lengths_query,
256+
actual_seq_lengths_key=actual_seq_lengths_key,
257+
)
179258

180259
return self.metadata_cls( # type: ignore
181260
has_prefill=has_prefill,
@@ -189,7 +268,8 @@ def build(
189268
attn_state=common_attn_metadata.attn_state,
190269
block_tables=block_table,
191270
sin=sin,
192-
cos=cos)
271+
cos=cos,
272+
sfa_cp_context=sfa_cp_context)
193273

194274
def build_for_graph_capture(
195275
self,
@@ -265,67 +345,29 @@ def __init__(
265345
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
266346
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
267347
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
268-
348+
self.model_config = get_current_vllm_config().model_config
269349
assert self.indexer is not None, "Indexer is required for DSA."
270-
271-
self.enable_sfa_cp = ascend_config.enable_sfa_cp
350+
351+
self.enable_sfa_cp = enable_sp() and \
352+
hasattr(self.model_config.hf_config, "index_topk")
272353
self.local_num_heads = self.num_heads
273354
self.vllm_config = get_current_vllm_config()
274355
if self.enable_sfa_cp:
275356
self.local_num_heads = self.num_heads * self.tp_size
276357

277-
# Dispose tensor from the original q_proj
278-
dispose_layer(self.q_proj)
279-
# Construct the new q_proj using ReplicatedLinear
280-
new_q_proj = ReplicatedLinear(
281-
self.q_lora_rank,
282-
self.local_num_heads * self.qk_head_dim,
283-
bias=False,
284-
quant_config=self.vllm_config.quant_config,
285-
prefix=self.q_proj.prefix)
286-
# Replace the q_proj with the new one
287-
replace_layer(self.q_proj, new_q_proj)
288-
289-
# Dispose tensor from the original kv_b_proj
290-
dispose_layer(self.kv_b_proj)
291-
# Construct the new kv_b_proj using ReplicatedLinear
292-
new_kv_b_proj = ReplicatedLinear(
293-
self.kv_lora_rank,
294-
self.local_num_heads *
295-
(self.qk_nope_head_dim + self.v_head_dim),
296-
bias=False,
297-
quant_config=self.vllm_config.quant_config,
298-
prefix=self.kv_b_proj.prefix)
299-
# Replace the kv_b_proj with the new one
300-
replace_layer(self.kv_b_proj, new_kv_b_proj)
301-
302-
# Dispose tensor from the original o_proj
303-
dispose_layer(self.o_proj)
304-
# Construct the new o_proj using ReplicatedLinear
305-
config = self.vllm_config.model_config.hf_config
306-
new_o_proj = ReplicatedLinear(
307-
config.num_attention_heads * config.v_head_dim,
308-
config.hidden_size,
309-
bias=False,
310-
quant_config=self.vllm_config.quant_config,
311-
prefix=self.o_proj.prefix)
312-
# Replace the o_proj with the new one
313-
replace_layer(self.o_proj, new_o_proj)
314-
315-
from vllm_ascend.distributed.parallel_state import \
316-
get_shared_weight_group
317-
if is_hidden_layer(self.vllm_config, self.q_proj):
318-
register_layer_to_shared_weight_series(
319-
series_name="q_proj",
320-
group=get_shared_weight_group(),
321-
layer=self.q_proj,
322-
prefetch_step=1)
323-
if is_hidden_layer(self.vllm_config, self.o_proj):
324-
register_layer_to_shared_weight_series(
325-
series_name="o_proj",
326-
group=get_shared_weight_group(),
327-
layer=self.o_proj,
328-
prefetch_step=1)
358+
#TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
359+
self._replace_linear_class_for_sfa_cp()
360+
from vllm_ascend.distributed.parallel_state import get_shared_weight_group
361+
register_layer_to_shared_weight_series(
362+
series_name="q_proj",
363+
group=get_shared_weight_group(),
364+
layer=self.q_proj,
365+
prefetch_step=1)
366+
register_layer_to_shared_weight_series(
367+
series_name="o_proj",
368+
group=get_shared_weight_group(),
369+
layer=self.o_proj,
370+
prefetch_step=1)
329371

330372
# indexer param
331373
self.n_head: int = self.indexer.n_head # 64
@@ -479,6 +521,7 @@ def exec_kv(
479521
cache_mode=cache_mode,
480522
is_output_kv=True,
481523
)
524+
#TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
482525
k_pe = get_tp_group().all_gather(k_pe, 0)
483526
k_nope = get_tp_group().all_gather(k_nope, 0)
484527

@@ -538,11 +581,8 @@ def forward(
538581
has_prefill = attn_metadata.has_prefill
539582
num_actual_tokens = attn_metadata.num_actual_tokens
540583
hidden_states = hidden_states[:num_actual_tokens]
541-
sfa_sp_context = None
542584
if self.enable_sfa_cp:
543585
need_gather_q_kv = False
544-
set_sfa_sp_context(hidden_states, attn_metadata.num_actual_tokens)
545-
sfa_sp_context = get_sfa_sp_context()
546586
# Inputs and outputs may be padded for CUDA graphs
547587
output_padded = output
548588
output = output[:num_actual_tokens]
@@ -570,76 +610,38 @@ def forward(
570610
cos = attn_metadata.cos
571611
sin = attn_metadata.sin
572612
slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
573-
574613
slot_mapping_cp = None
575-
if self.enable_sfa_cp and sfa_sp_context is not None:
576-
if sfa_sp_context.pad_size > 0:
577-
cos = nn.functional.pad(
578-
cos, (0, 0, 0, 0, 0, 0, 0, sfa_sp_context.pad_size))
579-
sin = nn.functional.pad(
580-
sin, (0, 0, 0, 0, 0, 0, 0, sfa_sp_context.pad_size))
581-
slot_mapping = nn.functional.pad(slot_mapping,
582-
(0, sfa_sp_context.pad_size),
583-
value=-1)
584-
cos = cos[sfa_sp_context.local_start:sfa_sp_context.
585-
local_end_with_pad]
586-
sin = sin[sfa_sp_context.local_start:sfa_sp_context.
587-
local_end_with_pad]
588-
slot_mapping_cp = slot_mapping[
589-
sfa_sp_context.local_start:sfa_sp_context.local_end_with_pad]
614+
if self.enable_sfa_cp:
615+
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
590616

591617
self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
592618
slot_mapping_cp)
593619

594-
if self.enable_sfa_cp and sfa_sp_context is not None:
620+
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
595621
if is_hidden_layer(self.vllm_config, self.q_proj):
596622
reach_layer_for_shared_weight_series(self.q_proj)
597623
if is_hidden_layer(self.vllm_config, self.o_proj):
598624
reach_layer_for_shared_weight_series(self.o_proj)
599625

600626
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
601627
q_pe = self.rope_single(q_pe, cos, sin)
628+
629+
actual_seq_lengths_query = attn_metadata.cum_query_lens
630+
actual_seq_lengths_key = attn_metadata.seq_lens
602631

603-
cum_query_lens = attn_metadata.cum_query_lens
604-
seq_lens = attn_metadata.seq_lens
605-
actual_seq_lengths_query = cum_query_lens
606-
actual_seq_lengths_key = seq_lens
607-
608-
if self.enable_sfa_cp and sfa_sp_context is not None:
609-
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
610-
actual_seq_lengths_key = torch.empty_like(seq_lens)
611-
num_segs = cum_query_lens.shape[0]
612-
last_token = 0
613-
cum = 0
614-
for i in range(0, num_segs):
615-
global_start = last_token
616-
global_end = cum_query_lens[i].item()
617-
last_token = global_end
618-
619-
local_start = max(global_start, sfa_sp_context.local_start)
620-
local_end = min(global_end, sfa_sp_context.local_end_with_pad)
621-
num_local_tokens = local_end - local_start
622-
623-
if num_local_tokens > 0:
624-
cum += num_local_tokens
625-
actual_seq_lengths_query[i] = cum
626-
627-
offset = global_end - local_end
628-
actual_seq_lengths_key[i] = seq_lens[i].item() - offset
629-
else:
630-
actual_seq_lengths_query[i] = cum
631-
actual_seq_lengths_key[i] = 0
632-
633-
topk_indices = self.indexer_select(
634-
x=hidden_states,
635-
qr=q_c,
636-
kv_cache=kv_cache,
637-
attn_metadata=attn_metadata,
638-
cos=cos,
639-
sin=sin,
640-
actual_seq_lengths_query=actual_seq_lengths_query,
641-
actual_seq_lengths_key=actual_seq_lengths_key,
642-
need_gather_q_kv=need_gather_q_kv)
632+
if self.enable_sfa_cp:
633+
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
634+
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
635+
636+
topk_indices = self.indexer_select(x=hidden_states,
637+
qr=q_c,
638+
kv_cache=kv_cache,
639+
attn_metadata=attn_metadata,
640+
cos=cos,
641+
sin=sin,
642+
actual_seq_lengths_query=actual_seq_lengths_query,
643+
actual_seq_lengths_key=actual_seq_lengths_key,
644+
need_gather_q_kv=need_gather_q_kv)
643645
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
644646
query=ql_nope,
645647
key=kv_cache[0],
@@ -749,3 +751,45 @@ def indexer_select(
749751
sparse_count=2048,
750752
sparse_mode=3)
751753
return topk_indices
754+
755+
def _replace_linear_class_for_sfa_cp(self):
756+
757+
vllm_config = get_current_vllm_config()
758+
# Dispose tensor from the original q_proj
759+
dispose_layer(self.q_proj)
760+
# Construct the new q_proj using ReplicatedLinear
761+
new_q_proj = ReplicatedLinear(
762+
self.q_lora_rank,
763+
self.local_num_heads * self.qk_head_dim,
764+
bias=False,
765+
quant_config=vllm_config.quant_config,
766+
prefix=self.q_proj.prefix)
767+
# Replace the q_proj with the new one
768+
replace_layer(self.q_proj, new_q_proj)
769+
770+
# Dispose tensor from the original kv_b_proj
771+
dispose_layer(self.kv_b_proj)
772+
# Construct the new kv_b_proj using ReplicatedLinear
773+
new_kv_b_proj = ReplicatedLinear(
774+
self.kv_lora_rank,
775+
self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim),
776+
bias=False,
777+
quant_config=vllm_config.quant_config,
778+
prefix=self.kv_b_proj.prefix)
779+
# Replace the kv_b_proj with the new one
780+
replace_layer(self.kv_b_proj, new_kv_b_proj)
781+
782+
# Dispose tensor from the original o_proj
783+
dispose_layer(self.o_proj)
784+
# Construct the new o_proj using ReplicatedLinear
785+
config = vllm_config.model_config.hf_config
786+
new_o_proj = ReplicatedLinear(
787+
config.num_attention_heads * config.v_head_dim,
788+
config.hidden_size,
789+
bias=False,
790+
quant_config=vllm_config.quant_config,
791+
prefix=self.o_proj.prefix)
792+
# Replace the o_proj with the new one
793+
replace_layer(self.o_proj, new_o_proj)
794+
795+

0 commit comments

Comments
 (0)