From eba8d344c1a18795b673d87b9b1d4371aa6a26a5 Mon Sep 17 00:00:00 2001 From: leo-pony Date: Mon, 27 Oct 2025 11:36:43 +0800 Subject: [PATCH] fix issues of 310p on main Signed-off-by: leo-pony --- .github/workflows/vllm_ascend_test_310p.yaml | 2 - tests/e2e/310p/test_offline_inference_310p.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 85 +++++++++++++++++++ vllm_ascend/worker/worker_v1.py | 5 +- 4 files changed, 89 insertions(+), 5 deletions(-) diff --git a/.github/workflows/vllm_ascend_test_310p.yaml b/.github/workflows/vllm_ascend_test_310p.yaml index 1de447fc31d..243f65bb93d 100644 --- a/.github/workflows/vllm_ascend_test_310p.yaml +++ b/.github/workflows/vllm_ascend_test_310p.yaml @@ -112,6 +112,4 @@ jobs: run: | if [[ "${{ matrix.os }}" == "linux-aarch64-310p-1" ]]; then pytest -sv tests/e2e/310p/test_offline_inference_310p.py - else - pytest -sv tests/e2e/310p/test_offline_inference_parallel_310p.py fi diff --git a/tests/e2e/310p/test_offline_inference_310p.py b/tests/e2e/310p/test_offline_inference_310p.py index 31f7eb92120..47460b59f66 100644 --- a/tests/e2e/310p/test_offline_inference_310p.py +++ b/tests/e2e/310p/test_offline_inference_310p.py @@ -48,7 +48,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None: VL_MODELS = ["Qwen/Qwen2.5-VL-3B-Instruct"] -@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("model", VL_MODELS) @pytest.mark.parametrize("dtype", ["float16"]) def test_vl_model_with_samples(model: str, dtype: str) -> None: example_prompts = [ diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f30a9a39b44..abc829895a5 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -3183,6 +3183,87 @@ def initialize_kv_cache_tensors_deepseek_mla( return kv_caches + def _initialize_kv_cache_tensors_310p( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + kv_cache_sizes = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + assert len(kv_cache_tensor.shared_by) == 1, ( + "KV cache tensor shared by multiple layers is not supported in " + "310p NPU.") + kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + + kv_caches: Dict[str, torch.Tensor] = {} + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec + attn_backend = group.backend + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + tensor_size = kv_cache_sizes[layer_name] + assert tensor_size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_size // kv_cache_spec.page_size_bytes + + # `num_blocks` is the number of blocks the model runner can use. + # `kv_cache_config.num_blocks` is the number of blocks that + # KVCacheManager may allocate. + # Since different GPUs may have different number of layers and + # different memory capacities, `num_blocks` can be different on + # different GPUs, and `kv_cache_config.num_blocks` is set to + # the min of all `num_blocks`. Verify it here. + assert num_blocks >= kv_cache_config.num_blocks + + # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may + # encounter OOM issue + if isinstance(kv_cache_spec, FullAttentionSpec): + if self.vllm_config.additional_config.get( + "kv_cache_dtype", None) == 'int8': + kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) + elif hasattr(attn_backend, "get_supported_block_size" + ) and self.use_hybrid_blocks: + block_size = attn_backend.get_supported_block_size()[0] + + block_size_chunk = kv_cache_spec.block_size // block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks * block_size_chunk, block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) + else: + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) + dtype = kv_cache_spec.dtype + + if "attn" in layer_name: + # for self_attn, sliding window attn + if self.vllm_config.kv_transfer_config is None: + k_tensor = torch.zeros(kv_cache_shape[1:], + dtype=dtype, + device=self.device) + v_tensor = torch.zeros(kv_cache_shape[1:], + dtype=dtype, + device=self.device) + k_cache = torch_npu.npu_format_cast( + k_tensor, ACL_FORMAT) + v_cache = torch_npu.npu_format_cast( + v_tensor, ACL_FORMAT) + + kv_caches[layer_name] = (k_cache, v_cache) + else: + raise ValueError( + "KV cache transfer is not supported for 310p.") + else: + raise ValueError("Unknown KV cache spec type.") + + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + + return kv_caches + def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ @@ -3194,6 +3275,10 @@ def initialize_kv_cache_tensors( Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ + + if is_310p(): + return self._initialize_kv_cache_tensors_310p(kv_cache_config) + # init kv cache tensors kv_cache_raw_tensors: dict[str, Union[torch.Tensor, Optional[torch.Tensor]]] = {} diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index e8729925fad..5fe146d6f64 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -47,7 +47,7 @@ from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import (init_ascend_soc_version, +from vllm_ascend.utils import (init_ascend_soc_version, is_310p, prefill_context_parallel_enable, register_ascend_customop, sleep_mode_enabled, try_register_lib, vllm_version_is) @@ -332,7 +332,8 @@ def compile_or_warm_up_model(self) -> None: self.model_runner.capture_model() # Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache) # may cause performance degradation at runtime. - self._warm_up_atb() + if not is_310p(): + self._warm_up_atb() # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. NPUPlatform.seed_everything(self.model_config.seed)