-
Notifications
You must be signed in to change notification settings - Fork 629
[Feat] Enable EPLB to support MTP layers #4598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,45 +19,72 @@ | |
|
|
||
| import torch | ||
|
|
||
| from vllm.model_executor.models.deepseek_mtp import DeepSeekMultiTokenPredictor | ||
|
|
||
|
|
||
| def get_expert_map(self, layer_id): | ||
| return self.model.layers[layer_id].mlp.experts.get_map() | ||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| return self.model.layers[layer_id].mlp.experts.get_map() | ||
| else: | ||
| return self.layers[str(layer_id)].mtp_block.mlp.experts.get_map() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is suggested here to extract the public get_map() function, obtain the experts in the branch, and then add get_map() again in the return statement. |
||
|
|
||
|
|
||
| def get_log2phy_map(self, layer_id): | ||
| return self.model.layers[layer_id].mlp.experts.get_log2phy_map() | ||
|
|
||
|
|
||
| def get_all_expert_map(self, num_moe_layers): | ||
| all_loads = [] | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| for layer_id in range(num_moe_layers): | ||
| load_tensor = self.get_expert_map( | ||
| layer_id + num_dense_layers) # (num_experts_per_layer,) | ||
| all_loads.append(load_tensor) | ||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| return self.model.layers[layer_id].mlp.experts.get_log2phy_map() | ||
| else: | ||
| return self.layers[str(layer_id)].mtp_block.mlp.experts.get_log2phy_map() | ||
|
|
||
|
|
||
| def get_all_expert_map(self, num_moe_layers=None): | ||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| all_loads = [] | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| for layer_id in range(num_moe_layers): | ||
| load_tensor = self.get_expert_map( | ||
| layer_id + num_dense_layers) # (num_experts_per_layer,) | ||
| all_loads.append(load_tensor) | ||
| else: | ||
| all_loads = [] | ||
| for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers): | ||
| load_tensor = self.get_expert_map(layer_id) | ||
| all_loads.append(load_tensor) | ||
|
|
||
| return torch.stack(all_loads, dim=0) | ||
|
|
||
|
|
||
| def get_all_moe_loads(self): | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| all_moe_loads = torch.stack( | ||
| [self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \ | ||
| for layer_id in range(self.num_moe_layers)], | ||
| dim=0 | ||
| ) | ||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| all_moe_loads = torch.stack( | ||
| [self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \ | ||
| for layer_id in range(self.num_moe_layers)], | ||
| dim=0 | ||
| ) | ||
| else: | ||
| all_moe_loads = torch.stack( | ||
| [self.layers[str(idx)].mtp_block.mlp.experts.moe_load \ | ||
| for idx in range(self.mtp_start_layer_idx, | ||
| self.mtp_start_layer_idx + self.num_mtp_layers)], | ||
| dim=0 | ||
| ) | ||
| return all_moe_loads | ||
|
|
||
|
|
||
| def clear_all_moe_loads(self): | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| for layer_id in range(self.num_moe_layers): | ||
| self.model.layers[layer_id + | ||
| num_dense_layers].mlp.experts.clear_moe_load() | ||
|
|
||
| if not isinstance(self, DeepSeekMultiTokenPredictor): | ||
| num_dense_layers = self.num_dense_layers if hasattr( | ||
| self, "num_dense_layers") else 0 | ||
| for layer_id in range(self.num_moe_layers): | ||
| self.model.layers[layer_id + | ||
| num_dense_layers].mlp.experts.clear_moe_load() | ||
| else: | ||
| for layer_id in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers): | ||
| self.layers[str(layer_id)].mtp_block.mlp.experts.clear_moe_load() | ||
|
|
||
|
||
|
|
||
|
|
||
| def model_register(model, model_config): | ||
| model.get_expert_map = types.MethodType(get_expert_map, model) | ||
|
|
@@ -66,12 +93,13 @@ def model_register(model, model_config): | |
| model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model) | ||
| model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model) | ||
|
|
||
| config = model_config.hf_config | ||
| if not isinstance(model, DeepSeekMultiTokenPredictor): | ||
| config = model_config.hf_config | ||
|
|
||
| if config.model_type == "qwen3_moe": | ||
| model.num_moe_layers = config.num_hidden_layers | ||
| elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3": | ||
| model.num_dense_layers = config.first_k_dense_replace | ||
| model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers | ||
| else: | ||
| raise NotImplementedError("EPLB is not supported.") | ||
| if config.model_type == "qwen3_moe": | ||
| model.num_moe_layers = config.num_hidden_layers | ||
| elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3": | ||
| model.num_dense_layers = config.first_k_dense_replace | ||
| model.num_moe_layers = config.num_hidden_layers - model.num_dense_layers | ||
| else: | ||
| raise NotImplementedError("EPLB is not supported.") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The construction of parameter names for MTP layers is hardcoded and very specific to the current model structure. This makes the code brittle and difficult to maintain or extend to other MTP models. While the suggestion improves readability and removes a redundant calculation, the core issue of the hardcoded path remains. Consider making this more robust, for example by having the MTP model itself expose a method to get expert parameter names for a given layer.