@@ -2775,7 +2775,8 @@ def _allocate_kv_cache_tensors(
27752775 for idx in range (len (kv_cache_tensor .shared_by )):
27762776 layer_name = kv_cache_tensor .shared_by [idx ]
27772777 print (30 * "-" , f"layer_name: { layer_name } " )
2778- if "linear_attn" in layer_name :
2778+ if "linear_attn" in layer_name and layer_name not in kv_cache_raw_tensors .keys ():
2779+ print (30 * "|" , f"layer_name: { layer_name } " )
27792780 # for mamba linear attention
27802781 if self .vllm_config .kv_transfer_config is None :
27812782 tensor = torch .zeros (kv_cache_tensor .size ,
@@ -2789,6 +2790,10 @@ def _allocate_kv_cache_tensors(
27892790 tensor = self ._align_memory (
27902791 tensor , alignment )[:kv_cache_tensor .size ]
27912792 kv_cache_raw_tensors [layer_name ] = tensor
2793+ for layer_name_inner in kv_cache_tensor .shared_by :
2794+ # shared the kvcache between the self_attn specs in the same group
2795+ if "linear_attn" in layer_name_inner :
2796+ kv_cache_raw_tensors [layer_name_inner ] = tensor
27922797 elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors .keys ():
27932798 print (30 * "/" , f"layer_name: { layer_name } " )
27942799 # NOTE: We need to init k cache tensor (nope cache tensor in mla) and
@@ -3002,10 +3007,6 @@ def _reshape_kv_cache_tensors(
30023007 else :
30033008 raise ValueError ("Unknown KV cache spec type." )
30043009
3005- bind_kv_cache (kv_caches ,
3006- self .compilation_config .static_forward_context ,
3007- self .kv_caches )
3008-
30093010 return kv_caches
30103011
30113012 def may_reinitialize_input_batch (self ,
0 commit comments