File tree Expand file tree Collapse file tree 1 file changed +12
-10
lines changed
Expand file tree Collapse file tree 1 file changed +12
-10
lines changed Original file line number Diff line number Diff line change @@ -355,16 +355,18 @@ def __init__(
355355 self ._replace_linear_class_for_sfa_cp ()
356356 from vllm_ascend .distributed .parallel_state import \
357357 get_shared_weight_group
358- register_layer_to_shared_weight_series (
359- series_name = "q_proj" ,
360- group = get_shared_weight_group (),
361- layer = self .q_proj ,
362- prefetch_step = 1 )
363- register_layer_to_shared_weight_series (
364- series_name = "o_proj" ,
365- group = get_shared_weight_group (),
366- layer = self .o_proj ,
367- prefetch_step = 1 )
358+ if is_hidden_layer (self .model_config .hf_config , self .q_proj ):
359+ register_layer_to_shared_weight_series (
360+ series_name = "q_proj" ,
361+ group = get_shared_weight_group (),
362+ layer = self .q_proj ,
363+ prefetch_step = 1 )
364+ if is_hidden_layer (self .model_config .hf_config , self .o_proj ):
365+ register_layer_to_shared_weight_series (
366+ series_name = "o_proj" ,
367+ group = get_shared_weight_group (),
368+ layer = self .o_proj ,
369+ prefetch_step = 1 )
368370
369371 # indexer param
370372 self .n_head : int = self .indexer .n_head # 64
You can’t perform that action at this time.
0 commit comments