Skip to content

Commit a86bdf8

Browse files
committed
delete sfa_cp_context.py
Signed-off-by: zzhx1 <[email protected]>
1 parent 35267e0 commit a86bdf8

File tree

3 files changed

+36
-103
lines changed

3 files changed

+36
-103
lines changed

vllm_ascend/attention/sfa_v1.py

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar
3-
import math
3+
44
import torch
55
import torch_npu
66
from torch import nn
@@ -17,21 +17,16 @@
1717
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
1818
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
1919
wait_for_kv_layer_from_connector)
20-
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
21-
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
22-
is_enable_nz, _round_up)
23-
from vllm_ascend.worker.npu_input_batch import InputBatch
24-
from vllm_ascend.utils import dispose_tensor, dispose_layer, replace_layer, enable_sp
2520
from vllm_ascend.ops.shared_weight_layer import (
2621
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
2722
reach_layer_for_shared_weight_series,
2823
register_layer_to_shared_weight_series)
2924
from vllm_ascend.ops.triton.rope import rope_forward_triton
3025
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
3126
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
32-
dispose_layer, is_enable_nz, replace_layer)
27+
_round_up, dispose_layer, enable_sp,
28+
is_enable_nz, replace_layer)
3329
from vllm_ascend.worker.npu_input_batch import InputBatch
34-
from vllm.forward_context import get_forward_context
3530

3631
if TYPE_CHECKING:
3732
from vllm.v1.core.sched.output import SchedulerOutput
@@ -59,7 +54,6 @@ def get_impl_cls() -> Type["AscendSFAImpl"]:
5954
return AscendSFAImpl
6055

6156

62-
6357
@dataclass
6458
class SfaCpContext:
6559
num_tokens: int
@@ -73,6 +67,7 @@ class SfaCpContext:
7367
actual_seq_lengths_query: torch.Tensor
7468
actual_seq_lengths_key: torch.Tensor
7569

70+
7671
@dataclass
7772
class AscendSFAMetadata:
7873
"""Metadata for MLACommon.
@@ -198,7 +193,7 @@ def build(
198193
1).unsqueeze(2)
199194
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
200195
1).unsqueeze(2)
201-
196+
202197
sfa_cp_context = None
203198
if self.enable_sfa_cp:
204199
global_tp_size = get_tp_group().world_size
@@ -214,12 +209,13 @@ def build(
214209
if pad_size > 0:
215210
cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size))
216211
sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size))
217-
slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size), value=-1)
212+
slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size),
213+
value=-1)
218214
cos = cos[local_start:local_end_with_pad]
219215
sin = sin[local_start:local_end_with_pad]
220216
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
221217

222-
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
218+
actual_seq_lengths_query = torch.empty_like(cum_query_lens)
223219
actual_seq_lengths_key = torch.empty_like(seq_lens)
224220
num_segs = cum_query_lens.shape[0]
225221
last_token = 0
@@ -347,7 +343,7 @@ def __init__(
347343
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
348344
self.model_config = get_current_vllm_config().model_config
349345
assert self.indexer is not None, "Indexer is required for DSA."
350-
346+
351347
self.enable_sfa_cp = enable_sp() and \
352348
hasattr(self.model_config.hf_config, "index_topk")
353349
self.local_num_heads = self.num_heads
@@ -357,7 +353,8 @@ def __init__(
357353

358354
#TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
359355
self._replace_linear_class_for_sfa_cp()
360-
from vllm_ascend.distributed.parallel_state import get_shared_weight_group
356+
from vllm_ascend.distributed.parallel_state import \
357+
get_shared_weight_group
361358
register_layer_to_shared_weight_series(
362359
series_name="q_proj",
363360
group=get_shared_weight_group(),
@@ -625,23 +622,24 @@ def forward(
625622

626623
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
627624
q_pe = self.rope_single(q_pe, cos, sin)
628-
625+
629626
actual_seq_lengths_query = attn_metadata.cum_query_lens
630627
actual_seq_lengths_key = attn_metadata.seq_lens
631628

632629
if self.enable_sfa_cp:
633630
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
634631
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
635-
636-
topk_indices = self.indexer_select(x=hidden_states,
637-
qr=q_c,
638-
kv_cache=kv_cache,
639-
attn_metadata=attn_metadata,
640-
cos=cos,
641-
sin=sin,
642-
actual_seq_lengths_query=actual_seq_lengths_query,
643-
actual_seq_lengths_key=actual_seq_lengths_key,
644-
need_gather_q_kv=need_gather_q_kv)
632+
633+
topk_indices = self.indexer_select(
634+
x=hidden_states,
635+
qr=q_c,
636+
kv_cache=kv_cache,
637+
attn_metadata=attn_metadata,
638+
cos=cos,
639+
sin=sin,
640+
actual_seq_lengths_query=actual_seq_lengths_query,
641+
actual_seq_lengths_key=actual_seq_lengths_key,
642+
need_gather_q_kv=need_gather_q_kv)
645643
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
646644
query=ql_nope,
647645
key=kv_cache[0],
@@ -751,19 +749,18 @@ def indexer_select(
751749
sparse_count=2048,
752750
sparse_mode=3)
753751
return topk_indices
754-
752+
755753
def _replace_linear_class_for_sfa_cp(self):
756754

757755
vllm_config = get_current_vllm_config()
758756
# Dispose tensor from the original q_proj
759757
dispose_layer(self.q_proj)
760758
# Construct the new q_proj using ReplicatedLinear
761-
new_q_proj = ReplicatedLinear(
762-
self.q_lora_rank,
763-
self.local_num_heads * self.qk_head_dim,
764-
bias=False,
765-
quant_config=vllm_config.quant_config,
766-
prefix=self.q_proj.prefix)
759+
new_q_proj = ReplicatedLinear(self.q_lora_rank,
760+
self.local_num_heads * self.qk_head_dim,
761+
bias=False,
762+
quant_config=vllm_config.quant_config,
763+
prefix=self.q_proj.prefix)
767764
# Replace the q_proj with the new one
768765
replace_layer(self.q_proj, new_q_proj)
769766

@@ -783,13 +780,11 @@ def _replace_linear_class_for_sfa_cp(self):
783780
dispose_layer(self.o_proj)
784781
# Construct the new o_proj using ReplicatedLinear
785782
config = vllm_config.model_config.hf_config
786-
new_o_proj = ReplicatedLinear(
787-
config.num_attention_heads * config.v_head_dim,
788-
config.hidden_size,
789-
bias=False,
790-
quant_config=vllm_config.quant_config,
791-
prefix=self.o_proj.prefix)
783+
new_o_proj = ReplicatedLinear(config.num_attention_heads *
784+
config.v_head_dim,
785+
config.hidden_size,
786+
bias=False,
787+
quant_config=vllm_config.quant_config,
788+
prefix=self.o_proj.prefix)
792789
# Replace the o_proj with the new one
793790
replace_layer(self.o_proj, new_o_proj)
794-
795-

vllm_ascend/distributed/parallel_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
233233
get_world_group().local_rank,
234234
backend,
235235
group_name="flashcomm2_odp")
236-
236+
237237
vllm_config = get_current_vllm_config()
238238
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
239239
is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk")

vllm_ascend/distributed/sfa_sp_context.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

0 commit comments

Comments
 (0)