2121import torch
2222import torch .distributed as dist
2323from vllm .logger import logger
24+ from vllm .config import get_current_vllm_config
2425
2526from vllm_ascend .ascend_config import get_ascend_config
2627from vllm_ascend .eplb .adaptor .abstract_adaptor import EplbAdaptor
2728
2829
2930class VllmEplbAdaptor (EplbAdaptor ):
3031
31- def __init__ (self , model , ** args ):
32+ def __init__ (self , model , mtp_instance , num_mtp_layers , ** args ):
3233 super ().__init__ (** args )
3334 self .model = model
3435 self .rank_id = dist .get_rank ()
3536 self .world_size = dist .get_world_size ()
3637 self .param_dict = dict (self .model .named_parameters ())
38+ self .mtp_instance = mtp_instance
39+ self .num_mtp_layers = num_mtp_layers
3740 if self .model .config .model_type == "qwen3_moe" :
3841 self .num_dense_layers = 0
3942 self .global_expert_num = self .model .config .num_experts
4043 else :
4144 self .num_dense_layers = self .model .config .first_k_dense_replace
4245 self .global_expert_num = self .model .config .n_routed_experts
43- self .num_moe_layers = self .model .config .num_hidden_layers - self .num_dense_layers
46+ self .num_moe_layers = self .model .config .num_hidden_layers - self .num_dense_layers # MTP not included
4447 self .init_redundancy_expert = get_ascend_config (
4548 ).init_redundancy_expert
4649
@@ -53,6 +56,16 @@ def __init__(self, model, **args):
5356 else :
5457 self .expert_weight_names = ["w13_weight" , "w2_weight" ]
5558
59+ # TODO: init self.mtp_expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
60+ if any ("w13_weight_offset" in name for name , _ in self .mtp_instance .named_parameters ()):
61+ self .mtp_expert_weight_names = [
62+ "w13_weight" , "w2_weight" , "w13_weight_scale" ,
63+ "w13_weight_offset" , "w2_weight_scale" , "w2_weight_offset"
64+ ]
65+ else :
66+ self .mtp_expert_weight_names = ["w13_weight" , "w2_weight" ]
67+
68+
5669 self .expert_map_per_layer = dict (
5770 ) # reference to expert map on device for expert map update
5871 self .expert_map_per_layer_cpu = dict (
@@ -61,6 +74,12 @@ def __init__(self, model, **args):
6174 self .expert_map_per_layer [self .num_dense_layers + layer_idx ] = \
6275 self .model .get_expert_map (self .num_dense_layers + layer_idx )
6376
77+ # Currently, MTP only support one layer.
78+ if self .mtp_instance is not None :
79+ for mtp_layer_idx in range (self .num_mtp_layers ):
80+ self .expert_map_per_layer [self .num_dense_layers + self .num_moe_layers + mtp_layer_idx ] = \
81+ self .mtp_instance .model .get_expert_map (self .num_dense_layers + self .num_moe_layers + mtp_layer_idx )
82+
6483 # TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
6584 num_buffer_tensor = torch .where (
6685 self .expert_map_per_layer [self .num_dense_layers ] != - 1 )[0 ].numel ()
@@ -76,6 +95,11 @@ def __init__(self, model, **args):
7695 for layer_idx in range (self .num_moe_layers ):
7796 self .log2phy_map_per_layer [self .num_dense_layers + layer_idx ] = \
7897 self .model .get_log2phy_map (self .num_dense_layers + layer_idx )
98+
99+ if self .mtp_instance is not None :
100+ for mtp_layer_idx in range (self .num_mtp_layers ):
101+ self .log2phy_map_per_layer [self .num_dense_layers + self .num_moe_layers + mtp_layer_idx ] = \
102+ self .mtp_instance .model .get_log2phy_map (self .num_dense_layers + self .num_moe_layers + mtp_layer_idx )
79103
80104 self .all_topk_ids = []
81105
@@ -103,13 +127,29 @@ def init_expert_param_per_layer(self):
103127 name ].data [local_expert_id ]
104128 for name in self .expert_weight_names
105129 ])
130+
131+ if self .mtp_instance is not None :
132+ mtp_param_dict = dict (self .mtp_instance .named_parameters ())
133+ self .expert_param_per_layer [self .num_dense_layers + self .num_moe_layers ] = list ()
134+ for local_expert_id in range (num_local_expert ):
135+ for mtp_layer_idx in range (self .num_mtp_layers ):
136+ self .expert_param_per_layer [self .num_dense_layers + self .num_moe_layers + mtp_layer_idx ].append ([
137+ mtp_param_dict ["model.layers." + str (self .num_dense_layers + self .num_moe_layers + mtp_layer_idx ) +
138+ ".mtp_block.mlp.experts." +
139+ name ].data [local_expert_id ]
140+ for name in self .mtp_expert_weight_names
141+ ])
106142
107143 def get_rank_expert_workload (self ) -> torch .Tensor :
108144 self .moe_load = self .model .get_all_moe_loads ()
145+ if self .mtp_instance is not None :
146+ self .moe_load = torch .cat ([self .moe_load , self .mtp_instance .model .get_all_moe_loads ().to (device = self .moe_load .device )], dim = 0 )
109147 return self .moe_load
110148
111149 def get_init_expert_map (self , num_moe_layers ):
112150 expert_map = self .model .get_all_expert_map (num_moe_layers )
151+ if self .mtp_instance is not None :
152+ expert_map = torch .cat ([expert_map , self .mtp_instance .model .get_all_expert_map ().to (device = expert_map .device )], dim = 0 )
113153 if dist .is_initialized ():
114154 world_size = dist .get_world_size ()
115155
@@ -261,9 +301,11 @@ def determine_expert_map_all(self):
261301 local_num_experts = self .global_expert_num // self .world_size
262302
263303 expert_map_all = torch .full (
264- (self .num_moe_layers , self .world_size , self .global_expert_num ),
265- - 1 ,
266- dtype = torch .int32 )
304+ (self .num_moe_layers if self .mtp_instance is None else (self .num_moe_layers + self .num_mtp_layers ),
305+ self .world_size ,
306+ self .global_expert_num ),
307+ - 1 ,
308+ dtype = torch .int32 )
267309
268310 for r in range (self .world_size ):
269311 if r < self .world_size - 1 :
@@ -284,6 +326,6 @@ def determine_expert_map_all(self):
284326
285327 local_ids = torch .arange (local_count , dtype = torch .int32 )
286328 expert_map_all [:, r , start :end ] = local_ids .unsqueeze (0 ).expand (
287- self .num_moe_layers , - 1 )
329+ self .num_moe_layers if self . mtp_instance is None else ( self . num_moe_layers + self . num_mtp_layers ) , - 1 )
288330
289331 return expert_map_all
0 commit comments