-
Notifications
You must be signed in to change notification settings - Fork 629
[Feature] Support multi graphs for torchair #4757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Jade Zheng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request enables support for multiple Torchair graphs with MTP by removing a hardcoded graph size configuration and refining the graph size alignment logic. The changes look good overall, but I've identified an inconsistency in the alignment logic that should be addressed. The new alignment method introduced in _align_graph_size_divisible_by_tp_size is an improvement, but the old, less optimal logic is still used elsewhere, which could lead to issues. My review comment provides details on how to resolve this.
| cur_graph_batch_size = (graph_batch_size + lcm_size - | ||
| 1) // lcm_size * lcm_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alignment logic for cur_graph_batch_size has been updated here to round up to the nearest multiple of lcm(tp_size, self.decode_token_per_req). This is a good improvement as it can result in smaller, more appropriate batch sizes compared to the previous lcm(tp_size, graph_batch_size) logic, especially when graph_batch_size is large.
However, the old logic is still being used in _init_mc2_tokens_capacity via calculate_new_torchair_graph_batch_size. This inconsistency could lead to an unnecessarily large mc2_tokens_capacity, potentially causing memory issues or hitting device limits.
To ensure consistency and correctness, _init_mc2_tokens_capacity should be updated to use the same alignment logic. After this change, calculate_new_torchair_graph_batch_size would become dead code and could be removed for better code hygiene.
For example, _init_mc2_tokens_capacity could be updated as follows:
def _init_mc2_tokens_capacity(self):
# NOTE: To be clear, we need to make sure that during graph capture, the number of
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
tp_size = self.parallel_config.tensor_parallel_size
# Use the new alignment logic
lcm_size = math.lcm(tp_size, self.uniform_decode_query_len)
max_graph_batch_size = (max_num_tokens + lcm_size - 1) // lcm_size * lcm_size
self.mc2_tokens_capacity = max_graph_batch_size
if get_ascend_device_type(
) == AscendDeviceType._910_93 and self.mc2_tokens_capacity > 512:
logger.error(
f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
)
if get_ascend_device_type(
) == AscendDeviceType._910B and self.mc2_tokens_capacity > 256:
logger.error(
f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
)|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
What this PR does / why we need it?
This PR supports configuring multiple Torchair graphs when MTP is enabled.
Does this PR introduce any user-facing change?
When MTP is enabled, users can set up multiple Torchair graphs, though these may adjust automatically based on the environment.
How was this patch tested?