From 56d156ddcb928e354a86896f82821015caaae327 Mon Sep 17 00:00:00 2001 From: wangyao Date: Tue, 2 Dec 2025 17:19:00 +0800 Subject: [PATCH] ascend 950 support qwen dense model Signed-off-by: wangyao --- setup.py | 1 + vllm_ascend/attention/attention_v1.py | 61 +++++++++++++++++++-------- vllm_ascend/ops/rotary_embedding.py | 3 +- vllm_ascend/worker/worker_v1.py | 6 ++- 4 files changed, 51 insertions(+), 20 deletions(-) diff --git a/setup.py b/setup.py index 890b5228e57..b803124cf8f 100644 --- a/setup.py +++ b/setup.py @@ -160,6 +160,7 @@ def gen_build_info(): "ascend310p3vir02": "_310P", "ascend310p3vir04": "_310P", "ascend310p3vir08": "_310P", + "ascend910_9579": "_910_95", } assert soc_version in soc_to_device, f"Undefined soc_version: {soc_version}. Please file an issue to vllm-ascend." diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index b524e648cf0..b0cfd26e530 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -41,7 +41,9 @@ split_decodes_and_prefills) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) -from vllm_ascend.utils import prefill_context_parallel_enable, weak_ref_tensors +from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, + prefill_context_parallel_enable, + weak_ref_tensors) # isort: off if prefill_context_parallel_enable(): @@ -1413,12 +1415,20 @@ def forward( if has_decode: slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \ if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens] - torch_npu._npu_reshape_and_cache( - key=key[:num_decode_tokens], - value=value[:num_decode_tokens], - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_indices=slot_mapping) + if get_ascend_device_type() == AscendDeviceType._910_95: + torch_npu.npu_scatter_pa_kv_cache( + key=key[:num_decode_tokens], + value=value[:num_decode_tokens], + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slot_mapping) + else: + torch_npu._npu_reshape_and_cache( + key=key[:num_decode_tokens], + value=value[:num_decode_tokens], + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slot_mapping) if has_prefill: if self.pcp_size > 1: @@ -1432,18 +1442,35 @@ def forward( key, value = all_kv.split([self.head_size, self.head_size], dim=-1) - torch_npu._npu_reshape_and_cache( - key=key[self.pcp_size * num_decode_tokens:attn_metadata. - num_actual_tokens_pcp_padded], - value=value[self.pcp_size * + if get_ascend_device_type() == AscendDeviceType._910_95: + torch_npu.npu_scatter_pa_kv_cache( + key=key[self.pcp_size * num_decode_tokens:attn_metadata. num_actual_tokens_pcp_padded], - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_indices=attn_metadata. - slot_mapping[self.pcp_size * - num_decode_tokens:attn_metadata. - num_actual_tokens_pcp_padded]) + value=value[self.pcp_size * + num_decode_tokens:attn_metadata. + num_actual_tokens_pcp_padded].contiguous(), + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_mapping=attn_metadata. + slot_mapping[self.pcp_size * + num_decode_tokens:attn_metadata. + num_actual_tokens_pcp_padded], + out=(self.key_cache, self.value_cache)) + else: + torch_npu._npu_reshape_and_cache( + key=key[self.pcp_size * + num_decode_tokens:attn_metadata. + num_actual_tokens_pcp_padded], + value=value[self.pcp_size * + num_decode_tokens:attn_metadata. + num_actual_tokens_pcp_padded], + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=attn_metadata. + slot_mapping[self.pcp_size * + num_decode_tokens:attn_metadata. + num_actual_tokens_pcp_padded]) forward_context: ForwardContext = get_forward_context() if not forward_context.capturing: diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index ee9dd9f9302..dccbdc34e82 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -414,7 +414,8 @@ def forward_oot( # TODO: This judgment will be removed once the mrope precision issue is fixed if self.mrope_section != [ 16, 24, 24 - ] or NPUPlatform.get_cpu_architecture() == CpuArchEnum.X86: + ] or NPUPlatform.get_cpu_architecture() == CpuArchEnum.X86 or \ + get_ascend_device_type() == AscendDeviceType._910_95: return super().forward_oot(positions, query, key) import torch_npu diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 41b6abb9fd9..29efde09074 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -52,7 +52,8 @@ from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import (check_ascend_device_type, is_enable_nz, +from vllm_ascend.utils import (AscendDeviceType, check_ascend_device_type, + get_ascend_device_type, is_enable_nz, prefill_context_parallel_enable, register_ascend_customop, sleep_mode_enabled, try_register_lib) @@ -344,7 +345,8 @@ def compile_or_warm_up_model(self) -> None: self.model_runner.capture_model() # Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache) # may cause performance degradation at runtime. - self._warm_up_atb() + if get_ascend_device_type() != AscendDeviceType._910_95: + self._warm_up_atb() # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. NPUPlatform.seed_everything(self.model_config.seed)