Skip to content

Commit 3e46e82

Browse files
BryanChen408cbx
authored andcommitted
feat: enable EPLB for MTP layer
- Implement EPLB supporting MTP layer - Add support for multiple MTP layers configuration - Enhance handling of num_speculative_tokens parameter: - Support num_speculative_tokens = 1 (single token speculative inference) - Support num_speculative_tokens > 1 (multiple tokens speculative inference) Signed-off-by: chenbaixuan <[email protected]>
1 parent 06f6cc1 commit 3e46e82

File tree

4 files changed

+127
-41
lines changed

4 files changed

+127
-41
lines changed

vllm_ascend/eplb/adaptor/vllm_adaptor.py

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,29 @@
2121
import torch
2222
import torch.distributed as dist
2323
from vllm.logger import logger
24+
from vllm.config import get_current_vllm_config
2425

2526
from vllm_ascend.ascend_config import get_ascend_config
2627
from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor
2728

2829

2930
class VllmEplbAdaptor(EplbAdaptor):
3031

31-
def __init__(self, model, **args):
32+
def __init__(self, model, mtp_instance, num_mtp_layers, **args):
3233
super().__init__(**args)
3334
self.model = model
3435
self.rank_id = dist.get_rank()
3536
self.world_size = dist.get_world_size()
3637
self.param_dict = dict(self.model.named_parameters())
38+
self.mtp_instance = mtp_instance
39+
self.num_mtp_layers = num_mtp_layers
3740
if self.model.config.model_type == "qwen3_moe":
3841
self.num_dense_layers = 0
3942
self.global_expert_num = self.model.config.num_experts
4043
else:
4144
self.num_dense_layers = self.model.config.first_k_dense_replace
4245
self.global_expert_num = self.model.config.n_routed_experts
43-
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
46+
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers # MTP not included
4447
self.init_redundancy_expert = get_ascend_config(
4548
).init_redundancy_expert
4649

@@ -53,6 +56,16 @@ def __init__(self, model, **args):
5356
else:
5457
self.expert_weight_names = ["w13_weight", "w2_weight"]
5558

59+
# TODO: init self.mtp_expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
60+
if any("w13_weight_offset" in name for name, _ in self.mtp_instance.named_parameters()):
61+
self.mtp_expert_weight_names = [
62+
"w13_weight", "w2_weight", "w13_weight_scale",
63+
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
64+
]
65+
else:
66+
self.mtp_expert_weight_names = ["w13_weight", "w2_weight"]
67+
68+
5669
self.expert_map_per_layer = dict(
5770
) # reference to expert map on device for expert map update
5871
self.expert_map_per_layer_cpu = dict(
@@ -61,6 +74,12 @@ def __init__(self, model, **args):
6174
self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \
6275
self.model.get_expert_map(self.num_dense_layers + layer_idx)
6376

77+
# Currently, MTP only support one layer.
78+
if self.mtp_instance is not None:
79+
for mtp_layer_idx in range(self.num_mtp_layers):
80+
self.expert_map_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx] = \
81+
self.mtp_instance.model.get_expert_map(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx)
82+
6483
# TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
6584
num_buffer_tensor = torch.where(
6685
self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel()
@@ -76,6 +95,11 @@ def __init__(self, model, **args):
7695
for layer_idx in range(self.num_moe_layers):
7796
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \
7897
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)
98+
99+
if self.mtp_instance is not None:
100+
for mtp_layer_idx in range(self.num_mtp_layers):
101+
self.log2phy_map_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx] = \
102+
self.mtp_instance.model.get_log2phy_map(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx)
79103

80104
self.all_topk_ids = []
81105

@@ -103,13 +127,29 @@ def init_expert_param_per_layer(self):
103127
name].data[local_expert_id]
104128
for name in self.expert_weight_names
105129
])
130+
131+
if self.mtp_instance is not None:
132+
mtp_param_dict = dict(self.mtp_instance.named_parameters())
133+
self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers] = list()
134+
for local_expert_id in range(num_local_expert):
135+
for mtp_layer_idx in range(self.num_mtp_layers):
136+
self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx].append([
137+
mtp_param_dict["model.layers." + str(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx) +
138+
".mtp_block.mlp.experts." +
139+
name].data[local_expert_id]
140+
for name in self.mtp_expert_weight_names
141+
])
106142

107143
def get_rank_expert_workload(self) -> torch.Tensor:
108144
self.moe_load = self.model.get_all_moe_loads()
145+
if self.mtp_instance is not None:
146+
self.moe_load = torch.cat([self.moe_load, self.mtp_instance.model.get_all_moe_loads().to(device=self.moe_load.device)], dim=0)
109147
return self.moe_load
110148

111149
def get_init_expert_map(self, num_moe_layers):
112150
expert_map = self.model.get_all_expert_map(num_moe_layers)
151+
if self.mtp_instance is not None:
152+
expert_map = torch.cat([expert_map, self.mtp_instance.model.get_all_expert_map().to(device=expert_map.device)], dim=0)
113153
if dist.is_initialized():
114154
world_size = dist.get_world_size()
115155

@@ -261,9 +301,11 @@ def determine_expert_map_all(self):
261301
local_num_experts = self.global_expert_num // self.world_size
262302

263303
expert_map_all = torch.full(
264-
(self.num_moe_layers, self.world_size, self.global_expert_num),
265-
-1,
266-
dtype=torch.int32)
304+
(self.num_moe_layers if self.mtp_instance is None else (self.num_moe_layers + self.num_mtp_layers),
305+
self.world_size,
306+
self.global_expert_num),
307+
-1,
308+
dtype=torch.int32)
267309

268310
for r in range(self.world_size):
269311
if r < self.world_size - 1:
@@ -284,6 +326,6 @@ def determine_expert_map_all(self):
284326

285327
local_ids = torch.arange(local_count, dtype=torch.int32)
286328
expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(
287-
self.num_moe_layers, -1)
329+
self.num_moe_layers if self.mtp_instance is None else (self.num_moe_layers + self.num_mtp_layers), -1)
288330

289331
return expert_map_all

vllm_ascend/eplb/eplb_updator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ def __init__(self, ascend_config, loader, eplb_process: EplbProcess,
3535
self.eplb_process = eplb_process
3636
self.shared_dict = self.eplb_process.shared_dict
3737

38-
def set_adaptor(self, adaptor):
38+
def set_adaptor(self, adaptor, num_mtp_layers):
3939
self.adaptor = adaptor
40-
self.num_moe_layers = self.adaptor.num_moe_layers
40+
self.num_moe_layers = (
41+
self.adaptor.num_moe_layers if self.adaptor.mtp_instance is None else self.adaptor.num_moe_layers + num_mtp_layers
42+
)
4143
self.global_expert_num = self.adaptor.global_expert_num
4244

4345
def init_eplb(self, expert_map_path, process):
@@ -84,6 +86,7 @@ def update_iteration(self):
8486
self.expert_map_record_path)
8587

8688
self.adaptor.model.clear_all_moe_loads()
89+
self.adaptor.mtp_instance.model.clear_all_moe_loads()
8790
if not self.gate_eplb:
8891
self.cur_iterations = 0
8992

vllm_ascend/eplb/utils.py

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,45 +19,72 @@
1919

2020
import torch
2121

22+
from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictor
23+
2224

2325
def get_expert_map(self, layer_id):
24-
return self.model.layers[layer_id].mlp.experts.get_map()
26+
if not isinstance(self, DeepSeekMultiTokenPredictor):
27+
return self.model.layers[layer_id].mlp.experts.get_map()
28+
else:
29+
return self.layers[str(layer_id)].mtp_block.mlp.experts.get_map()
2530

2631

2732
def get_log2phy_map(self, layer_id):
28-
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
33+
if not isinstance(self, DeepSeekMultiTokenPredictor):
34+
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
35+
else:
36+
return self.layers[str(layer_id)].mtp_block.mlp.experts.get_log2phy_map()
2937

3038

3139
def get_all_expert_map(self, num_moe_layers):
32-
all_loads = []
33-
num_dense_layers = self.num_dense_layers if hasattr(
34-
self, "num_dense_layers") else 0
35-
for layer_id in range(num_moe_layers):
36-
load_tensor = self.get_expert_map(
37-
layer_id + num_dense_layers) # (num_experts_per_layer,)
38-
all_loads.append(load_tensor)
40+
if not isinstance(self, DeepSeekMultiTokenPredictor):
41+
all_loads = []
42+
num_dense_layers = self.num_dense_layers if hasattr(
43+
self, "num_dense_layers") else 0
44+
for layer_id in range(num_moe_layers):
45+
load_tensor = self.get_expert_map(
46+
layer_id + num_dense_layers) # (num_experts_per_layer,)
47+
all_loads.append(load_tensor)
48+
else:
49+
all_loads = []
50+
for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers):
51+
load_tensor = self.get_expert_map(layer_id)
52+
all_loads.append(load_tensor)
3953

4054
return torch.stack(all_loads, dim=0)
4155

4256

4357
def get_all_moe_loads(self):
44-
num_dense_layers = self.num_dense_layers if hasattr(
45-
self, "num_dense_layers") else 0
46-
all_moe_loads = torch.stack(
47-
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
48-
for layer_id in range(self.num_moe_layers)],
49-
dim=0
50-
)
58+
if not isinstance(self, DeepSeekMultiTokenPredictor):
59+
num_dense_layers = self.num_dense_layers if hasattr(
60+
self, "num_dense_layers") else 0
61+
all_moe_loads = torch.stack(
62+
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
63+
for layer_id in range(self.num_moe_layers)],
64+
dim=0
65+
)
66+
else:
67+
all_moe_loads = torch.stack(
68+
[self.layers[str(idx)].mtp_block.mlp.experts.moe_load \
69+
for idx in range(self.mtp_start_layer_idx,
70+
self.mtp_start_layer_idx + self.num_mtp_layers)],
71+
dim=0
72+
)
5173
return all_moe_loads
5274

5375

5476
def clear_all_moe_loads(self):
55-
num_dense_layers = self.num_dense_layers if hasattr(
56-
self, "num_dense_layers") else 0
57-
for layer_id in range(self.num_moe_layers):
58-
self.model.layers[layer_id +
59-
num_dense_layers].mlp.experts.clear_moe_load()
60-
77+
if not isinstance(self, DeepSeekMultiTokenPredictor):
78+
num_dense_layers = self.num_dense_layers if hasattr(
79+
self, "num_dense_layers") else 0
80+
for layer_id in range(self.num_moe_layers):
81+
self.model.layers[layer_id +
82+
num_dense_layers].mlp.experts.clear_moe_load()
83+
else:
84+
for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers):
85+
self.layers[str(layer_id)].mtp_block.mlp.experts.clear_moe_load()
86+
87+
6188

6289
def model_register(model, model_config):
6390
model.get_expert_map = types.MethodType(get_expert_map, model)
@@ -66,12 +93,13 @@ def model_register(model, model_config):
6693
model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model)
6794
model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model)
6895

69-
config = model_config.hf_config
96+
if not isinstance(model, DeepSeekMultiTokenPredictor):
97+
config = model_config.hf_config
7098

71-
if config.model_type == "qwen3_moe":
72-
model.num_moe_layers = config.num_hidden_layers
73-
elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3":
74-
model.num_dense_layers = config.first_k_dense_replace
75-
model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers
76-
else:
77-
raise NotImplementedError("EPLB is not supported.")
99+
if config.model_type == "qwen3_moe":
100+
model.num_moe_layers = config.num_hidden_layers
101+
elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3":
102+
model.num_dense_layers = config.first_k_dense_replace
103+
model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers
104+
else:
105+
raise NotImplementedError("EPLB is not supported.")

vllm_ascend/worker/model_runner_v1.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
6363
supports_mrope,
6464
supports_transcription)
65+
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
6566
from vllm.model_executor.models.interfaces_base import (
6667
VllmModelForPooling, is_pooling_model, is_text_generation_model)
6768
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -3129,9 +3130,17 @@ def _dummy_pooler_run(
31293130
def eplb_warmup(self):
31303131
if self.dynamic_eplb and not self.is_eplb_warmuped:
31313132
self.is_eplb_warmuped = True
3132-
self.eplb_adaptor = VllmEplbAdaptor(model=self.model)
3133+
mtp_instance: Optional[DeepSeekMTP] = None
3134+
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
3135+
assert isinstance(self.drafter, MtpProposer) and isinstance(self.drafter.model, DeepSeekMTP)
3136+
mtp_instance=self.drafter.model
3137+
self.eplb_adaptor = VllmEplbAdaptor(
3138+
model=self.model,
3139+
mtp_instance=mtp_instance,
3140+
num_mtp_layers=mtp_instance.model.num_mtp_layers
3141+
)
31333142
self.eplb_loader.set_adator(self.eplb_adaptor)
3134-
self.eplb_updator.set_adaptor(self.eplb_adaptor)
3143+
self.eplb_updator.set_adaptor(self.eplb_adaptor, mtp_instance.model.num_mtp_layers)
31353144
self.eplb_updator.warm_up_eplb()
31363145

31373146
def load_model(self) -> None:
@@ -3140,7 +3149,7 @@ def load_model(self) -> None:
31403149
with DeviceMemoryProfiler() as m: # noqa: SIM117
31413150
self.model = get_model(vllm_config=self.vllm_config)
31423151
if self.dynamic_eplb:
3143-
model_register(self.model, self.model_config)
3152+
model_register(self.model, self.model_config)
31443153
if is_310p():
31453154
from vllm.model_executor.layers.linear import (
31463155
MergedColumnParallelLinear, QKVParallelLinear,
@@ -3154,6 +3163,10 @@ def load_model(self) -> None:
31543163
if self.drafter:
31553164
logger.info("Loading drafter model...")
31563165
self.drafter.load_model(self.model)
3166+
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
3167+
assert isinstance(self.drafter, MtpProposer) and isinstance(self.drafter.model, DeepSeekMTP)
3168+
mtp_instance=self.drafter.model
3169+
model_register(mtp_instance.model, self.vllm_config)
31573170
if self.drafter.name == SpecDcodeType.EAGLE3:
31583171
self.model.set_aux_hidden_state_layers(
31593172
self.model.get_eagle3_aux_hidden_state_layers())

0 commit comments

Comments
 (0)