-
Notifications
You must be signed in to change notification settings - Fork 640
[Feat] Flashcomm2 use o_shared linear #4188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for shared o_proj linear layers for Flashcomm2, which involves changes across configuration, distributed state management, and the attention mechanism. The core logic for shared weights is implemented in vllm_ascend/torchair/ops/shared_weight_layer.py, which has been refactored for better usability.
My review focuses on ensuring the correctness and robustness of the new feature. I've identified a few critical issues:
- Incorrect validation logic for the new
flashcomm2_oproj_sharedconfiguration that could lead to silent failures. - A potential crash in the shared weight layer logic when handling a series with a single layer.
I have provided suggestions to fix these issues. The rest of the changes look good and the refactoring of the shared weight layer API is a nice improvement.
vllm_ascend/ascend_config.py
Outdated
| if self.flashcomm2_oproj_tensor_parallel_size is None: | ||
| raise AssertionError( | ||
| "flashcomm2_oproj_shared must be enabled simultaneously with flashcomm2_oproj_tensor_parallel_size" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation if self.flashcomm2_oproj_tensor_parallel_size is None: is incorrect. The value of self.flashcomm2_oproj_tensor_parallel_size is an integer returned from get_flashcomm2_config_and_validate (which gets it from an environment variable with a default of 0), so it will never be None. The check should be against 0, as flashcomm2_oproj_shared requires flashcomm2_oproj_tensor_parallel_size to be greater than 0.
| if self.flashcomm2_oproj_tensor_parallel_size is None: | |
| raise AssertionError( | |
| "flashcomm2_oproj_shared must be enabled simultaneously with flashcomm2_oproj_tensor_parallel_size" | |
| ) | |
| if self.flashcomm2_oproj_tensor_parallel_size == 0: | |
| raise AssertionError( | |
| "flashcomm2_oproj_shared must be enabled with flashcomm2_oproj_tensor_parallel_size > 0" | |
| ) |
| self.layers.sort(key=lambda x: x.layer_idx) | ||
| self.num_layers = len(self.layers) | ||
| assert self.num_layers > 0, "No layers in the series" | ||
| assert self.prefetch_step >= 0 and self.prefetch_step <= self.num_layers - 2, "prefetch_step must be in [0, num_layers - 2]" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion self.prefetch_step <= self.num_layers - 2 will cause a crash if a shared weight series contains only one layer (self.num_layers == 1), because self.num_layers - 2 would be -1. For a single-layer series, prefetching is not applicable, and prefetch_step should be 0. To prevent this crash, the assertion should be adjusted to handle this edge case.
| assert self.prefetch_step >= 0 and self.prefetch_step <= self.num_layers - 2, "prefetch_step must be in [0, num_layers - 2]" | |
| assert self.prefetch_step >= 0 and self.prefetch_step <= max(0, self.num_layers - 2), "prefetch_step must be in [0, num_layers - 2]" |
vllm_ascend/utils.py
Outdated
| if flashcomm2_oproj_shared: | ||
| if flashcomm2_oproj_tp_size is None: | ||
| raise AssertionError( | ||
| "flashcomm2_oproj_shared must be enabled simultaneously with flashcomm2_oproj_tensor_parallel_size" | ||
| ) | ||
| logger.info("Enable Flashcomm2 with flashcomm2_oproj_shared") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This validation logic for flashcomm2_oproj_shared is redundant with the logic in vllm_ascend/ascend_config.py. It's better to have validation in one place to avoid inconsistencies. Since ascend_config.py is the configuration entry point, it's a better place for this check. Additionally, the check if flashcomm2_oproj_tp_size is None: is incorrect, as flashcomm2_oproj_tp_size is an integer. I've suggested a fix in ascend_config.py and recommend removing this redundant block.
a951ad1 to
da0d630
Compare
47501e5 to
8bf8ed5
Compare
faaa68e to
bfbda42
Compare
60f1aac to
89c3923
Compare
18803f9 to
a9fae57
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
a9fae57 to
6ca00ba
Compare
26672ae to
ce92a65
Compare
fd7c9fa to
ba1a760
Compare
|
@wangxiyuan this PR is ready, Please check again, and if there are no issues, please help merge it in. |
| kv_no_split.contiguous(), need_gather_q_kv) | ||
|
|
||
| if self.fc2_enable and is_hidden_layer(self.vllm_config, self.o_proj): | ||
| reach_layer_for_shared_weight_series(self.o_proj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the first broadcast performed here? I think it's not general enough because other models in the profilerun phase are not aware of the information related to the Opoj layer. Should the first broadcast be performed after post_process_after_loading_for_shared_weight_series instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's already included in post_process_after_loading_for_shared_weight_series. See https://github.com/vllm-project/vllm-ascend/blob/main/vllm_ascend/torchair/ops/shared_weight_layer.py#L73
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK ,understand. So why we do broadcast in profile run if the first broadcast already included in post_process_after_loading_for_shared_weight_series ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To handle the multi DP cases. When some DP are running dummy_run, they should also broadcast their weights to those DP executing model.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
ba1a760 to
73407a7
Compare
8404356 to
5f7c45c
Compare
| from typing import Callable, Optional | ||
|
|
||
| import torch | ||
| import torch.distributed as dist |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not a ops, we should consider to move to a better place
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Signed-off-by: zzhx1 <[email protected]> Co-authored-by: clrs97 <[email protected]> Co-authored-by: Levi-JQ <[email protected]>
5f7c45c to
de6d64f
Compare
Signed-off-by: zzhx1 <[email protected]>
2427a23 to
0835a03
Compare
Signed-off-by: zzhx1 <[email protected]>
0835a03 to
3a1bf01
Compare
|
@wangxiyuan this PR is ready, please help merge it in. |
What this PR does / why we need it?
It is mentioned in the flashcomm2 technical report that FC2 will introduce full redundant storage of the o_proj matrix, which will put pressure on the memory. Therefore, the technical report proposed a compromise solution using otp2, but it will introduce additional reduce-scatter communication.
We propose a shared linear feature (#2931 ) that supports distributing weights layer by layer to each card, avoiding the need for TP splitting, and can solve the memory issue.
This PR depends on #3232 and #2931
Flashcomm2 flowchart
Does this PR introduce any user-facing change?
Use environment variables