Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/vllm_ascend_test_310p.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,4 @@ jobs:
run: |
if [[ "${{ matrix.os }}" == "linux-aarch64-310p-1" ]]; then
pytest -sv tests/e2e/310p/test_offline_inference_310p.py
else
pytest -sv tests/e2e/310p/test_offline_inference_parallel_310p.py
fi
2 changes: 1 addition & 1 deletion tests/e2e/310p/test_offline_inference_310p.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
VL_MODELS = ["Qwen/Qwen2.5-VL-3B-Instruct"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("model", VL_MODELS)
@pytest.mark.parametrize("dtype", ["float16"])
def test_vl_model_with_samples(model: str, dtype: str) -> None:
example_prompts = [
Expand Down
85 changes: 85 additions & 0 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3183,6 +3183,87 @@ def initialize_kv_cache_tensors_deepseek_mla(

return kv_caches

def _initialize_kv_cache_tensors_310p(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
kv_cache_sizes = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
assert len(kv_cache_tensor.shared_by) == 1, (
"KV cache tensor shared by multiple layers is not supported in "
"310p NPU.")
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size

kv_caches: Dict[str, torch.Tensor] = {}
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
tensor_size = kv_cache_sizes[layer_name]
assert tensor_size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_size // kv_cache_spec.page_size_bytes

# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
# KVCacheManager may allocate.
# Since different GPUs may have different number of layers and
# different memory capacities, `num_blocks` can be different on
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks

# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
if isinstance(kv_cache_spec, FullAttentionSpec):
if self.vllm_config.additional_config.get(
"kv_cache_dtype", None) == 'int8':
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
elif hasattr(attn_backend, "get_supported_block_size"
) and self.use_hybrid_blocks:
block_size = attn_backend.get_supported_block_size()[0]

block_size_chunk = kv_cache_spec.block_size // block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks * block_size_chunk, block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
else:
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype

if "attn" in layer_name:
# for self_attn, sliding window attn
if self.vllm_config.kv_transfer_config is None:
k_tensor = torch.zeros(kv_cache_shape[1:],
dtype=dtype,
device=self.device)
v_tensor = torch.zeros(kv_cache_shape[1:],
dtype=dtype,
device=self.device)
k_cache = torch_npu.npu_format_cast(
k_tensor, ACL_FORMAT)
v_cache = torch_npu.npu_format_cast(
v_tensor, ACL_FORMAT)

kv_caches[layer_name] = (k_cache, v_cache)
else:
raise ValueError(
"KV cache transfer is not supported for 310p.")
else:
raise ValueError("Unknown KV cache spec type.")

bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)

return kv_caches

def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Expand All @@ -3194,6 +3275,10 @@ def initialize_kv_cache_tensors(
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""

if is_310p():
return self._initialize_kv_cache_tensors_310p(kv_cache_config)

# init kv cache tensors
kv_cache_raw_tensors: dict[str, Union[torch.Tensor,
Optional[torch.Tensor]]] = {}
Expand Down
5 changes: 3 additions & 2 deletions vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (init_ascend_soc_version,
from vllm_ascend.utils import (init_ascend_soc_version, is_310p,
prefill_context_parallel_enable,
register_ascend_customop, sleep_mode_enabled,
try_register_lib, vllm_version_is)
Expand Down Expand Up @@ -332,7 +332,8 @@ def compile_or_warm_up_model(self) -> None:
self.model_runner.capture_model()
# Call ATB matmul to warm up; otherwise, the first operation (ReshapeAndCache)
# may cause performance degradation at runtime.
self._warm_up_atb()
if not is_310p():
self._warm_up_atb()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
NPUPlatform.seed_everything(self.model_config.seed)
Expand Down
Loading