-
Notifications
You must be signed in to change notification settings - Fork 617
[Feat] Add custom Embedding tensor model parallel #2616
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces tensor parallelism for embeddings and the LM head on Ascend devices, aiming to reduce memory consumption. The implementation adds new configurations, parallel communication groups, and custom operators. While the overall approach is sound, I've identified a few critical issues. The forward pass for embedding tensor parallelism incorrectly uses the LM head's communication group and has a flawed tensor slicing logic. Additionally, the tests for the new features are incomplete, and there are some minor maintainability concerns like a non-English comment and excessive debug logging that should be addressed.
| def _get_local_batch_slice(self, tensor: torch.Tensor, | ||
| batch_sizes: list, | ||
| local_batch_size: int, | ||
| rank: int) -> torch.Tensor: | ||
| """Extract local batch portion from gathered tensor. | ||
| Args: | ||
| tensor: The gathered tensor to slice | ||
| batch_sizes: List of batch sizes for each rank | ||
| local_batch_size: Size of current rank's batch | ||
| rank: Current rank index | ||
| Returns: | ||
| Sliced tensor containing only the local batch data | ||
| """ | ||
| end_idx = batch_sizes[rank] | ||
| start_idx = end_idx - local_batch_size | ||
| return tensor[start_idx:end_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _get_local_batch_slice method has a flaw in its slicing logic. It calculates start_idx as end_idx - local_batch_size, where end_idx is just batch_sizes[rank]. This is incorrect as batch_sizes is a list of sizes per rank, not cumulative sizes. To get the correct slice for a given rank, you need to sum the batch sizes of all preceding ranks to find the correct starting offset.
| def _get_local_batch_slice(self, tensor: torch.Tensor, | |
| batch_sizes: list, | |
| local_batch_size: int, | |
| rank: int) -> torch.Tensor: | |
| """Extract local batch portion from gathered tensor. | |
| Args: | |
| tensor: The gathered tensor to slice | |
| batch_sizes: List of batch sizes for each rank | |
| local_batch_size: Size of current rank's batch | |
| rank: Current rank index | |
| Returns: | |
| Sliced tensor containing only the local batch data | |
| """ | |
| end_idx = batch_sizes[rank] | |
| start_idx = end_idx - local_batch_size | |
| return tensor[start_idx:end_idx] | |
| def _get_local_batch_slice(self, tensor: torch.Tensor, | |
| batch_sizes: list, | |
| local_batch_size: int, | |
| rank: int) -> torch.Tensor: | |
| """Extract local batch portion from gathered tensor. | |
| Args: | |
| tensor: The gathered tensor to slice | |
| batch_sizes: List of batch sizes for each rank | |
| local_batch_size: Size of current rank's batch | |
| rank: Current rank index | |
| Returns: | |
| Sliced tensor containing only the local batch data | |
| """ | |
| start_idx = sum(batch_sizes[:rank]) | |
| end_idx = start_idx + local_batch_size | |
| return tensor[start_idx:end_idx] |
| def _forward_embed_tp(self, input_): | ||
| cu_tokens_across_dp_cpu = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu | ||
| global_dp_batch_size = torch.diff(cu_tokens_across_dp_cpu, prepend=cu_tokens_across_dp_cpu.new_zeros(1)) | ||
| logger.info(f"debug input_: {input_.shape} \n global_dp_batch_size: {global_dp_batch_size}\n ") | ||
| lmhead_group_batch_size = [global_dp_batch_size[x] for x in get_lmhead_tp_group().ranks] | ||
| local_batch_size = input_.size(0) | ||
| gathered_input = [torch.empty(batch_size, dtype=input_.dtype, device='npu') for batch_size in lmhead_group_batch_size] | ||
| torch.distributed.all_gather( | ||
| gathered_input, input_, group=get_lmhead_tp_group().device_group) | ||
| complete_input = torch.cat(gathered_input, dim=0) | ||
| masked_input, input_mask = get_masked_input_and_mask( | ||
| input_, self.shard_indices.org_vocab_start_index, | ||
| complete_input, self.shard_indices.org_vocab_start_index, | ||
| self.shard_indices.org_vocab_end_index, | ||
| self.shard_indices.num_org_vocab_padding, | ||
| self.shard_indices.added_vocab_start_index, | ||
| self.shard_indices.added_vocab_end_index) | ||
| else: | ||
| masked_input = input_ | ||
| # Get the embeddings. | ||
| output_parallel = self.quant_method.embedding(self, masked_input.long()) | ||
| # Mask the output embedding. | ||
| if self.tp_size > 1: | ||
| output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) | ||
| # Reduce across all the model parallel GPUs. | ||
| output = tensor_model_parallel_all_reduce(output_parallel) | ||
| return output | ||
|
|
||
|
|
||
| VocabParallelEmbedding.forward = vocab_parallel_embedding_forward | ||
| logger.info(f"all_gather_down complete_input: {complete_input.shape}") | ||
|
|
||
| output = self.quant_method.embedding(self, masked_input.long()) | ||
| output.masked_fill_(input_mask.unsqueeze(-1), 0) | ||
| output = tensor_model_parallel_all_reduce(output) | ||
| # output = output[lmhead_group_batch_size[get_lmhead_tp_group().rank_in_group]-local_batch_size :lmhead_group_batch_size[get_lmhead_tp_group().rank_in_group]] | ||
| # Extract the local batch portion from the gathered output | ||
| lmhead_tp_group = get_lmhead_tp_group() | ||
| output = self._get_local_batch_slice( | ||
| output, | ||
| lmhead_group_batch_size, | ||
| local_batch_size, | ||
| lmhead_tp_group.rank_in_group | ||
| ) | ||
| logger.info(f"rank:{get_dp_group().rank_in_group} output: {output.shape}") | ||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _forward_embed_tp method incorrectly uses get_lmhead_tp_group() for its operations. When embedding_tp_enable() is true, it should be using the embedding tensor parallel group obtained via get_emtp_group(). This is a critical bug that will lead to incorrect communication patterns and failures.
def _forward_embed_tp(self, input_):
cu_tokens_across_dp_cpu = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
global_dp_batch_size = torch.diff(cu_tokens_across_dp_cpu, prepend=cu_tokens_across_dp_cpu.new_zeros(1))
logger.info(f"debug input_: {input_.shape} \n global_dp_batch_size: {global_dp_batch_size}\n ")
emtp_group = get_emtp_group()
emtp_group_batch_size = [global_dp_batch_size[x] for x in emtp_group.ranks]
local_batch_size = input_.size(0)
gathered_input = [torch.empty(batch_size, dtype=input_.dtype, device='npu') for batch_size in emtp_group_batch_size]
torch.distributed.all_gather(
gathered_input, input_, group=emtp_group.device_group)
complete_input = torch.cat(gathered_input, dim=0)
masked_input, input_mask = get_masked_input_and_mask(
complete_input, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
logger.info(f"all_gather_down complete_input: {complete_input.shape}")
output = self.quant_method.embedding(self, masked_input.long())
output.masked_fill_(input_mask.unsqueeze(-1), 0)
output = tensor_model_parallel_all_reduce(output)
# Extract the local batch portion from the gathered output
output = self._get_local_batch_slice(
output,
emtp_group_batch_size,
local_batch_size,
emtp_group.rank_in_group
)
logger.info(f"rank:{get_dp_group().rank_in_group} output: {output.shape}")
return outputThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix
| def test_init_ascend_model_parallel(mock_distributed, parallel_config): | ||
| with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \ | ||
| patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'): | ||
| parallel_config.lmhead_tensor_parallel_size = 2 | ||
| init_ascend_model_parallel(parallel_config) | ||
|
|
||
| mc2_group = get_mc2_group() | ||
| assert mc2_group is not None | ||
| lmheadtp_group = get_lmhead_tp_group() | ||
| assert lmheadtp_group is not None | ||
|
|
||
| destroy_ascend_model_parallel() | ||
| assert _MC2 is None | ||
| assert _LMTP is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test test_init_ascend_model_parallel only validates the initialization for lmhead_tensor_parallel_size. The new embedding_tensor_parallel_size feature, which is a core part of this PR, is not tested here. This leaves a significant gap in test coverage. Please add a similar test case for embedding_tensor_parallel_size to ensure it is initialized and destroyed correctly.
| with self.assertRaises(AssertionError): | ||
| test_vllm_config.additional_config = { | ||
| "lmhead_tensor_parallel_size": 2, | ||
| "refresh": True | ||
| } | ||
| test_vllm_config.parallel_config = ParallelConfig( | ||
| data_parallel_size=4, tensor_parallel_size=2) | ||
| init_ascend_config(test_vllm_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| get_world_group().local_rank, | ||
| backend, | ||
| group_name="emtp") | ||
| # 输出日志:成功建立Embedding 通信并行组 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is in Chinese, which is inconsistent with the rest of the codebase that is in English. For better maintainability and readability for all contributors, please translate it to English.
| # 输出日志:成功建立Embedding 通信并行组 | |
| # Log success in establishing the embedding communication parallel group |
|
|
||
| def forward(self, input_): | ||
| if embedding_tp_enable(): | ||
| logger.info(f"rank:{get_dp_group().rank_in_group} embedding_tp_enable") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file contains several logger.info calls that seem to be for debugging purposes. These logs can be very noisy in a production environment and may impact performance. Please consider changing them to logger.debug or removing them if they are no longer needed.
| logger.info(f"rank:{get_dp_group().rank_in_group} embedding_tp_enable") | |
| logger.debug(f"rank:{get_dp_group().rank_in_group} embedding_tp_enable") |
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
|
please rebase and fix CI |
0fe35e0 to
213c27e
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
479fbce to
b1335de
Compare
|
@wangxiyuan Please add the "ready-to-test" lable. |
|
|
||
| def _forward_embed_tp(self, input_): | ||
| if get_ascend_config( | ||
| ).torchair_graph_config.enabled is False and not self.is_decode_only: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
custom ops doesn't work with torchair. you can just skip this check
| backend, | ||
| group_name="lmheadtp") | ||
|
|
||
| embedding_tensor_parallel_size = get_ascend_config( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These TP communication groups may be consolidated, given their group creation logic is similar.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Signed-off-by: zzhx1 <[email protected]>
Signed-off-by: zzhx1 <[email protected]>
Signed-off-by: zzhx1 <[email protected]>
Signed-off-by: zzhx1 <[email protected]>
What this PR does / why we need it?
Similar to #2309 , this PR introduces Embedding tensor model parallel to achieve decreasing of memory consumption. It support both eager mode and graph mode.
Does this PR introduce any user-facing change?
This PR introduces one new config in
additional_config.example
--additional_config={"embedding_tensor_parallel_size": 8}How was this patch tested?