Skip to content

Conversation

@lidenghui1110
Copy link
Contributor

@lidenghui1110 lidenghui1110 commented Aug 28, 2025

What this PR does / why we need it?

Similar to #2309 , this PR introduces Embedding tensor model parallel to achieve decreasing of memory consumption. It support both eager mode and graph mode.

Does this PR introduce any user-facing change?

This PR introduces one new config in additional_config.

Name Effect Required Type Constraints
embedding_tensor_parallel_size Split the vocab column dimension into embedding_tensor_parallel_size pieces No int default value is None, once this value is set, the feature will be enabled, vocab_size must be divisible by this value.

example

--additional_config={"embedding_tensor_parallel_size": 8}

How was this patch tested?

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces tensor parallelism for embeddings and the LM head on Ascend devices, aiming to reduce memory consumption. The implementation adds new configurations, parallel communication groups, and custom operators. While the overall approach is sound, I've identified a few critical issues. The forward pass for embedding tensor parallelism incorrectly uses the LM head's communication group and has a flawed tensor slicing logic. Additionally, the tests for the new features are incomplete, and there are some minor maintainability concerns like a non-English comment and excessive debug logging that should be addressed.

Comment on lines 155 to 157
def _get_local_batch_slice(self, tensor: torch.Tensor,
batch_sizes: list,
local_batch_size: int,
rank: int) -> torch.Tensor:
"""Extract local batch portion from gathered tensor.
Args:
tensor: The gathered tensor to slice
batch_sizes: List of batch sizes for each rank
local_batch_size: Size of current rank's batch
rank: Current rank index
Returns:
Sliced tensor containing only the local batch data
"""
end_idx = batch_sizes[rank]
start_idx = end_idx - local_batch_size
return tensor[start_idx:end_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _get_local_batch_slice method has a flaw in its slicing logic. It calculates start_idx as end_idx - local_batch_size, where end_idx is just batch_sizes[rank]. This is incorrect as batch_sizes is a list of sizes per rank, not cumulative sizes. To get the correct slice for a given rank, you need to sum the batch sizes of all preceding ranks to find the correct starting offset.

Suggested change
def _get_local_batch_slice(self, tensor: torch.Tensor,
batch_sizes: list,
local_batch_size: int,
rank: int) -> torch.Tensor:
"""Extract local batch portion from gathered tensor.
Args:
tensor: The gathered tensor to slice
batch_sizes: List of batch sizes for each rank
local_batch_size: Size of current rank's batch
rank: Current rank index
Returns:
Sliced tensor containing only the local batch data
"""
end_idx = batch_sizes[rank]
start_idx = end_idx - local_batch_size
return tensor[start_idx:end_idx]
def _get_local_batch_slice(self, tensor: torch.Tensor,
batch_sizes: list,
local_batch_size: int,
rank: int) -> torch.Tensor:
"""Extract local batch portion from gathered tensor.
Args:
tensor: The gathered tensor to slice
batch_sizes: List of batch sizes for each rank
local_batch_size: Size of current rank's batch
rank: Current rank index
Returns:
Sliced tensor containing only the local batch data
"""
start_idx = sum(batch_sizes[:rank])
end_idx = start_idx + local_batch_size
return tensor[start_idx:end_idx]

Comment on lines 181 to 172
def _forward_embed_tp(self, input_):
cu_tokens_across_dp_cpu = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
global_dp_batch_size = torch.diff(cu_tokens_across_dp_cpu, prepend=cu_tokens_across_dp_cpu.new_zeros(1))
logger.info(f"debug input_: {input_.shape} \n global_dp_batch_size: {global_dp_batch_size}\n ")
lmhead_group_batch_size = [global_dp_batch_size[x] for x in get_lmhead_tp_group().ranks]
local_batch_size = input_.size(0)
gathered_input = [torch.empty(batch_size, dtype=input_.dtype, device='npu') for batch_size in lmhead_group_batch_size]
torch.distributed.all_gather(
gathered_input, input_, group=get_lmhead_tp_group().device_group)
complete_input = torch.cat(gathered_input, dim=0)
masked_input, input_mask = get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index,
complete_input, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self, masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output


VocabParallelEmbedding.forward = vocab_parallel_embedding_forward
logger.info(f"all_gather_down complete_input: {complete_input.shape}")

output = self.quant_method.embedding(self, masked_input.long())
output.masked_fill_(input_mask.unsqueeze(-1), 0)
output = tensor_model_parallel_all_reduce(output)
# output = output[lmhead_group_batch_size[get_lmhead_tp_group().rank_in_group]-local_batch_size :lmhead_group_batch_size[get_lmhead_tp_group().rank_in_group]]
# Extract the local batch portion from the gathered output
lmhead_tp_group = get_lmhead_tp_group()
output = self._get_local_batch_slice(
output,
lmhead_group_batch_size,
local_batch_size,
lmhead_tp_group.rank_in_group
)
logger.info(f"rank:{get_dp_group().rank_in_group} output: {output.shape}")
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _forward_embed_tp method incorrectly uses get_lmhead_tp_group() for its operations. When embedding_tp_enable() is true, it should be using the embedding tensor parallel group obtained via get_emtp_group(). This is a critical bug that will lead to incorrect communication patterns and failures.

    def _forward_embed_tp(self, input_):
        cu_tokens_across_dp_cpu = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
        global_dp_batch_size = torch.diff(cu_tokens_across_dp_cpu, prepend=cu_tokens_across_dp_cpu.new_zeros(1))
        logger.info(f"debug input_: {input_.shape} \n global_dp_batch_size: {global_dp_batch_size}\n ")
        emtp_group = get_emtp_group()
        emtp_group_batch_size = [global_dp_batch_size[x] for x in emtp_group.ranks]
        local_batch_size = input_.size(0)
        gathered_input = [torch.empty(batch_size, dtype=input_.dtype, device='npu') for batch_size in emtp_group_batch_size]
        torch.distributed.all_gather(
            gathered_input, input_, group=emtp_group.device_group)
        complete_input = torch.cat(gathered_input, dim=0)
        masked_input, input_mask = get_masked_input_and_mask(
            complete_input, self.shard_indices.org_vocab_start_index,
            self.shard_indices.org_vocab_end_index,
            self.shard_indices.num_org_vocab_padding,
            self.shard_indices.added_vocab_start_index,
            self.shard_indices.added_vocab_end_index)
        logger.info(f"all_gather_down complete_input: {complete_input.shape}")
        
        output = self.quant_method.embedding(self, masked_input.long())
        output.masked_fill_(input_mask.unsqueeze(-1), 0)
        output = tensor_model_parallel_all_reduce(output)
        # Extract the local batch portion from the gathered output
        output = self._get_local_batch_slice(
            output, 
            emtp_group_batch_size, 
            local_batch_size, 
            emtp_group.rank_in_group
        )
        logger.info(f"rank:{get_dp_group().rank_in_group}  output: {output.shape}")
        return output

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

Comment on lines 29 to 74
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'):
parallel_config.lmhead_tensor_parallel_size = 2
init_ascend_model_parallel(parallel_config)

mc2_group = get_mc2_group()
assert mc2_group is not None
lmheadtp_group = get_lmhead_tp_group()
assert lmheadtp_group is not None

destroy_ascend_model_parallel()
assert _MC2 is None
assert _LMTP is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The test test_init_ascend_model_parallel only validates the initialization for lmhead_tensor_parallel_size. The new embedding_tensor_parallel_size feature, which is a core part of this PR, is not tested here. This leaves a significant gap in test coverage. Please add a similar test case for embedding_tensor_parallel_size to ensure it is initialized and destroyed correctly.

Comment on lines 308 to 338
with self.assertRaises(AssertionError):
test_vllm_config.additional_config = {
"lmhead_tensor_parallel_size": 2,
"refresh": True
}
test_vllm_config.parallel_config = ParallelConfig(
data_parallel_size=4, tensor_parallel_size=2)
init_ascend_config(test_vllm_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test case only checks the assertion for lmhead_tensor_parallel_size when tensor_parallel_size != 1. A similar check is needed for embedding_tensor_parallel_size to ensure configuration validation is complete for the new features introduced in this PR.

get_world_group().local_rank,
backend,
group_name="emtp")
# 输出日志:成功建立Embedding 通信并行组
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This comment is in Chinese, which is inconsistent with the rest of the codebase that is in English. For better maintainability and readability for all contributors, please translate it to English.

Suggested change
# 输出日志:成功建立Embedding 通信并行组
# Log success in establishing the embedding communication parallel group


def forward(self, input_):
if embedding_tp_enable():
logger.info(f"rank:{get_dp_group().rank_in_group} embedding_tp_enable")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This file contains several logger.info calls that seem to be for debugging purposes. These logs can be very noisy in a production environment and may impact performance. Please consider changing them to logger.debug or removing them if they are no longer needed.

Suggested change
logger.info(f"rank:{get_dp_group().rank_in_group} embedding_tp_enable")
logger.debug(f"rank:{get_dp_group().rank_in_group} embedding_tp_enable")

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link

github-actions bot commented Sep 7, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@wangxiyuan
Copy link
Collaborator

please rebase and fix CI

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@lidenghui1110 lidenghui1110 changed the title [WIP][Feat] Add custom Embedding tensor model parallel [Feat] Add custom Embedding tensor model parallel Oct 13, 2025
@zzhx1
Copy link
Contributor

zzhx1 commented Oct 13, 2025

@wangxiyuan Please add the "ready-to-test" lable.


def _forward_embed_tp(self, input_):
if get_ascend_config(
).torchair_graph_config.enabled is False and not self.is_decode_only:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

custom ops doesn't work with torchair. you can just skip this check

backend,
group_name="lmheadtp")

embedding_tensor_parallel_size = get_ascend_config(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These TP communication groups may be consolidated, given their group creation logic is similar.

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: zzhx1 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation module:core module:ops module:tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants