Skip to content

Commit 4b90aff

Browse files
committed
fix shared by
Signed-off-by: MengqingCao <[email protected]>
1 parent 248ee7f commit 4b90aff

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2922,22 +2922,18 @@ def initialize_kv_cache_tensors(
29222922
layer_name = kv_cache_tensor.shared_by[idx]
29232923
if "linear_attn" in layer_name:
29242924
# for mamba linear attention
2925-
for layer_name_inner in kv_cache_tensor.shared_by:
2926-
if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \
2927-
layer_name_inner in kv_cache_raw_tensors.keys():
2928-
continue
2929-
if self.vllm_config.kv_transfer_config is None:
2930-
tensor = torch.zeros(kv_cache_tensor.size,
2931-
dtype=torch.int8,
2932-
device=self.device)
2933-
else:
2934-
cache_size_aligned = kv_cache_tensor.size + alignment
2935-
tensor = torch.zeros(cache_size_aligned,
2936-
dtype=torch.int8,
2937-
device=self.device)
2938-
tensor = self._align_memory(
2939-
tensor, alignment)[:kv_cache_tensor.size]
2940-
kv_cache_raw_tensors[layer_name_inner] = tensor
2925+
if self.vllm_config.kv_transfer_config is None:
2926+
tensor = torch.zeros(kv_cache_tensor.size,
2927+
dtype=torch.int8,
2928+
device=self.device)
2929+
else:
2930+
cache_size_aligned = kv_cache_tensor.size + alignment
2931+
tensor = torch.zeros(cache_size_aligned,
2932+
dtype=torch.int8,
2933+
device=self.device)
2934+
tensor = self._align_memory(
2935+
tensor, alignment)[:kv_cache_tensor.size]
2936+
kv_cache_raw_tensors[layer_name] = tensor
29412937
elif "attn" in layer_name:
29422938
# for other attentions, e.g., self_attn, sliding window attn
29432939
if self.vllm_config.kv_transfer_config is None:
@@ -2961,6 +2957,11 @@ def initialize_kv_cache_tensors(
29612957
v_tensor = self._align_memory(v_tensor,
29622958
alignment)[:cache_size]
29632959
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)
2960+
for layer_name_inner in kv_cache_tensor.shared_by:
2961+
if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) and \
2962+
layer_name_inner not in kv_cache_raw_tensors.keys():
2963+
kv_cache_raw_tensors[layer_name_inner] = tensor
2964+
break
29642965

29652966
layer_names = set()
29662967
for group in kv_cache_config.kv_cache_groups:

0 commit comments

Comments
 (0)