@@ -2922,22 +2922,18 @@ def initialize_kv_cache_tensors(
29222922 layer_name = kv_cache_tensor .shared_by [idx ]
29232923 if "linear_attn" in layer_name :
29242924 # for mamba linear attention
2925- for layer_name_inner in kv_cache_tensor .shared_by :
2926- if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner ) or \
2927- layer_name_inner in kv_cache_raw_tensors .keys ():
2928- continue
2929- if self .vllm_config .kv_transfer_config is None :
2930- tensor = torch .zeros (kv_cache_tensor .size ,
2931- dtype = torch .int8 ,
2932- device = self .device )
2933- else :
2934- cache_size_aligned = kv_cache_tensor .size + alignment
2935- tensor = torch .zeros (cache_size_aligned ,
2936- dtype = torch .int8 ,
2937- device = self .device )
2938- tensor = self ._align_memory (
2939- tensor , alignment )[:kv_cache_tensor .size ]
2940- kv_cache_raw_tensors [layer_name_inner ] = tensor
2925+ if self .vllm_config .kv_transfer_config is None :
2926+ tensor = torch .zeros (kv_cache_tensor .size ,
2927+ dtype = torch .int8 ,
2928+ device = self .device )
2929+ else :
2930+ cache_size_aligned = kv_cache_tensor .size + alignment
2931+ tensor = torch .zeros (cache_size_aligned ,
2932+ dtype = torch .int8 ,
2933+ device = self .device )
2934+ tensor = self ._align_memory (
2935+ tensor , alignment )[:kv_cache_tensor .size ]
2936+ kv_cache_raw_tensors [layer_name ] = tensor
29412937 elif "attn" in layer_name :
29422938 # for other attentions, e.g., self_attn, sliding window attn
29432939 if self .vllm_config .kv_transfer_config is None :
@@ -2961,6 +2957,11 @@ def initialize_kv_cache_tensors(
29612957 v_tensor = self ._align_memory (v_tensor ,
29622958 alignment )[:cache_size ]
29632959 kv_cache_raw_tensors [layer_name ] = (k_tensor , v_tensor )
2960+ for layer_name_inner in kv_cache_tensor .shared_by :
2961+ if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner ) and \
2962+ layer_name_inner not in kv_cache_raw_tensors .keys ():
2963+ kv_cache_raw_tensors [layer_name_inner ] = tensor
2964+ break
29642965
29652966 layer_names = set ()
29662967 for group in kv_cache_config .kv_cache_groups :
0 commit comments