@@ -3355,26 +3355,14 @@ def _reshape_kv_cache_tensors(
33553355 else :
33563356 # k_cache: nope_cache v_cache: rope_cache
33573357 mla_num_blocks , mla_block_size , num_kv_heads , _ = kv_cache_shape
3358- if not self .use_sparse :
3359- k_shape = [
3360- mla_num_blocks , mla_block_size , num_kv_heads ,
3361- self .model_config .hf_text_config .kv_lora_rank
3362- ]
3363- v_shape = [
3364- mla_num_blocks , mla_block_size , num_kv_heads ,
3365- self .model_config .hf_text_config .
3366- qk_rope_head_dim
3367- ]
3368- else :
3369- k_shape = [
3370- mla_num_blocks , mla_block_size , num_kv_heads ,
3371- self .model_config .hf_text_config .kv_lora_rank
3372- ]
3373- v_shape = [
3374- mla_num_blocks , mla_block_size , num_kv_heads ,
3375- self .model_config .hf_text_config .
3376- qk_rope_head_dim
3377- ]
3358+ k_shape = [
3359+ mla_num_blocks , mla_block_size , num_kv_heads ,
3360+ self .model_config .hf_text_config .kv_lora_rank
3361+ ]
3362+ v_shape = [
3363+ mla_num_blocks , mla_block_size , num_kv_heads ,
3364+ self .model_config .hf_text_config .qk_rope_head_dim
3365+ ]
33783366 k_cache = raw_k_tensor .view (dtype ).view (k_shape )
33793367 k_cache = self ._convert_torch_format (k_cache )
33803368 v_cache = raw_v_tensor .view (dtype ).view (v_shape )
0 commit comments