File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed
Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -2823,7 +2823,7 @@ def _allocate_kv_cache_tensors(
28232823 v_tensor = self ._align_memory (
28242824 v_tensor , alignment )[:v_tensor_size ]
28252825 #### k cache: for deepseek sparse attention
2826- if dsa_k_cache_factor is not None :
2826+ if dsa_k_cache_factor is not None and dsa_k_cache_size is not None :
28272827 k_cache_tensor = torch .zeros (dsa_k_cache_size +
28282828 alignment ,
28292829 dtype = torch .int8 ,
@@ -2848,6 +2848,8 @@ def _allocate_kv_cache_tensors(
28482848 assert layer_names == set (kv_cache_raw_tensors .keys (
28492849 )), "Some layers are not correctly initialized"
28502850
2851+ return kv_cache_raw_tensors
2852+
28512853 def _reshape_kv_cache_tensors (
28522854 self ,
28532855 kv_cache_config : KVCacheConfig ,
@@ -2926,7 +2928,7 @@ def _reshape_kv_cache_tensors(
29262928 k_cache = self ._convert_torch_format (k_cache )
29272929 v_cache = raw_v_tensor .view (dtype ).view (kv_cache_shape [1 :])
29282930 v_cache = self ._convert_torch_format (v_cache )
2929- if self .use_sparse :
2931+ if self .use_sparse and raw_dsa_k_cache is not None :
29302932 dsa_k_cache_shape = (num_blocks , block_size , 1 , 128 )
29312933 dsa_k_cache = raw_dsa_k_cache .view (dtype ).view (
29322934 dsa_k_cache_shape )
You can’t perform that action at this time.
0 commit comments