Skip to content

Commit 3bc2319

Browse files
ZYang6263clrs97zzhx1hwhaokunwangxiyuan
committed
Support DeepSeekV3.2 with MLAPO operator
Signed-off-by: ZYang6263 <[email protected]> [Feat]enable sfa cp for dsv3.2 (#4702) RFC: vllm-project/vllm#30055 1. enable flashcommon1 export VLLM_ASCEND_ENABLE_FLASHCOMM1=1 2. enable sfa-cp --additional-config '{ "enable_sfa_cp": true }' \ - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: AlvisGong <[email protected]> Co-authored-by: clrs97 <[email protected]> Co-authored-by: zzhx1 <[email protected]> Co-authored-by: hwhaokun <[email protected]> Co-authored-by: wangxiyuan <[email protected]>
1 parent a5163c8 commit 3bc2319

File tree

3 files changed

+306
-49
lines changed

3 files changed

+306
-49
lines changed

vllm_ascend/ascend_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(self, vllm_config):
8080
enable_shared_expert_dp=True)
8181
self.multistream_overlap_shared_expert = additional_config.get(
8282
"multistream_overlap_shared_expert", False)
83+
self.enable_sfa_cp = additional_config.get("enable_sfa_cp", False)
8384
self.recompute_scheduler_enable = additional_config.get(
8485
"recompute_scheduler_enable", False)
8586
self.lmhead_tensor_parallel_size = additional_config.get(

vllm_ascend/attention/sfa_v1.py

Lines changed: 243 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,21 @@
66
from torch import nn
77
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
88
from vllm.config import VllmConfig, get_current_vllm_config
9-
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
109
from vllm.model_executor.layers.linear import (LinearBase, ReplicatedLinear,
1110
UnquantizedLinearMethod)
1211
from vllm.triton_utils import HAS_TRITON
1312
from vllm.v1.attention.backends.utils import AttentionCGSupport
13+
from vllm.logger import logger
1414

15+
from vllm_ascend import envs
1516
from vllm_ascend.ascend_config import get_ascend_config
1617
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1718
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
1819
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
19-
wait_for_kv_layer_from_connector)
20+
wait_for_kv_layer_from_connector,
21+
trans_rope_weight, transdata)
22+
from vllm_ascend.distributed.sfa_sp_context import (get_sfa_sp_context,
23+
set_sfa_sp_context)
2024
from vllm_ascend.ops.shared_weight_layer import (
2125
is_hidden_layer, post_process_after_loading_for_shared_weight_series,
2226
reach_layer_for_shared_weight_series,
@@ -26,6 +30,7 @@
2630
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
2731
_round_up, dispose_layer, enable_sp,
2832
is_enable_nz, replace_layer)
33+
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
2934
from vllm_ascend.worker.npu_input_batch import InputBatch
3035

3136
if TYPE_CHECKING:
@@ -341,17 +346,54 @@ def __init__(
341346
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
342347
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
343348
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
344-
self.vllm_config = get_current_vllm_config()
349+
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
350+
345351
assert self.indexer is not None, "Indexer is required for DSA."
346352

347353
self.enable_sfa_cp = enable_sp()
348354
self.local_num_heads = self.num_heads
349-
355+
self.vllm_config = get_current_vllm_config()
350356
if self.enable_sfa_cp:
351357
self.local_num_heads = self.num_heads * self.tp_size
352358

353-
#TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97
354-
self._replace_linear_class_for_sfa_cp()
359+
# Dispose tensor from the original q_proj
360+
dispose_layer(self.q_proj)
361+
# Construct the new q_proj using ReplicatedLinear
362+
new_q_proj = ReplicatedLinear(
363+
self.q_lora_rank,
364+
self.local_num_heads * self.qk_head_dim,
365+
bias=False,
366+
quant_config=self.vllm_config.quant_config,
367+
prefix=self.q_proj.prefix)
368+
# Replace the q_proj with the new one
369+
replace_layer(self.q_proj, new_q_proj)
370+
371+
# Dispose tensor from the original kv_b_proj
372+
dispose_layer(self.kv_b_proj)
373+
# Construct the new kv_b_proj using ReplicatedLinear
374+
new_kv_b_proj = ReplicatedLinear(
375+
self.kv_lora_rank,
376+
self.local_num_heads *
377+
(self.qk_nope_head_dim + self.v_head_dim),
378+
bias=False,
379+
quant_config=self.vllm_config.quant_config,
380+
prefix=self.kv_b_proj.prefix)
381+
# Replace the kv_b_proj with the new one
382+
replace_layer(self.kv_b_proj, new_kv_b_proj)
383+
384+
# Dispose tensor from the original o_proj
385+
dispose_layer(self.o_proj)
386+
# Construct the new o_proj using ReplicatedLinear
387+
config = self.vllm_config.model_config.hf_config
388+
new_o_proj = ReplicatedLinear(
389+
config.num_attention_heads * config.v_head_dim,
390+
config.hidden_size,
391+
bias=False,
392+
quant_config=self.vllm_config.quant_config,
393+
prefix=self.o_proj.prefix)
394+
# Replace the o_proj with the new one
395+
replace_layer(self.o_proj, new_o_proj)
396+
355397
from vllm_ascend.distributed.parallel_state import \
356398
get_shared_weight_group
357399
if is_hidden_layer(self.vllm_config, self.q_proj):
@@ -555,6 +597,98 @@ def rope_single(
555597
x = torch_npu.npu_interleave_rope(x, cos, sin)
556598
return x.view(B, N, D)
557599

600+
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
601+
assert self.kv_a_proj_with_mqa is None
602+
assert self.fused_qkv_a_proj is not None
603+
604+
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
605+
..., self.q_lora_rank:].contiguous()
606+
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
607+
..., :self.q_lora_rank].contiguous()
608+
609+
self.fused_qkv_a_proj.weight = None
610+
611+
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
612+
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
613+
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
614+
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
615+
wd_qkv = wd_qkv.t().contiguous()
616+
wd_qkv = transdata(wd_qkv,
617+
block_size=(16, 32)).unsqueeze(0).contiguous()
618+
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
619+
620+
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[
621+
self.q_lora_rank:].contiguous()
622+
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self.
623+
q_lora_rank].contiguous(
624+
)
625+
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
626+
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
627+
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
628+
self.qk_rope_head_dim)
629+
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
630+
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
631+
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl),
632+
dim=-1).contiguous()
633+
634+
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[
635+
self.q_lora_rank:].contiguous()
636+
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self.
637+
q_lora_rank].contiguous(
638+
)
639+
640+
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
641+
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
642+
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
643+
self.qk_rope_head_dim)
644+
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
645+
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
646+
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias),
647+
dim=-1).contiguous()
648+
649+
wu_q = self.q_proj.weight.data
650+
wu_q = wu_q.t().reshape(self.num_heads,
651+
self.qk_nope_head_dim + self.qk_rope_head_dim,
652+
-1)
653+
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
654+
wu_q = wu_q.reshape(
655+
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
656+
-1)
657+
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
658+
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
659+
660+
qb_deq_scl = self.q_proj.deq_scale.data
661+
qb_deq_scl = qb_deq_scl.reshape(
662+
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
663+
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
664+
self.qb_deq_scl = qb_deq_scl.reshape(
665+
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
666+
667+
qb_qt_bias = self.q_proj.quant_bias.data
668+
qb_qt_bias = qb_qt_bias.reshape(
669+
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
670+
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
671+
self.qb_qt_bias = qb_qt_bias.reshape(
672+
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
673+
674+
device = self.q_proj.weight.device
675+
self.gamma1 = self.q_a_layernorm.weight.data
676+
self.beta1 = self.q_a_layernorm.bias.data
677+
self.gamma2 = self.kv_a_layernorm.weight.data
678+
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data
679+
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data
680+
self.quant_scale1 = self.q_proj.input_scale.data
681+
self.quant_offset1 = self.q_proj.input_offset.data
682+
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
683+
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
684+
685+
if self.vllm_config.kv_transfer_config is not None:
686+
self.fused_qkv_a_proj.deq_scale = None
687+
self.fused_qkv_a_proj.quant_bias = None
688+
self.q_proj.deq_scale = None
689+
self.q_proj.quant_bias = None
690+
torch.npu.empty_cache()
691+
558692
def forward(
559693
self,
560694
layer_name,
@@ -578,56 +712,116 @@ def forward(
578712
return output.fill_(0)
579713
has_prefill = attn_metadata.has_prefill
580714
num_actual_tokens = attn_metadata.num_actual_tokens
715+
cos = attn_metadata.cos
716+
sin = attn_metadata.sin
717+
cum_query_lens = attn_metadata.cum_query_lens
718+
seq_lens = attn_metadata.seq_lens
719+
actual_seq_lengths_query = cum_query_lens
720+
actual_seq_lengths_key = seq_lens
581721
hidden_states = hidden_states[:num_actual_tokens]
582722
if self.enable_sfa_cp:
583723
need_gather_q_kv = False
584724
# Inputs and outputs may be padded for CUDA graphs
585725
output_padded = output
586726
output = output[:num_actual_tokens]
587-
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
588-
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
589-
dependency=hidden_states,
590-
enabled=self.enable_prefetch)
591-
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
592-
q_c, kv_no_split = qkv_lora.split(
593-
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
594-
dim=-1,
595-
)
596-
q_c = self.q_a_layernorm(q_c)
597-
598-
# Process for Flash Comm V1
599-
if need_gather_q_kv:
600-
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
601-
q_c.contiguous(), need_gather_q_kv)
602-
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
603-
kv_no_split.contiguous(), need_gather_q_kv)
604-
605-
if has_prefill:
606-
wait_for_kv_layer_from_connector(layer_name)
607727

608-
cos = attn_metadata.cos
609-
sin = attn_metadata.sin
610-
slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
611-
slot_mapping_cp = None
612-
actual_seq_lengths_query = attn_metadata.cum_query_lens
613-
actual_seq_lengths_key = attn_metadata.seq_lens
614-
if self.enable_sfa_cp:
615-
assert attn_metadata.sfa_cp_context is not None
616-
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
617-
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
618-
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
619-
620-
self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
621-
slot_mapping_cp)
622-
623-
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
624-
if is_hidden_layer(self.vllm_config, self.q_proj):
625-
reach_layer_for_shared_weight_series(self.q_proj)
626-
if is_hidden_layer(self.vllm_config, self.o_proj):
627-
reach_layer_for_shared_weight_series(self.o_proj)
628-
629-
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
630-
q_pe = self.rope_single(q_pe, cos, sin)
728+
from vllm.forward_context import get_forward_context
729+
forward_context = get_forward_context()
730+
if self.enable_mlapo and not forward_context.with_prefill:
731+
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
732+
hidden_states.contiguous(), need_gather_q_kv)
733+
k_nope, k_pe = kv_cache[0], kv_cache[1]
734+
ql_nope = torch.empty(
735+
(num_actual_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]),
736+
dtype=hidden_states.dtype,
737+
device=hidden_states.device,
738+
)
739+
q_pe = torch.empty(
740+
(num_actual_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]),
741+
dtype=hidden_states.dtype,
742+
device=hidden_states.device,
743+
)
744+
q_c = torch.empty(
745+
(num_actual_tokens, self.q_lora_rank),
746+
dtype=hidden_states.dtype,
747+
device=hidden_states.device,
748+
)
749+
torch.ops._C_ascend.mla_preprocess(
750+
hidden_states,
751+
self.wd_qkv,
752+
self.deq_scale_qkv,
753+
self.gamma1,
754+
self.beta1,
755+
self.wu_q,
756+
self.qb_deq_scl,
757+
self.gamma2,
758+
cos,
759+
sin,
760+
self.W_UK_T,
761+
k_nope,
762+
k_pe,
763+
attn_metadata.slot_mapping[:num_actual_tokens].flatten(),
764+
quant_scale0=self.quant_scale0,
765+
quant_offset0=self.quant_offset0,
766+
bias0=self.quant_bias_qkv,
767+
quant_scale1=self.quant_scale1,
768+
quant_offset1=self.quant_offset1,
769+
bias1=self.qb_qt_bias,
770+
ctkv_scale=self.ctkv_scale,
771+
q_nope_scale=self.q_nope_scale,
772+
cache_mode="krope_ctkv",
773+
quant_mode="per_tensor_quant_asymm",
774+
enable_inner_out=True,
775+
q_out0=ql_nope,
776+
kv_cache_out0=k_nope,
777+
q_out1=q_pe,
778+
kv_cache_out1=k_pe,
779+
inner_out=q_c,
780+
)
781+
else:
782+
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
783+
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
784+
dependency=hidden_states,
785+
enabled=self.enable_prefetch)
786+
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
787+
q_c, kv_no_split = qkv_lora.split(
788+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
789+
dim=-1,
790+
)
791+
q_c = self.q_a_layernorm(q_c)
792+
# Process for Flash Comm V1
793+
if need_gather_q_kv:
794+
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
795+
q_c.contiguous(), need_gather_q_kv)
796+
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
797+
kv_no_split.contiguous(), need_gather_q_kv)
798+
799+
if has_prefill:
800+
wait_for_kv_layer_from_connector(layer_name)
801+
802+
cos = attn_metadata.cos
803+
sin = attn_metadata.sin
804+
slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
805+
slot_mapping_cp = None
806+
actual_seq_lengths_query = attn_metadata.cum_query_lens
807+
actual_seq_lengths_key = attn_metadata.seq_lens
808+
if self.enable_sfa_cp:
809+
assert attn_metadata.sfa_cp_context is not None
810+
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
811+
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
812+
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
813+
814+
self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
815+
slot_mapping_cp)
816+
817+
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
818+
if is_hidden_layer(self.vllm_config, self.q_proj):
819+
reach_layer_for_shared_weight_series(self.q_proj)
820+
if is_hidden_layer(self.vllm_config, self.o_proj):
821+
reach_layer_for_shared_weight_series(self.o_proj)
822+
823+
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
824+
q_pe = self.rope_single(q_pe, cos, sin)
631825

632826
topk_indices = self.indexer_select(
633827
x=hidden_states,

0 commit comments

Comments
 (0)