Skip to content

Commit 63bbe41

Browse files
committed
ascend 950 support qwen dense model
Signed-off-by: wangyao <[email protected]>
1 parent 2b82320 commit 63bbe41

File tree

3 files changed

+113
-19
lines changed

3 files changed

+113
-19
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 108 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,6 +1464,72 @@ def _load_kv_for_chunk(self, attn_metadata, kv_cache,
14641464
)
14651465
return key, value
14661466

1467+
def _forward_ascend_950(self, query: torch.Tensor, key: torch.Tensor,
1468+
value: torch.Tensor, attn_metadata: AscendMetadata,
1469+
output: torch.Tensor) -> torch.Tensor:
1470+
num_tokens = attn_metadata.query_start_loc[-1]
1471+
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
1472+
output_data, _ = torch_npu.npu_fused_infer_attention_score_v2(
1473+
query[:num_tokens],
1474+
key[:num_tokens],
1475+
value[:num_tokens],
1476+
atten_mask=attn_metadata.attn_mask.to( # type: ignore
1477+
torch.bool),
1478+
actual_seq_qlen=attn_metadata.query_lens.cumsum(0),
1479+
actual_seq_kvlen=attn_metadata.seq_lens.cumsum(0),
1480+
num_query_heads=self.num_heads,
1481+
num_key_value_heads=self.num_kv_heads,
1482+
input_layout="TND",
1483+
softmax_scale=self.scale)
1484+
else:
1485+
batch_size = attn_metadata.query_lens.shape[0]
1486+
block_table = attn_metadata.block_tables[:batch_size, :]
1487+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
1488+
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
1489+
query = query[:batch_size]
1490+
query = query.view(batch_size, 1,
1491+
self.num_heads * self.head_size)
1492+
key = self.key_cache.flatten(2, 3).contiguous() # type: ignore
1493+
value = self.value_cache.flatten( # type: ignore
1494+
2, 3).contiguous()
1495+
atten_mask = None
1496+
actual_seq_qlen = None
1497+
actual_seq_kvlen = attn_metadata.seq_lens
1498+
sparse_mode = 0
1499+
input_layout = "BSH"
1500+
else:
1501+
query = query[:num_tokens]
1502+
key = self.key_cache.view( # type: ignore
1503+
num_block, block_size, -1)
1504+
value = self.value_cache.view( # type: ignore
1505+
num_block, block_size, -1)
1506+
input_layout = "TND"
1507+
atten_mask = attn_metadata.attn_mask
1508+
actual_seq_qlen = attn_metadata.actual_seq_lengths_q
1509+
actual_seq_kvlen = attn_metadata.seq_lens_list
1510+
sparse_mode = 3
1511+
output_data, _ = torch_npu.npu_fused_infer_attention_score_v2(
1512+
query=query,
1513+
key=key,
1514+
value=value,
1515+
block_table=block_table,
1516+
atten_mask=atten_mask,
1517+
actual_seq_qlen=actual_seq_qlen,
1518+
actual_seq_kvlen=actual_seq_kvlen,
1519+
num_query_heads=self.num_heads,
1520+
num_key_value_heads=self.num_kv_heads,
1521+
softmax_scale=self.scale,
1522+
sparse_mode=sparse_mode,
1523+
block_size=block_size,
1524+
input_layout=input_layout)
1525+
if attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
1526+
output[:batch_size] = output_data.view(batch_size,
1527+
self.num_heads,
1528+
self.head_size)
1529+
else:
1530+
output[:num_tokens] = output_data
1531+
return output
1532+
14671533
def forward(
14681534
self,
14691535
layer: AttentionLayer,
@@ -1523,12 +1589,19 @@ def forward(
15231589
if has_decode:
15241590
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
15251591
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
1526-
torch_npu._npu_reshape_and_cache(
1527-
key=key[:num_decode_tokens],
1528-
value=value[:num_decode_tokens],
1529-
key_cache=self.key_cache,
1530-
value_cache=self.value_cache,
1531-
slot_indices=slot_mapping)
1592+
if get_ascend_device_type() == AscendDeviceType._910_95:
1593+
torch_npu.npu_scatter_pa_kv_cache(
1594+
key=key[:num_decode_tokens],
1595+
value=value[:num_decode_tokens].contiguous(),
1596+
slot_mapping=slot_mapping,
1597+
out=(self.key_cache, self.value_cache))
1598+
else:
1599+
torch_npu._npu_reshape_and_cache(
1600+
key=key[:num_decode_tokens],
1601+
value=value[:num_decode_tokens],
1602+
key_cache=self.key_cache,
1603+
value_cache=self.value_cache,
1604+
slot_indices=slot_mapping)
15321605

15331606
if has_prefill:
15341607
if self.pcp_size > 1:
@@ -1542,22 +1615,40 @@ def forward(
15421615
key, value = all_kv.split([self.head_size, self.head_size],
15431616
dim=-1)
15441617

1545-
torch_npu._npu_reshape_and_cache(
1546-
key=key[self.pcp_size * num_decode_tokens:attn_metadata.
1547-
num_actual_tokens_pcp_padded],
1548-
value=value[self.pcp_size *
1618+
if get_ascend_device_type() == AscendDeviceType._910_95:
1619+
torch_npu.npu_scatter_pa_kv_cache(
1620+
key=key[self.pcp_size *
1621+
num_decode_tokens:attn_metadata.
1622+
num_actual_tokens_pcp_padded],
1623+
value=value[self.pcp_size *
1624+
num_decode_tokens:attn_metadata.
1625+
num_actual_tokens_pcp_padded].contiguous(),
1626+
slot_mapping=attn_metadata.
1627+
slot_mapping[self.pcp_size *
1628+
num_decode_tokens:attn_metadata.
1629+
num_actual_tokens_pcp_padded],
1630+
out=(self.key_cache, self.value_cache))
1631+
else:
1632+
torch_npu._npu_reshape_and_cache(
1633+
key=key[self.pcp_size *
15491634
num_decode_tokens:attn_metadata.
15501635
num_actual_tokens_pcp_padded],
1551-
key_cache=self.key_cache,
1552-
value_cache=self.value_cache,
1553-
slot_indices=attn_metadata.
1554-
slot_mapping[self.pcp_size *
1555-
num_decode_tokens:attn_metadata.
1556-
num_actual_tokens_pcp_padded])
1636+
value=value[self.pcp_size *
1637+
num_decode_tokens:attn_metadata.
1638+
num_actual_tokens_pcp_padded],
1639+
key_cache=self.key_cache,
1640+
value_cache=self.value_cache,
1641+
slot_indices=attn_metadata.
1642+
slot_mapping[self.pcp_size *
1643+
num_decode_tokens:attn_metadata.
1644+
num_actual_tokens_pcp_padded])
15571645

15581646
forward_context: ForwardContext = get_forward_context()
15591647
if not forward_context.capturing:
1560-
if self.pcp_size * self.dcp_size > 1:
1648+
if get_ascend_device_type() == AscendDeviceType._910_95:
1649+
intermediate_output = self._forward_ascend_950(
1650+
query, key, value, attn_metadata, output)
1651+
elif self.pcp_size * self.dcp_size > 1:
15611652
intermediate_output = self._forward_pcp_dcp(
15621653
query, key, value, kv_cache, attn_metadata, output)
15631654
elif attn_type == AttentionType.ENCODER_ONLY:

vllm_ascend/ops/rotary_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,8 @@ def forward_oot(
411411
# TODO: This judgment will be removed once the mrope precision issue is fixed
412412
if self.mrope_section != [
413413
16, 24, 24
414-
] or NPUPlatform.get_cpu_architecture() == CpuArchEnum.X86:
414+
] or NPUPlatform.get_cpu_architecture() == CpuArchEnum.X86 or
415+
get_ascend_device_type() == AscendDeviceType._910_95:
415416
return super().forward_oot(positions, query, key)
416417

417418
import torch_npu

vllm_ascend/worker/worker_v1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
5252
from vllm_ascend.platform import NPUPlatform
5353
from vllm_ascend.utils import (check_ascend_device_type, is_enable_nz,
54+
AscendDeviceType, get_ascend_device_type,
5455
prefill_context_parallel_enable,
5556
register_ascend_customop, sleep_mode_enabled,
5657
try_register_lib)
@@ -355,7 +356,8 @@ def compile_or_warm_up_model(self) -> None:
355356
self.model_runner.capture_model()
356357
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
357358
# may cause performance degradation at runtime.
358-
self._warm_up_atb()
359+
if get_ascend_device_type() != AscendDeviceType._910_95::
360+
self._warm_up_atb()
359361
# Reset the seed to ensure that the random state is not affected by
360362
# the model initialization and profiling.
361363
NPUPlatform.seed_everything(self.model_config.seed)

0 commit comments

Comments
 (0)