-
Notifications
You must be signed in to change notification settings - Fork 624
[MOE]move weight transpose to wakeup for RL secnarios #4626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the weight transposition logic for MoE models, moving it from the weight loading process into the wake_up method. This is intended to support Reinforcement Learning scenarios where weights are updated dynamically. The changes involve modifying how MoE weights are handled in fused_moe.py and worker_v1.py, and updating example and test files accordingly. My review identifies a critical bug in the weight identification logic within the wake_up method and suggests an improvement for correctness and code quality.
| for name, param in model.named_parameters(): | ||
| if 'w2_weight' in name and param.shape[2] == hidden_size: | ||
| parts = name.split('.') | ||
| param_name = parts[-1] | ||
| parent_module = model.get_submodule(".".join(parts[:-1])) | ||
|
|
||
| w2_data = param.transpose(1, 2) | ||
| w2_data = torch.nn.Parameter(w2_data, requires_grad=False) | ||
| setattr(parent_module, param_name, w2_data) | ||
| elif 'w13_weight' in name and param.shape[1] == hidden_size: | ||
| parts = name.split('.') | ||
| param_name = parts[-1] | ||
| parent_module = model.get_submodule(".".join(parts[:-1])) | ||
|
|
||
| w13_data = param.transpose(1, 2) | ||
| w13_data = torch.nn.Parameter(w13_data, requires_grad=False) | ||
| setattr(parent_module, param_name, w13_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a couple of issues in this block of code:
-
[Critical] The condition to identify the
w2_weightparameter is incorrect. The shape ofw2_weightat this point is(num_experts, hidden_size, intermediate_size). The conditionparam.shape[2] == hidden_sizecomparesintermediate_sizewithhidden_size, which is not always true and will cause this logic to fail for many models. It should beparam.shape[1] == hidden_sizeto correctly identify the parameter by its hidden dimension size. -
[High] After transposing a tensor, it's good practice to call
.contiguous()to ensure the memory layout is continuous. This can prevent potential errors and performance issues in subsequent operations that expect a contiguous tensor. Theload_weightsmethod which is called after this might rely on it. -
[Medium] The code for transposing
w2_weightandw13_weightis very similar. This duplication can be avoided by refactoring it into a helper function to improve readability and maintainability.
I've provided a suggestion that fixes the critical bug, adds .contiguous(), and refactors the duplicated logic.
| for name, param in model.named_parameters(): | |
| if 'w2_weight' in name and param.shape[2] == hidden_size: | |
| parts = name.split('.') | |
| param_name = parts[-1] | |
| parent_module = model.get_submodule(".".join(parts[:-1])) | |
| w2_data = param.transpose(1, 2) | |
| w2_data = torch.nn.Parameter(w2_data, requires_grad=False) | |
| setattr(parent_module, param_name, w2_data) | |
| elif 'w13_weight' in name and param.shape[1] == hidden_size: | |
| parts = name.split('.') | |
| param_name = parts[-1] | |
| parent_module = model.get_submodule(".".join(parts[:-1])) | |
| w13_data = param.transpose(1, 2) | |
| w13_data = torch.nn.Parameter(w13_data, requires_grad=False) | |
| setattr(parent_module, param_name, w13_data) | |
| for name, param in model.named_parameters(): | |
| # The shape of w2_weight is (num_experts, hidden_size, intermediate_size) | |
| # The shape of w13_weight is (num_experts, hidden_size, 2 * intermediate_size) | |
| if ('w2_weight' in name or 'w13_weight' in name) and len(param.shape) == 3 and param.shape[1] == hidden_size: | |
| parts = name.split('.') | |
| param_name = parts[-1] | |
| parent_module = model.get_submodule(".".join(parts[:-1])) | |
| # Transpose back to training format and ensure contiguity | |
| new_data = param.transpose(1, 2).contiguous() | |
| new_param = torch.nn.Parameter(new_data, requires_grad=False) | |
| setattr(parent_module, param_name, new_param) |
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
2c0ed50 to
f157dab
Compare
Signed-off-by: lhp-deep <[email protected]>
1a158d9 to
2c63755
Compare
What this PR does / why we need it?
In reinforcement learning scenarios, the current inference applies a transpose operation to the weights. For a cleaner architecture, the weight transpose module was moved to wakeup.
Does this PR introduce any user-facing change?
How was this patch tested?