Skip to content

Commit ccc6ef1

Browse files
committed
Fix the bug in sfa-cp when MTP is enabled.
Signed-off-by: zzhx1 <[email protected]>
1 parent a86bdf8 commit ccc6ef1

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

vllm_ascend/attention/sfa_v1.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -355,16 +355,18 @@ def __init__(
355355
self._replace_linear_class_for_sfa_cp()
356356
from vllm_ascend.distributed.parallel_state import \
357357
get_shared_weight_group
358-
register_layer_to_shared_weight_series(
359-
series_name="q_proj",
360-
group=get_shared_weight_group(),
361-
layer=self.q_proj,
362-
prefetch_step=1)
363-
register_layer_to_shared_weight_series(
364-
series_name="o_proj",
365-
group=get_shared_weight_group(),
366-
layer=self.o_proj,
367-
prefetch_step=1)
358+
if is_hidden_layer(self.model_config.hf_config, self.q_proj):
359+
register_layer_to_shared_weight_series(
360+
series_name="q_proj",
361+
group=get_shared_weight_group(),
362+
layer=self.q_proj,
363+
prefetch_step=1)
364+
if is_hidden_layer(self.model_config.hf_config, self.o_proj):
365+
register_layer_to_shared_weight_series(
366+
series_name="o_proj",
367+
group=get_shared_weight_group(),
368+
layer=self.o_proj,
369+
prefetch_step=1)
368370

369371
# indexer param
370372
self.n_head: int = self.indexer.n_head # 64

0 commit comments

Comments
 (0)