Skip to content

Commit fa1227c

Browse files
committed
ascend 950 support qwen dense model
Signed-off-by: wangyao <[email protected]>
1 parent e38ef2c commit fa1227c

File tree

8 files changed

+138
-32
lines changed

8 files changed

+138
-32
lines changed

tests/ut/ops/test_rotary_embedding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,8 @@ def _create_vllm_config(self):
428428
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
429429
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
430430
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
431-
def test_forward_oot_1d_positions(self, mock_npu_mrope):
431+
@patch('vllm_ascend.ops.rotary_embedding.is_Ascend950', return_value=False)
432+
def test_forward_oot_1d_positions(self, mock_is_ascend950, mock_npu_mrope):
432433
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
433434
torch.zeros_like(self.key))
434435

@@ -447,7 +448,8 @@ def test_forward_oot_1d_positions(self, mock_npu_mrope):
447448
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
448449
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
449450
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
450-
def test_forward_oot_2d_positions(self, mock_npu_mrope):
451+
@patch('vllm_ascend.ops.rotary_embedding.is_Ascend950', return_value=False)
452+
def test_forward_oot_2d_positions(self, mock_is_ascend950, mock_npu_mrope):
451453
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
452454
torch.zeros_like(self.key))
453455

tests/ut/sample/test_sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from unittest import mock
2+
from unittest.mock import patch
23

34
import torch
45

@@ -18,7 +19,8 @@ def test_init_with_raw_logprobs(self):
1819
class TestAscendTopKTopPSampler(TestBase):
1920

2021
@mock.patch("torch_npu.npu_top_k_top_p")
21-
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
22+
@patch('vllm_ascend.sample.sampler.is_Ascend950', return_value=False)
23+
def test_npu_topk_topp_called_when_optimized(self, mock_is_ascend950, mock_npu_op):
2224
mock_npu_op.return_value = (torch.randn(1, 3))
2325
sampler = AscendTopKTopPSampler()
2426

tests/ut/worker/test_worker_v1.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,9 @@ def test_load_model_sleep_mode_assertion_error(self, mock_allocator_class):
10751075
@patch("vllm_ascend.worker.worker_v1.NPUPlatform.seed_everything")
10761076
@patch("vllm_ascend.worker.worker_v1.logger")
10771077
@patch("vllm_ascend.worker.worker_v1.NPUWorker._warm_up_atb")
1078-
def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb,
1078+
@patch('vllm_ascend.worker.worker_v1.is_Ascend950', return_value=False)
1079+
def test_compile_or_warm_up_model_with_eager_mode(self, mock_is_ascend950,
1080+
mock_warm_up_atb,
10791081
mock_logger,
10801082
mock_seed_everything):
10811083
"""Test compile_or_warm_up_model method - eager mode"""
@@ -1124,8 +1126,10 @@ def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb,
11241126
@patch("vllm_ascend.worker.worker_v1.NPUPlatform.seed_everything")
11251127
@patch("vllm_ascend.worker.worker_v1.logger")
11261128
@patch("vllm_ascend.worker.worker_v1.NPUWorker._warm_up_atb")
1129+
@patch('vllm_ascend.worker.worker_v1', return_value=False)
11271130
def test_compile_or_warm_up_model_with_graph_capture(
1128-
self, mock_warm_up_atb, mock_logger, mock_seed_everything):
1131+
self, mock_is_ascend950, mock_warm_up_atb, mock_logger,
1132+
mock_seed_everything):
11291133
"""Test compile_or_warm_up_model method - with graph capture enabled"""
11301134
from vllm_ascend.worker.worker_v1 import NPUWorker
11311135

vllm_ascend/attention/attention_v1.py

Lines changed: 104 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
update_graph_params_workspaces)
5151
from vllm_ascend.ops.attention import vanilla_chunked_prefill
5252
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
53-
nd_to_nz_2d, nd_to_nz_spec,
53+
is_Ascend950, nd_to_nz_2d, nd_to_nz_spec,
5454
prefill_context_parallel_enable,
5555
weak_ref_tensors)
5656

@@ -1448,6 +1448,72 @@ def _load_kv_for_chunk(self, attn_metadata, kv_cache,
14481448
)
14491449
return key, value
14501450

1451+
def _forward_ascend_950(self, query: torch.Tensor, key: torch.Tensor,
1452+
value: torch.Tensor, attn_metadata: AscendMetadata,
1453+
output: torch.Tensor) -> torch.Tensor:
1454+
num_tokens = attn_metadata.query_start_loc[-1]
1455+
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
1456+
output, _ = torch_npu.npu_fused_infer_attention_score_v2(
1457+
query[:num_tokens],
1458+
key[:num_tokens],
1459+
value[:num_tokens],
1460+
atten_mask=attn_metadata.attn_mask.to(torch.bool), # type: ignore
1461+
actual_seq_qlen=attn_metadata.query_lens.cumsum(0),
1462+
actual_seq_kvlen=attn_metadata.seq_lens.cumsum(0),
1463+
num_query_heads=self.num_heads,
1464+
num_key_value_heads=self.num_kv_heads,
1465+
input_layout="TND",
1466+
softmax_scale=self.scale)
1467+
return output[:num_tokens]
1468+
else:
1469+
batch_size = attn_metadata.query_lens.shape[0]
1470+
block_table = attn_metadata.block_tables[:batch_size, :]
1471+
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
1472+
query = query.view(batch_size, 1,
1473+
self.num_heads * self.head_size)
1474+
key = self.key_cache.flatten(2, 3).contiguous() # type: ignore
1475+
value = self.value_cache.flatten(2, 3).contiguous() # type: ignore
1476+
atten_mask = None
1477+
actual_seq_qlen = None
1478+
actual_seq_kvlen = attn_metadata.seq_lens
1479+
sparse_mode = 0
1480+
input_layout = "BSH"
1481+
else:
1482+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
1483+
key = self.key_cache.view( # type: ignore
1484+
num_block, block_size, -1)
1485+
value = self.value_cache.view( # type: ignore
1486+
num_block, block_size, -1)
1487+
input_layout = "TND"
1488+
atten_mask = attn_metadata.attn_mask.to(torch.bool) # type: ignore
1489+
if attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
1490+
actual_seq_qlen = attn_metadata.query_lens.cumsum(0)
1491+
actual_seq_kvlen = attn_metadata.seq_lens
1492+
sparse_mode = 2
1493+
else:
1494+
query = query[:num_tokens]
1495+
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
1496+
actual_seq_kvlen = attn_metadata.seq_lens_list
1497+
sparse_mode = 0
1498+
output, _ = torch_npu.npu_fused_infer_attention_score_v2(
1499+
query=query,
1500+
key=key,
1501+
value=value,
1502+
block_table=block_table,
1503+
atten_mask=atten_mask,
1504+
actual_seq_qlen=actual_seq_qlen,
1505+
actual_seq_kvlen=actual_seq_kvlen,
1506+
num_query_heads=self.num_heads,
1507+
num_key_value_heads=self.num_kv_heads,
1508+
softmax_scale=self.scale,
1509+
sparse_mode=sparse_mode,
1510+
block_size=block_size,
1511+
input_layout=input_layout)
1512+
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
1513+
output = output.view(batch_size, self.num_heads,
1514+
self.head_size)
1515+
return output
1516+
14511517
def forward(
14521518
self,
14531519
layer: AttentionLayer,
@@ -1507,12 +1573,20 @@ def forward(
15071573
if has_decode:
15081574
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
15091575
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
1510-
torch_npu._npu_reshape_and_cache(
1511-
key=key[:num_decode_tokens],
1512-
value=value[:num_decode_tokens],
1513-
key_cache=self.key_cache,
1514-
value_cache=self.value_cache,
1515-
slot_indices=slot_mapping)
1576+
if is_Ascend950():
1577+
num_tokens = slot_mapping.shape[0]
1578+
torch_npu.npu_scatter_pa_kv_cache(
1579+
key=key[:num_tokens],
1580+
value=value[:num_tokens].contiguous(),
1581+
slot_mapping=slot_mapping,
1582+
out=(self.key_cache, self.value_cache))
1583+
else:
1584+
torch_npu._npu_reshape_and_cache(
1585+
key=key[:num_decode_tokens],
1586+
value=value[:num_decode_tokens],
1587+
key_cache=self.key_cache,
1588+
value_cache=self.value_cache,
1589+
slot_indices=slot_mapping)
15161590

15171591
if has_prefill:
15181592
if self.pcp_size > 1:
@@ -1526,22 +1600,34 @@ def forward(
15261600
key, value = all_kv.split([self.head_size, self.head_size],
15271601
dim=-1)
15281602

1529-
torch_npu._npu_reshape_and_cache(
1530-
key=key[self.pcp_size * num_decode_tokens:attn_metadata.
1531-
num_actual_tokens_pcp_padded],
1532-
value=value[self.pcp_size *
1603+
if is_Ascend950():
1604+
num_tokens = attn_metadata.slot_mapping.shape[0]
1605+
torch_npu.npu_scatter_pa_kv_cache(
1606+
key=key[:num_tokens],
1607+
value=value[:num_tokens].contiguous(),
1608+
slot_mapping=attn_metadata.slot_mapping,
1609+
out=(self.key_cache, self.value_cache))
1610+
else:
1611+
torch_npu._npu_reshape_and_cache(
1612+
key=key[self.pcp_size *
15331613
num_decode_tokens:attn_metadata.
15341614
num_actual_tokens_pcp_padded],
1535-
key_cache=self.key_cache,
1536-
value_cache=self.value_cache,
1537-
slot_indices=attn_metadata.
1538-
slot_mapping[self.pcp_size *
1539-
num_decode_tokens:attn_metadata.
1540-
num_actual_tokens_pcp_padded])
1615+
value=value[self.pcp_size *
1616+
num_decode_tokens:attn_metadata.
1617+
num_actual_tokens_pcp_padded],
1618+
key_cache=self.key_cache,
1619+
value_cache=self.value_cache,
1620+
slot_indices=attn_metadata.
1621+
slot_mapping[self.pcp_size *
1622+
num_decode_tokens:attn_metadata.
1623+
num_actual_tokens_pcp_padded])
15411624

15421625
forward_context: ForwardContext = get_forward_context()
15431626
if not forward_context.capturing:
1544-
if self.pcp_size * self.dcp_size > 1:
1627+
if is_Ascend950():
1628+
intermediate_output = self._forward_ascend_950(
1629+
query, key, value, attn_metadata, output)
1630+
elif self.pcp_size * self.dcp_size > 1:
15451631
intermediate_output = self._forward_pcp_dcp(
15461632
query, key, value, kv_cache, attn_metadata, output)
15471633
elif attn_type == AttentionType.ENCODER_ONLY:

vllm_ascend/ops/rotary_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
YaRNScalingRotaryEmbedding)
2727

2828
from vllm_ascend.platform import NPUPlatform
29-
from vllm_ascend.utils import enable_custom_op, is_310p
29+
from vllm_ascend.utils import enable_custom_op, is_310p, is_Ascend950
3030

3131

3232
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
@@ -405,7 +405,7 @@ def forward_oot(
405405
query: torch.Tensor,
406406
key: torch.Tensor,
407407
):
408-
if self.mrope_section != [16, 24, 24]:
408+
if self.mrope_section != [16, 24, 24] or is_Ascend950():
409409
return super().forward_oot(positions, query, key)
410410

411411
import torch_npu

vllm_ascend/sample/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
44
from vllm.v1.sample.sampler import Sampler
55

6-
from vllm_ascend.utils import is_310p
6+
from vllm_ascend.utils import is_310p, is_Ascend950
77

88
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
99

@@ -25,8 +25,8 @@ def _apply_top_k_top_p(
2525
p: torch.Tensor,
2626
) -> torch.Tensor:
2727
# npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P
28-
if not is_310p() and p is not None and k is not None and 1 <= int(
29-
k.max()) <= 1024:
28+
if not is_310p() and not is_Ascend950() \
29+
and p is not None and k is not None and 1 <= int(k.max()) <= 1024:
3030
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
3131
return torch_npu.npu_top_k_top_p(logits, p, k)
3232

vllm_ascend/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
_CUSTOM_OP_ENABLED = None
5151
_IS_310P = None
52+
_IS_ASCEND950 = None
5253
_SLEEP_MODE_ENABLED = None
5354
_CURRENT_STREAM = None
5455
_PREFETCH_STREAM = None
@@ -668,7 +669,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
668669
class AscendSocVersion(Enum):
669670
A2 = 0
670671
A3 = 1
671-
UNDEFINED = 2
672+
A5 = 2
673+
UNDEFINED = 3
672674

673675

674676
_ascend_soc_version = None
@@ -681,6 +683,8 @@ def init_ascend_soc_version():
681683
_ascend_soc_version = AscendSocVersion.A2
682684
elif 250 <= soc_version <= 255:
683685
_ascend_soc_version = AscendSocVersion.A3
686+
elif soc_version == 260:
687+
_ascend_soc_version = AscendSocVersion.A5
684688
else:
685689
_ascend_soc_version = AscendSocVersion.UNDEFINED
686690

@@ -945,3 +949,10 @@ def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]:
945949
reorgnized_batch_ids.append(ranks)
946950

947951
return reorgnized_batch_ids
952+
953+
954+
def is_Ascend950():
955+
global _IS_ASCEND950
956+
if _IS_ASCEND950 is None:
957+
_IS_ASCEND950 = (get_ascend_soc_version() == AscendSocVersion.A5)
958+
return _IS_ASCEND950

vllm_ascend/worker/worker_v1.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
from vllm_ascend.device_allocator.camem import CaMemAllocator
4848
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
4949
from vllm_ascend.platform import NPUPlatform
50-
from vllm_ascend.utils import (init_ascend_soc_version, is_enable_nz,
51-
prefill_context_parallel_enable,
50+
from vllm_ascend.utils import (init_ascend_soc_version, is_Ascend950,
51+
is_enable_nz, prefill_context_parallel_enable,
5252
register_ascend_customop, sleep_mode_enabled,
5353
try_register_lib, vllm_version_is)
5454
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -342,7 +342,8 @@ def compile_or_warm_up_model(self) -> None:
342342
self.model_runner.capture_model()
343343
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
344344
# may cause performance degradation at runtime.
345-
self._warm_up_atb()
345+
if not is_Ascend950():
346+
self._warm_up_atb()
346347
# Reset the seed to ensure that the random state is not affected by
347348
# the model initialization and profiling.
348349
NPUPlatform.seed_everything(self.model_config.seed)

0 commit comments

Comments
 (0)