Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions vllm_ascend/torchair/torchair_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,13 +501,7 @@ def select_torchair_padded_batch_size(self, batch_size: int):
def update_torchair_graph_batch_sizes(self):
# return graph_batch_sizes according to the max number of tokens
# first pad according to the number of requests
if self.is_kv_consumer and self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
# pd disaggregation scenario may incorrectly calculate the batch in mtp scenario, so we force set it to max_num_reqs
self.torchair_graph_batch_sizes = [self.max_num_reqs]
logger.warning(
f"is kv_consumer, torch_graph_batch_sizes sets to [max_num_seqs] {[self.max_num_reqs]}"
)
elif len(self.torchair_graph_batch_sizes) == 0:
if len(self.torchair_graph_batch_sizes) == 0:
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
else:
self.torchair_graph_batch_sizes = sorted(
Expand Down Expand Up @@ -537,10 +531,11 @@ def update_torchair_graph_batch_sizes(self):

def _align_graph_size_divisible_by_tp_size(self):
tp_size = self.parallel_config.tensor_parallel_size
lcm_size = math.lcm(tp_size, self.decode_token_per_req)
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = self.calculate_new_torchair_graph_batch_size(
graph_batch_size, tp_size)
cur_graph_batch_size = (graph_batch_size + lcm_size -
1) // lcm_size * lcm_size
Comment on lines +537 to +538
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The alignment logic for cur_graph_batch_size has been updated here to round up to the nearest multiple of lcm(tp_size, self.decode_token_per_req). This is a good improvement as it can result in smaller, more appropriate batch sizes compared to the previous lcm(tp_size, graph_batch_size) logic, especially when graph_batch_size is large.

However, the old logic is still being used in _init_mc2_tokens_capacity via calculate_new_torchair_graph_batch_size. This inconsistency could lead to an unnecessarily large mc2_tokens_capacity, potentially causing memory issues or hitting device limits.

To ensure consistency and correctness, _init_mc2_tokens_capacity should be updated to use the same alignment logic. After this change, calculate_new_torchair_graph_batch_size would become dead code and could be removed for better code hygiene.

For example, _init_mc2_tokens_capacity could be updated as follows:

    def _init_mc2_tokens_capacity(self):
        # NOTE: To be clear, we need to make sure that during graph capture, the number of
        # tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
        # the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
        max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
        tp_size = self.parallel_config.tensor_parallel_size
        # Use the new alignment logic
        lcm_size = math.lcm(tp_size, self.uniform_decode_query_len)
        max_graph_batch_size = (max_num_tokens + lcm_size - 1) // lcm_size * lcm_size
        self.mc2_tokens_capacity = max_graph_batch_size

        if get_ascend_device_type(
        ) == AscendDeviceType._910_93 and self.mc2_tokens_capacity > 512:
            logger.error(
                f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
            )
        if get_ascend_device_type(
        ) == AscendDeviceType._910B and self.mc2_tokens_capacity > 256:
            logger.error(
                f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
            )

if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)
Expand Down
Loading