Skip to content

Commit b040e39

Browse files
committed
ascend 950 support qwen dense model
Signed-off-by: wangyao <[email protected]>
1 parent e38ef2c commit b040e39

File tree

9 files changed

+144
-27
lines changed

9 files changed

+144
-27
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ def test_forward_no_attn_metadata(self):
289289
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
290290
@patch('torch_npu._npu_reshape_and_cache')
291291
@patch('torch_npu._npu_flash_attention')
292+
@patch('vllm_ascend.attention.attention_v1.is_Ascend950',
293+
return_value=False)
292294
def test_forward_prefill_no_cache(self, mock_flash_attention,
293295
mock_reshape_cache,
294296
mock_get_forward_context):
@@ -321,6 +323,8 @@ def test_forward_prefill_no_cache(self, mock_flash_attention,
321323
@patch('torch_npu._npu_reshape_and_cache')
322324
@patch('torch_npu.npu_fused_infer_attention_score')
323325
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
326+
@patch('vllm_ascend.attention.attention_v1.is_Ascend950',
327+
return_value=False)
324328
def test_forward_prefill_cache_hit(self, mock_get_forward_context,
325329
mock_npu_fused_infer_attention_score,
326330
mock_npu_reshape_and_cache):
@@ -357,6 +361,8 @@ def test_forward_prefill_cache_hit(self, mock_get_forward_context,
357361
@patch('torch_npu._npu_paged_attention')
358362
@patch('torch_npu._npu_reshape_and_cache')
359363
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
364+
@patch('vllm_ascend.attention.attention_v1.is_Ascend950',
365+
return_value=False)
360366
def test_forward_decode_only(self, mock_get_forward_context,
361367
mock_npu_reshape_and_cache,
362368
mock_paged_attention):
@@ -388,6 +394,8 @@ def test_forward_decode_only(self, mock_get_forward_context,
388394
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
389395
@patch('torch_npu.npu_fused_infer_attention_score')
390396
@patch('torch_npu._npu_reshape_and_cache')
397+
@patch('vllm_ascend.attention.attention_v1.is_Ascend950',
398+
return_value=False)
391399
def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
392400
mock_fused_infer_attention_score,
393401
mock_get_forward_context):
@@ -421,6 +429,8 @@ def test_forward_decode_only_swa(self, mock_npu_reshape_and_cache,
421429
@patch('torch_npu._npu_paged_attention')
422430
@patch('torch_npu.npu_fused_infer_attention_score')
423431
@patch('torch_npu._npu_reshape_and_cache')
432+
@patch('vllm_ascend.attention.attention_v1.is_Ascend950',
433+
return_value=False)
424434
def test_forward_decode_only_swa_seq_len_mismatch(
425435
self, mock_npu_reshape_and_cache, mock_fused_infer_attention_score,
426436
mock_paged_attention, mock_get_forward_context):
@@ -458,6 +468,8 @@ def test_forward_decode_only_swa_seq_len_mismatch(
458468
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
459469
@patch('torch_npu._npu_reshape_and_cache')
460470
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
471+
@patch('vllm_ascend.attention.attention_v1.is_Ascend950',
472+
return_value=False)
461473
def test_forward_head_size_192(self, mock_vanilla_prefill,
462474
mock_npu_reshape_and_cache, mock_is_310p,
463475
mock_get_forward_context):
@@ -493,6 +505,8 @@ def test_forward_head_size_192(self, mock_vanilla_prefill,
493505
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
494506
@patch('torch_npu.npu_fused_infer_attention_score')
495507
@patch('torch_npu._npu_reshape_and_cache')
508+
@patch('vllm_ascend.attention.attention_v1.is_Ascend950',
509+
return_value=False)
496510
def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
497511
mock_npu_fused_infer_attention_score,
498512
mock_get_forward_context):
@@ -529,6 +543,8 @@ def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
529543
@patch('torch_npu.npu_fused_infer_attention_score')
530544
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
531545
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
546+
@patch('vllm_ascend.attention.attention_v1.is_Ascend950',
547+
return_value=False)
532548
def test_forward_310p_device(self, mock_get_forward_context, mock_is_310p,
533549
mock_npu_fused_infer_attention_score,
534550
mock_npu_reshape_and_cache,

tests/ut/ops/test_rotary_embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ def _create_vllm_config(self):
428428
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
429429
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
430430
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
431+
@patch('vllm_ascend.ops.rotary_embedding.is_Ascend950', return_value=False)
431432
def test_forward_oot_1d_positions(self, mock_npu_mrope):
432433
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
433434
torch.zeros_like(self.key))
@@ -447,6 +448,7 @@ def test_forward_oot_1d_positions(self, mock_npu_mrope):
447448
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
448449
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
449450
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
451+
@patch('vllm_ascend.ops.rotary_embedding.is_Ascend950', return_value=False)
450452
def test_forward_oot_2d_positions(self, mock_npu_mrope):
451453
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
452454
torch.zeros_like(self.key))

tests/ut/sample/test_sampler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from unittest import mock
2+
from unittest.mock import patch
23

34
import torch
45

@@ -18,6 +19,7 @@ def test_init_with_raw_logprobs(self):
1819
class TestAscendTopKTopPSampler(TestBase):
1920

2021
@mock.patch("torch_npu.npu_top_k_top_p")
22+
@patch('vllm_ascend.sample.sampler.is_Ascend950')
2123
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
2224
mock_npu_op.return_value = (torch.randn(1, 3))
2325
sampler = AscendTopKTopPSampler()

tests/ut/worker/test_worker_v1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,7 @@ def test_load_model_sleep_mode_assertion_error(self, mock_allocator_class):
10751075
@patch("vllm_ascend.worker.worker_v1.NPUPlatform.seed_everything")
10761076
@patch("vllm_ascend.worker.worker_v1.logger")
10771077
@patch("vllm_ascend.worker.worker_v1.NPUWorker._warm_up_atb")
1078+
@patch('vllm_ascend.worker.worker_v1.is_Ascend950', return_value=False)
10781079
def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb,
10791080
mock_logger,
10801081
mock_seed_everything):
@@ -1124,6 +1125,7 @@ def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb,
11241125
@patch("vllm_ascend.worker.worker_v1.NPUPlatform.seed_everything")
11251126
@patch("vllm_ascend.worker.worker_v1.logger")
11261127
@patch("vllm_ascend.worker.worker_v1.NPUWorker._warm_up_atb")
1128+
@patch('vllm_ascend.worker.worker_v1', return_value=False)
11271129
def test_compile_or_warm_up_model_with_graph_capture(
11281130
self, mock_warm_up_atb, mock_logger, mock_seed_everything):
11291131
"""Test compile_or_warm_up_model method - with graph capture enabled"""

vllm_ascend/attention/attention_v1.py

Lines changed: 101 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
update_graph_params_workspaces)
5151
from vllm_ascend.ops.attention import vanilla_chunked_prefill
5252
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
53-
nd_to_nz_2d, nd_to_nz_spec,
53+
is_Ascend950, nd_to_nz_2d, nd_to_nz_spec,
5454
prefill_context_parallel_enable,
5555
weak_ref_tensors)
5656

@@ -1448,6 +1448,69 @@ def _load_kv_for_chunk(self, attn_metadata, kv_cache,
14481448
)
14491449
return key, value
14501450

1451+
def _forward_ascend_950(self, query: torch.Tensor, key: torch.Tensor,
1452+
value: torch.Tensor, attn_metadata: AscendMetadata,
1453+
output: torch.Tensor) -> torch.Tensor:
1454+
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
1455+
num_tokens = attn_metadata.query_start_loc[-1]
1456+
output, _ = torch_npu.npu_fused_infer_attention_score_v2(
1457+
query[:num_tokens],
1458+
key[:num_tokens],
1459+
value[:num_tokens],
1460+
atten_mask=attn_metadata.attn_mask.to(torch.bool),
1461+
actual_seq_qlen=attn_metadata.query_lens.cumsum(0),
1462+
actual_seq_kvlen=attn_metadata.seq_lens.cumsum(0),
1463+
num_query_heads=self.num_heads,
1464+
num_key_value_heads=self.num_kv_heads,
1465+
input_layout="TND",
1466+
softmax_scale=self.scale)
1467+
return output[:num_tokens]
1468+
else:
1469+
batch_size = attn_metadata.query_lens.shape[0]
1470+
block_table = attn_metadata.block_tables[:batch_size, :]
1471+
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
1472+
query = query.view(batch_size, 1, self.num_heads * self.head_size)
1473+
key = self.key_cache.flatten(2, 3).contiguous()
1474+
value = self.value_cache.flatten(2, 3).contiguous()
1475+
atten_mask=None
1476+
actual_seq_qlen=None
1477+
actual_seq_kvlen=attn_metadata.seq_lens
1478+
sparse_mode=0
1479+
input_layout="BSH"
1480+
else:
1481+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
1482+
key = self.key_cache.view( # type: ignore
1483+
num_block, block_size, -1)
1484+
value = self.value_cache.view( # type: ignore
1485+
num_block, block_size, -1)
1486+
input_layout="TND"
1487+
atten_mask=attn_metadata.attn_mask.to(torch.bool)
1488+
if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
1489+
actual_seq_qlen=attn_metadata.query_lens.cumsum(0)
1490+
actual_seq_kvlen=attn_metadata.seq_lens
1491+
sparse_mode=2
1492+
else:
1493+
actual_seq_qlen=attn_metadata.actual_seq_lengths_q
1494+
actual_seq_kvlen=attn_metadata.seq_lens_list
1495+
sparse_mode=0
1496+
output, _ = torch_npu.npu_fused_infer_attention_score_v2(
1497+
query=query,
1498+
key=key,
1499+
value=value,
1500+
block_table=block_table,
1501+
atten_mask=atten_mask,
1502+
actual_seq_qlen=actual_seq_qlen,
1503+
actual_seq_kvlen=actual_seq_kvlen,
1504+
num_query_heads=self.num_heads,
1505+
num_key_value_heads=self.num_kv_heads,
1506+
softmax_scale=self.scale,
1507+
sparse_mode=sparse_mode,
1508+
block_size=block_size,
1509+
input_layout=input_layout)
1510+
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
1511+
output = output.view(batch_size, self.num_heads, self.head_size)
1512+
return output
1513+
14511514
def forward(
14521515
self,
14531516
layer: AttentionLayer,
@@ -1507,12 +1570,20 @@ def forward(
15071570
if has_decode:
15081571
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
15091572
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
1510-
torch_npu._npu_reshape_and_cache(
1511-
key=key[:num_decode_tokens],
1512-
value=value[:num_decode_tokens],
1513-
key_cache=self.key_cache,
1514-
value_cache=self.value_cache,
1515-
slot_indices=slot_mapping)
1573+
if is_Ascend950():
1574+
num_tokens = slot_mapping.shape[0]
1575+
torch_npu.npu_scatter_pa_kv_cache(
1576+
key=key[:num_tokens],
1577+
value=value[:num_tokens].contiguous(),
1578+
slot_mapping=slot_mapping,
1579+
out=(self.key_cache, self.value_cache))
1580+
else:
1581+
torch_npu._npu_reshape_and_cache(
1582+
key=key[:num_decode_tokens],
1583+
value=value[:num_decode_tokens],
1584+
key_cache=self.key_cache,
1585+
value_cache=self.value_cache,
1586+
slot_indices=slot_mapping)
15161587

15171588
if has_prefill:
15181589
if self.pcp_size > 1:
@@ -1526,22 +1597,34 @@ def forward(
15261597
key, value = all_kv.split([self.head_size, self.head_size],
15271598
dim=-1)
15281599

1529-
torch_npu._npu_reshape_and_cache(
1530-
key=key[self.pcp_size * num_decode_tokens:attn_metadata.
1531-
num_actual_tokens_pcp_padded],
1532-
value=value[self.pcp_size *
1600+
if is_Ascend950():
1601+
num_tokens = attn_metadata.slot_mapping.shape[0]
1602+
torch_npu.npu_scatter_pa_kv_cache(
1603+
key=key[:num_tokens],
1604+
value=value[:num_tokens].contiguous(),
1605+
slot_mapping=attn_metadata.slot_mapping,
1606+
out=(self.key_cache, self.value_cache))
1607+
else:
1608+
torch_npu._npu_reshape_and_cache(
1609+
key=key[self.pcp_size *
15331610
num_decode_tokens:attn_metadata.
15341611
num_actual_tokens_pcp_padded],
1535-
key_cache=self.key_cache,
1536-
value_cache=self.value_cache,
1537-
slot_indices=attn_metadata.
1538-
slot_mapping[self.pcp_size *
1539-
num_decode_tokens:attn_metadata.
1540-
num_actual_tokens_pcp_padded])
1612+
value=value[self.pcp_size *
1613+
num_decode_tokens:attn_metadata.
1614+
num_actual_tokens_pcp_padded],
1615+
key_cache=self.key_cache,
1616+
value_cache=self.value_cache,
1617+
slot_indices=attn_metadata.
1618+
slot_mapping[self.pcp_size *
1619+
num_decode_tokens:attn_metadata.
1620+
num_actual_tokens_pcp_padded])
15411621

15421622
forward_context: ForwardContext = get_forward_context()
15431623
if not forward_context.capturing:
1544-
if self.pcp_size * self.dcp_size > 1:
1624+
if is_Ascend950():
1625+
intermediate_output =self._forward_ascend_950(
1626+
query, key, value, attn_metadata, output)
1627+
elif self.pcp_size * self.dcp_size > 1:
15451628
intermediate_output = self._forward_pcp_dcp(
15461629
query, key, value, kv_cache, attn_metadata, output)
15471630
elif attn_type == AttentionType.ENCODER_ONLY:

vllm_ascend/ops/rotary_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
YaRNScalingRotaryEmbedding)
2727

2828
from vllm_ascend.platform import NPUPlatform
29-
from vllm_ascend.utils import enable_custom_op, is_310p
29+
from vllm_ascend.utils import enable_custom_op, is_310p, is_Ascend950
3030

3131

3232
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
@@ -405,7 +405,7 @@ def forward_oot(
405405
query: torch.Tensor,
406406
key: torch.Tensor,
407407
):
408-
if self.mrope_section != [16, 24, 24]:
408+
if self.mrope_section != [16, 24, 24] or is_Ascend950():
409409
return super().forward_oot(positions, query, key)
410410

411411
import torch_npu

vllm_ascend/sample/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
44
from vllm.v1.sample.sampler import Sampler
55

6-
from vllm_ascend.utils import is_310p
6+
from vllm_ascend.utils import is_310p, is_Ascend950
77

88
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
99

@@ -25,8 +25,8 @@ def _apply_top_k_top_p(
2525
p: torch.Tensor,
2626
) -> torch.Tensor:
2727
# npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P
28-
if not is_310p() and p is not None and k is not None and 1 <= int(
29-
k.max()) <= 1024:
28+
if not is_310p() and not is_Ascend950() \
29+
and p is not None and k is not None and 1 <= int(k.max()) <= 1024:
3030
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
3131
return torch_npu.npu_top_k_top_p(logits, p, k)
3232

vllm_ascend/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
_CUSTOM_OP_ENABLED = None
5151
_IS_310P = None
52+
_IS_ASCEND950 = None
5253
_SLEEP_MODE_ENABLED = None
5354
_CURRENT_STREAM = None
5455
_PREFETCH_STREAM = None
@@ -668,7 +669,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
668669
class AscendSocVersion(Enum):
669670
A2 = 0
670671
A3 = 1
671-
UNDEFINED = 2
672+
A5 = 2
673+
UNDEFINED = 3
672674

673675

674676
_ascend_soc_version = None
@@ -681,6 +683,8 @@ def init_ascend_soc_version():
681683
_ascend_soc_version = AscendSocVersion.A2
682684
elif 250 <= soc_version <= 255:
683685
_ascend_soc_version = AscendSocVersion.A3
686+
elif soc_version == 260:
687+
_ascend_soc_version = AscendSocVersion.A5
684688
else:
685689
_ascend_soc_version = AscendSocVersion.UNDEFINED
686690

@@ -945,3 +949,10 @@ def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
945949
reorgnized_batch_ids.append(ranks)
946950

947951
return reorgnized_batch_ids
952+
953+
954+
def is_Ascend950():
955+
global _IS_ASCEND950
956+
if _IS_ASCEND950 is None:
957+
_IS_ASCEND950 = (get_ascend_soc_version() == AscendSocVersion.A5)
958+
return _IS_ASCEND950

vllm_ascend/worker/worker_v1.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
from vllm_ascend.device_allocator.camem import CaMemAllocator
4848
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
4949
from vllm_ascend.platform import NPUPlatform
50-
from vllm_ascend.utils import (init_ascend_soc_version, is_enable_nz,
51-
prefill_context_parallel_enable,
50+
from vllm_ascend.utils import (init_ascend_soc_version, is_Ascend950,
51+
is_enable_nz, prefill_context_parallel_enable,
5252
register_ascend_customop, sleep_mode_enabled,
5353
try_register_lib, vllm_version_is)
5454
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -342,7 +342,8 @@ def compile_or_warm_up_model(self) -> None:
342342
self.model_runner.capture_model()
343343
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
344344
# may cause performance degradation at runtime.
345-
self._warm_up_atb()
345+
if not is_Ascend950():
346+
self._warm_up_atb()
346347
# Reset the seed to ensure that the random state is not affected by
347348
# the model initialization and profiling.
348349
NPUPlatform.seed_everything(self.model_config.seed)

0 commit comments

Comments
 (0)