Skip to content

Commit 92050f5

Browse files
committed
ascend 950 support qwen dense model
Signed-off-by: wangyao <[email protected]>
1 parent e38ef2c commit 92050f5

File tree

5 files changed

+138
-56
lines changed

5 files changed

+138
-56
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 117 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
update_graph_params_workspaces)
5151
from vllm_ascend.ops.attention import vanilla_chunked_prefill
5252
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
53-
nd_to_nz_2d, nd_to_nz_spec,
53+
is_Ascend950, nd_to_nz_2d, nd_to_nz_spec,
5454
prefill_context_parallel_enable,
5555
weak_ref_tensors)
5656

@@ -703,15 +703,29 @@ def _forward_prefill_no_cache(
703703
mask = torch_npu.npu_format_cast(mask.contiguous(),
704704
ACL_FORMAT_FRACTAL_NZ)
705705

706-
torch_npu._npu_flash_attention(query=query,
707-
key=key,
708-
value=value,
709-
mask=mask,
710-
seq_len=attn_metadata.seq_lens,
711-
scale_value=self.scale,
712-
num_heads=self.num_heads,
713-
num_kv_heads=self.num_kv_heads,
714-
out=output)
706+
if is_Ascend950():
707+
num_tokens = attn_metadata.query_start_loc[-1]
708+
output, _ = torch_npu.npu_fused_infer_attention_score_v2(
709+
query[:num_tokens],
710+
key[:num_tokens],
711+
value[:num_tokens],
712+
atten_mask=mask.to(torch.bool),
713+
actual_seq_qlen=attn_metadata.query_lens.cumsum(0),
714+
actual_seq_kvlen=attn_metadata.seq_lens.cumsum(0),
715+
num_query_heads=self.num_heads,
716+
num_key_value_heads=self.num_kv_heads,
717+
input_layout="TND",
718+
softmax_scale=self.scale)
719+
else:
720+
torch_npu._npu_flash_attention(query=query,
721+
key=key,
722+
value=value,
723+
mask=mask,
724+
seq_len=attn_metadata.seq_lens,
725+
scale_value=self.scale,
726+
num_heads=self.num_heads,
727+
num_kv_heads=self.num_kv_heads,
728+
out=output)
715729
assert output is not None
716730
return output[:num_tokens]
717731

@@ -729,6 +743,27 @@ def _forward_prefill_cache_hit(
729743
block_table = attn_metadata.block_tables[:batch_size, :]
730744
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
731745

746+
if is_Ascend950():
747+
compress_mask = compress_mask.to(torch.bool)
748+
key = self.key_cache.transpose(1, 2) # type: ignore
749+
value = self.value_cache.transpose(1, 2) # type: ignore
750+
block_size = self.block_size
751+
output, _ = torch_npu.npu_fused_infer_attention_score_v2(
752+
query=query,
753+
key=key,
754+
value=value,
755+
block_table=block_table,
756+
atten_mask=compress_mask,
757+
actual_seq_qlen=attn_metadata.query_lens.cumsum(0),
758+
actual_seq_kvlen=attn_metadata.seq_lens,
759+
num_query_heads=self.num_heads,
760+
num_key_value_heads=self.num_kv_heads,
761+
softmax_scale=self.scale,
762+
spare_mode=2,
763+
block_size=block_size,
764+
input_layout="TND")
765+
return output
766+
732767
if block_size == 128:
733768
# TODO:The npu_fused_infer_attention_score op is planned to
734769
# be utilized in a wider range in upcoming versions.
@@ -777,18 +812,20 @@ def _forward_decode_only(
777812
# seq_lens_tensor needs to be transferred to the device for 310P.
778813
attn_metadata.seq_lens = \
779814
attn_metadata.seq_lens.to(device=query.device)
815+
816+
batch_size = attn_metadata.seq_lens.shape[0]
817+
block_size = 128
818+
key = self.key_cache
819+
value = self.value_cache
820+
if self.key_cache is not None and self.value_cache is not None:
821+
block_size = self.key_cache.shape[1]
822+
key = self.key_cache.flatten(2, 3).contiguous()
823+
value = self.value_cache.flatten(2, 3).contiguous()
824+
780825
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
781826
0] == query.size(0):
782-
batch_size = attn_metadata.seq_lens.shape[0]
783-
block_size = 128
784-
query = query.view(batch_size, 1, self.num_heads * self.head_size)
785-
key = self.key_cache
786-
value = self.value_cache
787-
if self.key_cache is not None and self.value_cache is not None:
788-
block_size = self.key_cache.shape[1]
789-
key = self.key_cache.flatten(2, 3).contiguous()
790-
value = self.value_cache.flatten(2, 3).contiguous()
791-
827+
query = query.view(batch_size, 1,
828+
self.num_heads * self.head_size)
792829
output, _ = torch_npu.npu_fused_infer_attention_score(
793830
query,
794831
key,
@@ -805,16 +842,33 @@ def _forward_decode_only(
805842

806843
output = output.view(batch_size, self.num_heads, self.head_size)
807844
else:
808-
torch_npu._npu_paged_attention(
809-
query=query,
810-
key_cache=self.key_cache,
811-
value_cache=self.value_cache,
812-
num_kv_heads=self.num_kv_heads,
813-
num_heads=self.num_heads,
814-
scale_value=self.scale,
815-
block_table=attn_metadata.block_tables,
816-
context_lens=attn_metadata.seq_lens,
817-
out=output)
845+
if is_Ascend950():
846+
query = query.view(batch_size, 1,
847+
self.num_heads * self.head_size)
848+
output, _ = torch_npu.npu_fused_infer_attention_score_v2(
849+
query=query,
850+
key=key,
851+
value=value,
852+
actual_seq_kvlen=attn_metadata.seq_lens,
853+
num_query_heads=self.num_heads,
854+
num_key_value_heads=self.num_kv_heads,
855+
block_table=attn_metadata.block_tables[:batch_size],
856+
block_size=block_size,
857+
softmax_scale=self.scale,
858+
input_layout="BSH")
859+
output = output.view(batch_size, self.num_heads,
860+
self.head_size)
861+
else:
862+
torch_npu._npu_paged_attention(
863+
query=query,
864+
key_cache=self.key_cache,
865+
value_cache=self.value_cache,
866+
num_kv_heads=self.num_kv_heads,
867+
num_heads=self.num_heads,
868+
scale_value=self.scale,
869+
block_table=attn_metadata.block_tables,
870+
context_lens=attn_metadata.seq_lens,
871+
out=output)
818872
return output
819873

820874
def _forward_v1_style(
@@ -862,7 +916,6 @@ def _forward_v1_style(
862916
num_block, block_size, -1)
863917
value = self.value_cache.view( # type: ignore
864918
num_block, block_size, -1)
865-
866919
output, _ = torch_npu.npu_fused_infer_attention_score(
867920
query=query,
868921
key=key,
@@ -1507,12 +1560,20 @@ def forward(
15071560
if has_decode:
15081561
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
15091562
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
1510-
torch_npu._npu_reshape_and_cache(
1511-
key=key[:num_decode_tokens],
1512-
value=value[:num_decode_tokens],
1513-
key_cache=self.key_cache,
1514-
value_cache=self.value_cache,
1515-
slot_indices=slot_mapping)
1563+
if is_Ascend950():
1564+
num_tokens = slot_mapping.shape[0]
1565+
torch_npu.npu_scatter_pa_kv_cache(
1566+
key=key[:num_tokens],
1567+
value=value[:num_tokens].contiguous(),
1568+
slot_mapping=slot_mapping,
1569+
out=(self.key_cache, self.value_cache))
1570+
else:
1571+
torch_npu._npu_reshape_and_cache(
1572+
key=key[:num_decode_tokens],
1573+
value=value[:num_decode_tokens],
1574+
key_cache=self.key_cache,
1575+
value_cache=self.value_cache,
1576+
slot_indices=slot_mapping)
15161577

15171578
if has_prefill:
15181579
if self.pcp_size > 1:
@@ -1526,18 +1587,27 @@ def forward(
15261587
key, value = all_kv.split([self.head_size, self.head_size],
15271588
dim=-1)
15281589

1529-
torch_npu._npu_reshape_and_cache(
1530-
key=key[self.pcp_size * num_decode_tokens:attn_metadata.
1531-
num_actual_tokens_pcp_padded],
1532-
value=value[self.pcp_size *
1590+
if is_Ascend950():
1591+
num_tokens = attn_metadata.slot_mapping.shape[0]
1592+
torch_npu.npu_scatter_pa_kv_cache(
1593+
key=key[:num_tokens],
1594+
value=value[:num_tokens].contiguous(),
1595+
slot_mapping=attn_metadata.slot_mapping,
1596+
out=(self.key_cache, self.value_cache))
1597+
else:
1598+
torch_npu._npu_reshape_and_cache(
1599+
key=key[self.pcp_size *
15331600
num_decode_tokens:attn_metadata.
15341601
num_actual_tokens_pcp_padded],
1535-
key_cache=self.key_cache,
1536-
value_cache=self.value_cache,
1537-
slot_indices=attn_metadata.
1538-
slot_mapping[self.pcp_size *
1539-
num_decode_tokens:attn_metadata.
1540-
num_actual_tokens_pcp_padded])
1602+
value=value[self.pcp_size *
1603+
num_decode_tokens:attn_metadata.
1604+
num_actual_tokens_pcp_padded],
1605+
key_cache=self.key_cache,
1606+
value_cache=self.value_cache,
1607+
slot_indices=attn_metadata.
1608+
slot_mapping[self.pcp_size *
1609+
num_decode_tokens:attn_metadata.
1610+
num_actual_tokens_pcp_padded])
15411611

15421612
forward_context: ForwardContext = get_forward_context()
15431613
if not forward_context.capturing:

vllm_ascend/ops/rotary_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
YaRNScalingRotaryEmbedding)
2727

2828
from vllm_ascend.platform import NPUPlatform
29-
from vllm_ascend.utils import enable_custom_op, is_310p
29+
from vllm_ascend.utils import enable_custom_op, is_310p, is_Ascend950
3030

3131

3232
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
@@ -405,7 +405,7 @@ def forward_oot(
405405
query: torch.Tensor,
406406
key: torch.Tensor,
407407
):
408-
if self.mrope_section != [16, 24, 24]:
408+
if self.mrope_section != [16, 24, 24] or is_Ascend950():
409409
return super().forward_oot(positions, query, key)
410410

411411
import torch_npu

vllm_ascend/sample/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
44
from vllm.v1.sample.sampler import Sampler
55

6-
from vllm_ascend.utils import is_310p
6+
from vllm_ascend.utils import is_310p, is_Ascend950
77

88
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
99

@@ -25,8 +25,8 @@ def _apply_top_k_top_p(
2525
p: torch.Tensor,
2626
) -> torch.Tensor:
2727
# npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P
28-
if not is_310p() and p is not None and k is not None and 1 <= int(
29-
k.max()) <= 1024:
28+
if not is_310p() and not is_Ascend950() \
29+
and p is not None and k is not None and 1 <= int(k.max()) <= 1024:
3030
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
3131
return torch_npu.npu_top_k_top_p(logits, p, k)
3232

vllm_ascend/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
_CUSTOM_OP_ENABLED = None
5151
_IS_310P = None
52+
_IS_ASCEND950 = None
5253
_SLEEP_MODE_ENABLED = None
5354
_CURRENT_STREAM = None
5455
_PREFETCH_STREAM = None
@@ -668,7 +669,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
668669
class AscendSocVersion(Enum):
669670
A2 = 0
670671
A3 = 1
671-
UNDEFINED = 2
672+
A5 = 2
673+
UNDEFINED = 3
672674

673675

674676
_ascend_soc_version = None
@@ -681,6 +683,8 @@ def init_ascend_soc_version():
681683
_ascend_soc_version = AscendSocVersion.A2
682684
elif 250 <= soc_version <= 255:
683685
_ascend_soc_version = AscendSocVersion.A3
686+
elif soc_version == 260:
687+
_ascend_soc_version = AscendSocVersion.A5
684688
else:
685689
_ascend_soc_version = AscendSocVersion.UNDEFINED
686690

@@ -945,3 +949,10 @@ def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
945949
reorgnized_batch_ids.append(ranks)
946950

947951
return reorgnized_batch_ids
952+
953+
954+
def is_Ascend950():
955+
global _IS_ASCEND950
956+
if _IS_ASCEND950 is None:
957+
_IS_ASCEND950 = (get_ascend_soc_version() == AscendSocVersion.A5)
958+
return _IS_ASCEND950

vllm_ascend/worker/worker_v1.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
from vllm_ascend.device_allocator.camem import CaMemAllocator
4848
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
4949
from vllm_ascend.platform import NPUPlatform
50-
from vllm_ascend.utils import (init_ascend_soc_version, is_enable_nz,
51-
prefill_context_parallel_enable,
50+
from vllm_ascend.utils import (init_ascend_soc_version, is_Ascend950,
51+
is_enable_nz, prefill_context_parallel_enable,
5252
register_ascend_customop, sleep_mode_enabled,
5353
try_register_lib, vllm_version_is)
5454
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -342,7 +342,8 @@ def compile_or_warm_up_model(self) -> None:
342342
self.model_runner.capture_model()
343343
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
344344
# may cause performance degradation at runtime.
345-
self._warm_up_atb()
345+
if not is_Ascend950():
346+
self._warm_up_atb()
346347
# Reset the seed to ensure that the random state is not affected by
347348
# the model initialization and profiling.
348349
NPUPlatform.seed_everything(self.model_config.seed)

0 commit comments

Comments
 (0)