Skip to content

Commit c8d5929

Browse files
committed
ascend 950 support qwen dense model
Signed-off-by: wangyao <[email protected]>
1 parent 2fa3945 commit c8d5929

File tree

4 files changed

+51
-20
lines changed

4 files changed

+51
-20
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def gen_build_info():
160160
"ascend310p3vir02": "_310P",
161161
"ascend310p3vir04": "_310P",
162162
"ascend310p3vir08": "_310P",
163+
"ascend910_9599": "_910_95",
163164
}
164165

165166
assert soc_version in soc_to_device, f"Undefined soc_version: {soc_version}. Please file an issue to vllm-ascend."

vllm_ascend/attention/attention_v1.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
split_decodes_and_prefills)
4242
from vllm_ascend.compilation.acl_graph import (get_graph_params,
4343
update_graph_params_workspaces)
44-
from vllm_ascend.utils import prefill_context_parallel_enable, weak_ref_tensors
44+
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
45+
prefill_context_parallel_enable,
46+
weak_ref_tensors)
4547

4648
# isort: off
4749
if prefill_context_parallel_enable():
@@ -1421,12 +1423,20 @@ def forward(
14211423
if has_decode:
14221424
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
14231425
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
1424-
torch_npu._npu_reshape_and_cache(
1425-
key=key[:num_decode_tokens],
1426-
value=value[:num_decode_tokens],
1427-
key_cache=self.key_cache,
1428-
value_cache=self.value_cache,
1429-
slot_indices=slot_mapping)
1426+
if get_ascend_device_type() == AscendDeviceType._910_95:
1427+
torch_npu.npu_scatter_pa_kv_cache(
1428+
key=key[:num_decode_tokens],
1429+
value=value[:num_decode_tokens],
1430+
key_cache=self.key_cache,
1431+
value_cache=self.value_cache,
1432+
slot_indices=slot_mapping)
1433+
else:
1434+
torch_npu._npu_reshape_and_cache(
1435+
key=key[:num_decode_tokens],
1436+
value=value[:num_decode_tokens],
1437+
key_cache=self.key_cache,
1438+
value_cache=self.value_cache,
1439+
slot_indices=slot_mapping)
14301440

14311441
if has_prefill:
14321442
if self.pcp_size > 1:
@@ -1440,18 +1450,35 @@ def forward(
14401450
key, value = all_kv.split([self.head_size, self.head_size],
14411451
dim=-1)
14421452

1443-
torch_npu._npu_reshape_and_cache(
1444-
key=key[self.pcp_size * num_decode_tokens:attn_metadata.
1445-
num_actual_tokens_pcp_padded],
1446-
value=value[self.pcp_size *
1453+
if get_ascend_device_type() == AscendDeviceType._910_95:
1454+
torch_npu.npu_scatter_pa_kv_cache(
1455+
key=key[self.pcp_size *
14471456
num_decode_tokens:attn_metadata.
14481457
num_actual_tokens_pcp_padded],
1449-
key_cache=self.key_cache,
1450-
value_cache=self.value_cache,
1451-
slot_indices=attn_metadata.
1452-
slot_mapping[self.pcp_size *
1453-
num_decode_tokens:attn_metadata.
1454-
num_actual_tokens_pcp_padded])
1458+
value=value[self.pcp_size *
1459+
num_decode_tokens:attn_metadata.
1460+
num_actual_tokens_pcp_padded].contiguous(),
1461+
key_cache=self.key_cache,
1462+
value_cache=self.value_cache,
1463+
slot_mapping=attn_metadata.
1464+
slot_mapping[self.pcp_size *
1465+
num_decode_tokens:attn_metadata.
1466+
num_actual_tokens_pcp_padded],
1467+
out=(self.key_cache, self.value_cache))
1468+
else:
1469+
torch_npu._npu_reshape_and_cache(
1470+
key=key[self.pcp_size *
1471+
num_decode_tokens:attn_metadata.
1472+
num_actual_tokens_pcp_padded],
1473+
value=value[self.pcp_size *
1474+
num_decode_tokens:attn_metadata.
1475+
num_actual_tokens_pcp_padded],
1476+
key_cache=self.key_cache,
1477+
value_cache=self.value_cache,
1478+
slot_indices=attn_metadata.
1479+
slot_mapping[self.pcp_size *
1480+
num_decode_tokens:attn_metadata.
1481+
num_actual_tokens_pcp_padded])
14551482

14561483
forward_context: ForwardContext = get_forward_context()
14571484
if not forward_context.capturing:

vllm_ascend/ops/rotary_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,8 @@ def forward_oot(
411411
# TODO: This judgment will be removed once the mrope precision issue is fixed
412412
if self.mrope_section != [
413413
16, 24, 24
414-
] or NPUPlatform.get_cpu_architecture() == CpuArchEnum.X86:
414+
] or NPUPlatform.get_cpu_architecture() == CpuArchEnum.X86 or \
415+
get_ascend_device_type() == AscendDeviceType._910_95:
415416
return super().forward_oot(positions, query, key)
416417

417418
import torch_npu

vllm_ascend/worker/worker_v1.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@
5050
from vllm_ascend.device_allocator.camem import CaMemAllocator
5151
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
5252
from vllm_ascend.platform import NPUPlatform
53-
from vllm_ascend.utils import (check_ascend_device_type, is_enable_nz,
53+
from vllm_ascend.utils import (AscendDeviceType, check_ascend_device_type,
54+
get_ascend_device_type, is_enable_nz,
5455
prefill_context_parallel_enable,
5556
register_ascend_customop, sleep_mode_enabled,
5657
try_register_lib)
@@ -355,7 +356,8 @@ def compile_or_warm_up_model(self) -> None:
355356
self.model_runner.capture_model()
356357
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
357358
# may cause performance degradation at runtime.
358-
self._warm_up_atb()
359+
if get_ascend_device_type() != AscendDeviceType._910_95:
360+
self._warm_up_atb()
359361
# Reset the seed to ensure that the random state is not affected by
360362
# the model initialization and profiling.
361363
NPUPlatform.seed_everything(self.model_config.seed)

0 commit comments

Comments
 (0)