Skip to content

Commit 9fc0c32

Browse files
committed
tiny refactor
Signed-off-by: MengqingCao <[email protected]>
1 parent 8022016 commit 9fc0c32

File tree

1 file changed

+45
-4
lines changed

1 file changed

+45
-4
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2590,6 +2590,33 @@ def initialize_kv_cache_tensors(
25902590
Dict[str, torch.Tensor]: A map between layer names to their
25912591
corresponding memory buffer for KV cache.
25922592
"""
2593+
# Initialize the memory buffer for KV cache
2594+
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
2595+
# Change the memory buffer to the desired shape
2596+
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
2597+
kv_cache_raw_tensors)
2598+
bind_kv_cache(kv_caches,
2599+
self.compilation_config.static_forward_context,
2600+
self.kv_caches)
2601+
return kv_caches
2602+
2603+
def _allocate_kv_cache_tensors(
2604+
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
2605+
"""
2606+
Initializes the KV cache buffer with the correct size. The buffer needs
2607+
to be reshaped to the desired shape before being used by the models.
2608+
2609+
NOTE: To support prefill disaggregation, we need to split kvcache tensor into
2610+
k_cahce and v cache, and the addr of both are aligned by 2M
2611+
2612+
Args:
2613+
kv_cache_config: The KV cache config
2614+
Returns:
2615+
dict[str, torch.Tensor]: A map between layer names to their
2616+
corresponding memory buffer for KV cache.
2617+
dict[str, tuple(torch.Tensor, torch.Tensor)] A map between layer names
2618+
to their corresponding memory buffer for K cache and V cache.
2619+
"""
25932620
# init kv cache tensors
25942621
kv_cache_raw_tensors: dict[str, Union[torch.Tensor,
25952622
Optional[torch.Tensor]]] = {}
@@ -2666,6 +2693,24 @@ def initialize_kv_cache_tensors(
26662693
assert layer_names == set(kv_cache_raw_tensors.keys(
26672694
)), "Some layers are not correctly initialized"
26682695

2696+
return kv_cache_raw_tensors
2697+
2698+
def _reshape_kv_cache_tensors(
2699+
self,
2700+
kv_cache_config: KVCacheConfig,
2701+
kv_cache_raw_tensors: dict[str, torch.Tensor],
2702+
) -> dict[str, torch.Tensor]:
2703+
"""
2704+
Reshape the KV cache tensors to the desired shape and dtype.
2705+
2706+
Args:
2707+
kv_cache_config: The KV cache config
2708+
kv_cache_raw_tensors: The KV cache buffer of each layer, with
2709+
correct size but uninitialized shape.
2710+
Returns:
2711+
Dict[str, torch.Tensor]: A map between layer names to their
2712+
corresponding memory buffer for KV cache.
2713+
"""
26692714
kv_caches: Dict[str, torch.Tensor] = {}
26702715
for group in self._kv_cache_spec_attn_group_iterator_dispatcher():
26712716
if vllm_version_is("0.10.2"):
@@ -2782,10 +2827,6 @@ def initialize_kv_cache_tensors(
27822827
else:
27832828
raise ValueError("Unknown KV cache spec type.")
27842829

2785-
bind_kv_cache(kv_caches,
2786-
self.compilation_config.static_forward_context,
2787-
self.kv_caches)
2788-
27892830
return kv_caches
27902831

27912832
def may_reinitialize_input_batch(self,

0 commit comments

Comments
 (0)