Skip to content

Conversation

@BryanChen408
Copy link

@BryanChen408 BryanChen408 commented Nov 25, 2025

What this PR does / why we need it?

This PR enhances EPLB to support one or multiple MTP layers. Previously, EPLB only supported the main model. Now, it can handle num_speculative_tokens=1 or num_speculative_tokens > 1.

Does this PR introduce any user-facing change?

No, this PR does not introduce any user-facing changes.

How was this patch tested?

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for MTP layers in EPLB. The changes are mostly correct, but there are several critical issues related to potential None value access and incorrect logic when MTP is not used or when multiple MTP layers are present. These issues could lead to crashes or incorrect behavior. I've provided suggestions to fix these problems.

Comment on lines +59 to +67
# TODO: init self.mtp_expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
if any("w13_weight_offset" in name for name, _ in self.mtp_instance.named_parameters()):
self.mtp_expert_weight_names = [
"w13_weight", "w2_weight", "w13_weight_scale",
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
]
else:
self.mtp_expert_weight_names = ["w13_weight", "w2_weight"]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The code accesses self.mtp_instance.named_parameters() without checking if self.mtp_instance is None. This will cause an AttributeError when mtp_instance is not provided during initialization. The block should be guarded with a check for self.mtp_instance.

Suggested change
# TODO: init self.mtp_expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
if any("w13_weight_offset" in name for name, _ in self.mtp_instance.named_parameters()):
self.mtp_expert_weight_names = [
"w13_weight", "w2_weight", "w13_weight_scale",
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
]
else:
self.mtp_expert_weight_names = ["w13_weight", "w2_weight"]
# TODO: init self.mtp_expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
if self.mtp_instance is not None:
if any("w13_weight_offset" in name for name, _ in self.mtp_instance.named_parameters()):
self.mtp_expert_weight_names = [
"w13_weight", "w2_weight", "w13_weight_scale",
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
]
else:
self.mtp_expert_weight_names = ["w13_weight", "w2_weight"]
else:
self.mtp_expert_weight_names = []

Comment on lines +131 to +141
if self.mtp_instance is not None:
mtp_param_dict = dict(self.mtp_instance.named_parameters())
self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers] = list()
for local_expert_id in range(num_local_expert):
for mtp_layer_idx in range(self.num_mtp_layers):
self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx].append([
mtp_param_dict["model.layers." + str(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx) +
".mtp_block.mlp.experts." +
name].data[local_expert_id]
for name in self.mtp_expert_weight_names
])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The initialization of self.expert_param_per_layer for MTP layers is incorrect. It only initializes a list for the first MTP layer. If num_mtp_layers > 1, this will raise a KeyError when trying to access subsequent layers. The initialization should be done for all MTP layers.

Suggested change
if self.mtp_instance is not None:
mtp_param_dict = dict(self.mtp_instance.named_parameters())
self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers] = list()
for local_expert_id in range(num_local_expert):
for mtp_layer_idx in range(self.num_mtp_layers):
self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx].append([
mtp_param_dict["model.layers." + str(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx) +
".mtp_block.mlp.experts." +
name].data[local_expert_id]
for name in self.mtp_expert_weight_names
])
if self.mtp_instance is not None:
mtp_param_dict = dict(self.mtp_instance.named_parameters())
for mtp_layer_idx in range(self.num_mtp_layers):
self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx] = list()
for local_expert_id in range(num_local_expert):
for mtp_layer_idx in range(self.num_mtp_layers):
self.expert_param_per_layer[self.num_dense_layers + self.num_moe_layers + mtp_layer_idx].append([
mtp_param_dict["model.layers." + str(self.num_dense_layers + self.num_moe_layers + mtp_layer_idx) +
".mtp_block.mlp.experts." +
name].data[local_expert_id]
for name in self.mtp_expert_weight_names
])

self.expert_map_record_path)

self.adaptor.model.clear_all_moe_loads()
self.adaptor.mtp_instance.model.clear_all_moe_loads()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The code accesses self.adaptor.mtp_instance without checking if it is None. This will cause an AttributeError if no MTP instance is used. This call should be guarded with if self.adaptor.mtp_instance is not None:.

Suggested change
self.adaptor.mtp_instance.model.clear_all_moe_loads()
if self.adaptor.mtp_instance is not None:
self.adaptor.mtp_instance.model.clear_all_moe_loads()

return self.layers[str(layer_id)].mtp_block.mlp.experts.get_log2phy_map()


def get_all_expert_map(self, num_moe_layers):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function get_all_expert_map is defined to take num_moe_layers as a required argument. However, it is called without arguments for MTP models, which will cause a TypeError. The num_moe_layers argument is not used for MTP models, so it should be made optional.

Suggested change
def get_all_expert_map(self, num_moe_layers):
def get_all_expert_map(self, num_moe_layers=None):

Comment on lines 3137 to 3143
self.eplb_adaptor = VllmEplbAdaptor(
model=self.model,
mtp_instance=mtp_instance,
num_mtp_layers=mtp_instance.model.num_mtp_layers
)
self.eplb_loader.set_adator(self.eplb_adaptor)
self.eplb_updator.set_adaptor(self.eplb_adaptor)
self.eplb_updator.set_adaptor(self.eplb_adaptor, mtp_instance.model.num_mtp_layers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The code accesses mtp_instance.model without checking if mtp_instance is None. This will raise an AttributeError when speculative decoding with deepseek_mtp is not used. You should conditionally get num_mtp_layers and pass it to the adaptor and updator.

Suggested change
self.eplb_adaptor = VllmEplbAdaptor(
model=self.model,
mtp_instance=mtp_instance,
num_mtp_layers=mtp_instance.model.num_mtp_layers
)
self.eplb_loader.set_adator(self.eplb_adaptor)
self.eplb_updator.set_adaptor(self.eplb_adaptor)
self.eplb_updator.set_adaptor(self.eplb_adaptor, mtp_instance.model.num_mtp_layers)
num_mtp_layers = mtp_instance.model.num_mtp_layers if mtp_instance is not None else 0
self.eplb_adaptor = VllmEplbAdaptor(
model=self.model,
mtp_instance=mtp_instance,
num_mtp_layers=num_mtp_layers
)
self.eplb_loader.set_adator(self.eplb_adaptor)
self.eplb_updator.set_adaptor(self.eplb_adaptor, num_mtp_layers)

Comment on lines 3166 to 3169
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
assert isinstance(self.drafter, MtpProposer) and isinstance(self.drafter.model, DeepSeekMTP)
mtp_instance=self.drafter.model
model_register(mtp_instance.model, self.vllm_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable mtp_instance is only defined within the if block, but it is used outside of it in the model_register call. This will lead to a NameError if self.speculative_config.method is not 'deepseek_mtp'. The model_register call should be moved inside the if block.

Suggested change
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
assert isinstance(self.drafter, MtpProposer) and isinstance(self.drafter.model, DeepSeekMTP)
mtp_instance=self.drafter.model
model_register(mtp_instance.model, self.vllm_config)
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
assert isinstance(self.drafter, MtpProposer) and isinstance(self.drafter.model, DeepSeekMTP)
mtp_instance=self.drafter.model
model_register(mtp_instance.model, self.vllm_config)

@BryanChen408 BryanChen408 force-pushed the feature/eplb+mtp3 branch 3 times, most recently from 3e46e82 to 9c2f938 Compare November 26, 2025 03:51
- Implement EPLB supporting  MTP layer
- Add support for multiple MTP layers configuration
- Enhance handling of num_speculative_tokens parameter:
  - Support num_speculative_tokens = 1 (single token speculative inference)
  - Support num_speculative_tokens > 1 (multiple tokens speculative inference)

Signed-off-by: chenbaiuan <[email protected]>
@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant