@@ -2590,6 +2590,33 @@ def initialize_kv_cache_tensors(
25902590 Dict[str, torch.Tensor]: A map between layer names to their
25912591 corresponding memory buffer for KV cache.
25922592 """
2593+ # Initialize the memory buffer for KV cache
2594+ kv_cache_raw_tensors = self ._allocate_kv_cache_tensors (kv_cache_config )
2595+ # Change the memory buffer to the desired shape
2596+ kv_caches = self ._reshape_kv_cache_tensors (kv_cache_config ,
2597+ kv_cache_raw_tensors )
2598+ bind_kv_cache (kv_caches ,
2599+ self .compilation_config .static_forward_context ,
2600+ self .kv_caches )
2601+ return kv_caches
2602+
2603+ def _allocate_kv_cache_tensors (
2604+ self , kv_cache_config : KVCacheConfig ) -> dict [str , torch .Tensor ]:
2605+ """
2606+ Initializes the KV cache buffer with the correct size. The buffer needs
2607+ to be reshaped to the desired shape before being used by the models.
2608+
2609+ NOTE: To support prefill disaggregation, we need to split kvcache tensor into
2610+ k_cahce and v cache, and the addr of both are aligned by 2M
2611+
2612+ Args:
2613+ kv_cache_config: The KV cache config
2614+ Returns:
2615+ dict[str, torch.Tensor]: A map between layer names to their
2616+ corresponding memory buffer for KV cache.
2617+ dict[str, tuple(torch.Tensor, torch.Tensor)] A map between layer names
2618+ to their corresponding memory buffer for K cache and V cache.
2619+ """
25932620 # init kv cache tensors
25942621 kv_cache_raw_tensors : dict [str , Union [torch .Tensor ,
25952622 Optional [torch .Tensor ]]] = {}
@@ -2666,6 +2693,24 @@ def initialize_kv_cache_tensors(
26662693 assert layer_names == set (kv_cache_raw_tensors .keys (
26672694 )), "Some layers are not correctly initialized"
26682695
2696+ return kv_cache_raw_tensors
2697+
2698+ def _reshape_kv_cache_tensors (
2699+ self ,
2700+ kv_cache_config : KVCacheConfig ,
2701+ kv_cache_raw_tensors : dict [str , torch .Tensor ],
2702+ ) -> dict [str , torch .Tensor ]:
2703+ """
2704+ Reshape the KV cache tensors to the desired shape and dtype.
2705+
2706+ Args:
2707+ kv_cache_config: The KV cache config
2708+ kv_cache_raw_tensors: The KV cache buffer of each layer, with
2709+ correct size but uninitialized shape.
2710+ Returns:
2711+ Dict[str, torch.Tensor]: A map between layer names to their
2712+ corresponding memory buffer for KV cache.
2713+ """
26692714 kv_caches : Dict [str , torch .Tensor ] = {}
26702715 for group in self ._kv_cache_spec_attn_group_iterator_dispatcher ():
26712716 if vllm_version_is ("0.10.2" ):
@@ -2782,10 +2827,6 @@ def initialize_kv_cache_tensors(
27822827 else :
27832828 raise ValueError ("Unknown KV cache spec type." )
27842829
2785- bind_kv_cache (kv_caches ,
2786- self .compilation_config .static_forward_context ,
2787- self .kv_caches )
2788-
27892830 return kv_caches
27902831
27912832 def may_reinitialize_input_batch (self ,
0 commit comments