diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d4b4b25bf74..1cd55c61bcf 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -24,7 +24,8 @@ from copy import deepcopy from dataclasses import dataclass from multiprocessing import Manager -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union +from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, + TypeAlias, Union) import numpy as np import regex as re @@ -32,7 +33,8 @@ import torch.distributed as dist import torch.nn as nn from tqdm import tqdm # type: ignore -from vllm.attention.backends.abstract import AttentionBackend, AttentionType +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, AttentionType) from vllm.attention.layer import Attention, MLAAttention from vllm.attention.selector import get_attn_backend from vllm.compilation.counter import compilation_counter @@ -48,14 +50,14 @@ get_pcp_group, get_pp_group, get_tp_group, is_global_first_rank) -from vllm.forward_context import get_forward_context +from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.model_loader import get_model from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import LazyLoader -from vllm.utils.math_utils import cdiv +from vllm.utils.math_utils import cdiv, round_up from vllm.utils.mem_utils import DeviceMemoryProfiler from vllm.utils.torch_utils import get_dtype_size from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder @@ -80,6 +82,7 @@ from vllm.v1.worker.gpu_model_runner import (AsyncGPUModelRunnerOutput, GPUModelRunner) from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput +from vllm.v1.worker.ubatch_utils import UBatchSlices from vllm.v1.worker.utils import AttentionGroup import vllm_ascend.envs as envs_ascend @@ -133,6 +136,10 @@ import torch_npu +AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] +# list when ubatching is enabled +PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict + # if true, allow tensor initialization and casting with internal format (e.g., NZ) torch.npu.config.allow_internal_format = True @@ -573,6 +580,155 @@ def generate_kv_idx(self, scheduler_output): self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( torch.float32).argsort().to(torch.int32) + # + def _build_attention_metadata( + self, + num_tokens: int, + num_reqs: int, + max_query_len: int, + num_tokens_padded: int | None = None, + num_reqs_padded: int | None = None, + ubatch_slices: UBatchSlices | None = None, + logits_indices: torch.Tensor | None = None, + use_spec_decode: bool = False, + for_cudagraph_capture: bool = False, + num_scheduled_tokens: dict[str, int] | None = None, + cascade_attn_prefix_lens: list[list[int]] | None = None, + ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: + """ + :return: tuple[attn_metadata, spec_decode_common_attn_metadata] + """ + self.attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask() + num_tokens_padded = num_tokens_padded or num_tokens + num_reqs_padded = num_reqs_padded or num_reqs + attn_metadata: PerLayerAttnMetadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] + + if use_spec_decode: + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + + # Used in the below loop, uses padded shapes + query_start_loc = self.query_start_loc.gpu[:num_reqs_padded + 1] + query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs_padded + 1] + seq_lens = self.seq_lens.gpu[:num_reqs_padded] + seq_lens_cpu = self.seq_lens.cpu[:num_reqs_padded] + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[: + num_reqs_padded] + + spec_decode_common_attn_metadata = None + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_gid, kv_cache_group in enumerate( + self.kv_cache_config.kv_cache_groups): + long_seq_metadata = self._generate_pcp_metadata(num_tokens) + if for_cudagraph_capture and long_seq_metadata is not None: + num_computed_tokens_of_pcp_dcp = [[ + [0] * self.dcp_size for _ in range(self.pcp_size) + ] for _ in range(num_tokens)] + long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp + + blk_table = self.input_batch.block_table[kv_cache_gid] + blk_table_tensor = blk_table.get_device_tensor()[:num_reqs_padded] + slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded] + slot_mapping[num_tokens:num_tokens_padded].fill_(-1) + blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) + + attn_state = self.attn_state + if for_cudagraph_capture: + attn_state = AscendAttentionState.DecodeOnly + if self.speculative_config and self.speculative_config.method == "mtp": + attn_state = AscendAttentionState.SpecDecoding + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens_cpu=seq_lens_cpu, + seq_lens=seq_lens, + num_reqs=num_reqs_padded, + # TODO + num_actual_tokens=num_tokens, + # TODO + num_input_tokens=num_tokens_padded, + # TODO + actual_seq_lengths_q=self.actual_seq_lengths_q, + block_table_tensor=blk_table_tensor, + slot_mapping=slot_mapping, + max_query_len=max_query_len, + # max_seq_len=max_seq_len, + num_computed_tokens_cpu=num_computed_tokens_cpu, + # TODO + positions=self.positions.gpu, + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=attn_state, + decode_token_per_req=self.decode_token_per_req, + prefill_context_parallel_metadata=None, + ) + if self.pcp_size > 1 and for_cudagraph_capture: + common_attn_metadata.block_table_tensor = \ + blk_table_tensor[:num_reqs * self.decode_threshold] + + if self.speculative_config and self.pcp_size > 1 and not for_cudagraph_capture: + # For pcp + spec decode, we flatten block_table + # to avoid irregular spec_attn_mask shape, e.g., + # num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1, + # ori block_table: # [d0, d1, p0, p1, p2] + # (num_reqs_d + num_reqs_p, max_num_blocks), + # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] + # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), + ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs+1] - \ + self.query_start_loc_pcp_full.cpu[:num_reqs] + num_prefill_reqs = (ori_query_lens + > self.decode_threshold).sum().item() + num_decode_reqs = num_reqs - num_prefill_reqs + num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold + blk_table_tensor[ + num_decode_reqs_flatten:num_decode_reqs_flatten + + num_prefill_reqs].copy_( + blk_table_tensor[num_decode_reqs:num_decode_reqs + + num_prefill_reqs].clone(), + non_blocking=True) + blk_table_tensor[:num_decode_reqs_flatten].copy_( + blk_table_tensor[:num_decode_reqs].repeat_interleave( + self.decode_threshold, dim=0), + non_blocking=True) + common_attn_metadata.block_table_tensor = \ + blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs] + + # TODO(zhenwenqi) it is different from gpu_model_runner + if self.speculative_config and spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + + for attn_group in self.attn_groups[kv_cache_gid]: + builder = attn_group.get_metadata_builder() + extra_attn_metadata_args = {} + if use_spec_decode and isinstance(builder, + GDNAttentionMetadataBuilder): + patch_torch_npu_argsort() + extra_attn_metadata_args = dict( + num_accepted_tokens=self.num_accepted_tokens. + gpu[:num_reqs_padded], + num_decode_draft_tokens_cpu=self. + num_decode_draft_tokens.cpu[:num_reqs_padded], + ) + if for_cudagraph_capture: + attn_metadata_i = builder.build_for_graph_capture( + common_attn_metadata, attn_state, self.get_model()) + else: + attn_metadata_i = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + model=self.get_model(), + **extra_attn_metadata_args, + ) + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + return attn_metadata, spec_decode_common_attn_metadata + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -1978,119 +2134,218 @@ def _generate_dummy_run_hidden_states(self, input_ids, positions, hidden_states = hidden_states return hidden_states + # padding for sp + def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if enable_sp(): + return round_up(num_scheduled_tokens, tp_size) + return num_scheduled_tokens + + def _determine_batch_execution_and_padding( + self, + num_tokens: int, + num_reqs: int, + num_scheduled_tokens_np: np.ndarray, + max_num_scheduled_tokens: int, + use_cascade_attn: bool, + allow_microbatching: bool = False, + force_eager: bool = False, + # For cudagraph capture TODO(lucas): Refactor how we capture cudagraphs (will + # be improved in model runner v2) + force_uniform_decode: bool | None = None, + force_has_lora: bool | None = None, + with_prefill: bool = False + ) -> tuple[CUDAGraphMode, BatchDescriptor, UBatchSlices | None, + torch.Tensor | None]: + num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) + uniform_decode = ( + ((max_num_scheduled_tokens == self.uniform_decode_query_len) and + (num_tokens_padded == max_num_scheduled_tokens * num_reqs)) + if force_uniform_decode is None else force_uniform_decode) + + has_lora = (len(self.input_batch.lora_id_to_lora_request) > 0 + if force_has_lora is None else force_has_lora) + + def dispatch_cudagraph(num_tokens, has_lora, use_cascade_attn, + uniform_decode): + if not force_eager: + return self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens, + has_lora=has_lora, + use_cascade_attn=use_cascade_attn, + uniform_decode=uniform_decode, + ) + else: + return (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) + + cudagraph_mode, batch_descriptor = dispatch_cudagraph( + num_tokens_padded, has_lora, use_cascade_attn, uniform_decode) + num_tokens_padded = batch_descriptor.num_tokens + + # Extra coordination when running data-parallel since we need to coordinate + # across ranks + ubatch_slices, num_tokens_across_dp = None, None + if self.vllm_config.parallel_config.data_parallel_size > 1: + # Disable DP padding when running eager to avoid excessive padding when + # running prefills. This lets us set cudagraph_mode="NONE" on the prefiller + # in a P/D setup and still use CUDA graphs (enabled by this padding) on the + # decoder. + + if self._skip_all_reduce_acorss_dp_group(): + num_tokens_across_dp = torch.tensor([num_tokens_padded] * + self.dp_size, + device="cpu", + dtype=torch.int32) + else: + # TODO(zhenwenqi) here we are different from gpu + _, num_tokens_across_dp, with_prefill = self._sync_metadata_across_dp( + num_tokens_padded, with_prefill) + + # Extract DP padding if there is any + if num_tokens_across_dp is not None: + dp_rank = self.parallel_config.data_parallel_rank + num_tokens_padded = int(num_tokens_across_dp[dp_rank].item()) + + # Re-dispatch with DP padding + cudagraph_mode, batch_descriptor = dispatch_cudagraph( + num_tokens_padded, has_lora, use_cascade_attn, + uniform_decode) + # Assert to make sure the agreed upon token count is correct otherwise + # num_tokens_across_dp will no-longer be valid + assert batch_descriptor.num_tokens == num_tokens_padded + + return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp + @torch.inference_mode() def _dummy_run( self, num_tokens: int, - with_prefill: bool = False, - aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, + cudagraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, uniform_decode: bool = False, + allow_microbatching: bool = False, + skip_eplb: bool = False, is_profile: bool = False, + create_mixed_batch: bool = False, + remove_lora: bool = True, + activate_lora: bool = False, + is_graph_capturing: bool = False, ) -> torch.Tensor: - # only support eager mode and piecewise graph now - assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in { - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL - } - # In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs. - # If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size. - if self.use_aclgraph and enable_sp(self.vllm_config): - tp_size = self.vllm_config.parallel_config.tensor_parallel_size - num_tokens = math.ceil(num_tokens / tp_size) * tp_size - # Force dummy run on prefill stage when this node is deemed as kv producer. + with_prefill = False if self.is_kv_producer and not self.is_kv_consumer: with_prefill = True - - # Padding for DP - (num_tokens, num_tokens_across_dp, - with_prefill) = self._sync_metadata_across_dp(num_tokens, - with_prefill) - - # If cudagraph_mode.decode_mode() == FULL and - # cudagraph_mode.seperate_routine(). This means that we are using - # different graphs and/or modes for mixed prefill-decode batches vs. - # uniform decode batches. A uniform decode batch means that all - # requests have identical query length, except a potential virtual - # request (shorter) in the batch account for padding. - # Uniform decode batch could either be common pure decode, where - # max_query_len == 1, or speculative decode, where - # max_query_len == 1 + num_spec_decode_tokens. - - # When setting max_query_len = 1, we switch to and capture the optimized - # routine of FA2 for pure decode, i.e., Flashdecode + an optimization - # for GQA/MQA. - max_query_len = self.uniform_decode_query_len if uniform_decode else \ - num_tokens - - # Set num_scheduled_tokens based on num_tokens and max_num_seqs - # for dummy run with LoRA so that the num_reqs collectively - # has num_tokens in total. + if is_profile: + with_prefill = True + if not is_profile and self.dynamic_eplb: + self.eplb_updator.forward_before() + # only support eager mode and piecewise graph now + assert cudagraph_runtime_mode is None or cudagraph_runtime_mode.valid_runtime_modes( + ) + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens assert num_tokens <= self.scheduler_config.max_num_batched_tokens - max_num_reqs = self.max_num_reqs - if uniform_decode: - num_reqs = cdiv(num_tokens, max_query_len) + max_num_reqs = self.scheduler_config.max_num_seqs + if create_mixed_batch: + assert not uniform_decode + # Create mixed batch: + # first half decode tokens, second half one prefill + num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2) + num_prefill_tokens = num_tokens - num_decode_tokens + num_reqs = num_decode_tokens + 1 + + # Create decode requests (1 token each) followed by prefill request + num_scheduled_tokens_list = [1] * num_decode_tokens + [ + num_prefill_tokens + ] + # Note: Overriding max_query_len to be the prefill tokens + max_query_len = num_prefill_tokens + elif uniform_decode: + assert not create_mixed_batch + num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len)) num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] = num_tokens % max_query_len else: - if with_prefill: - num_reqs = num_tokens - else: - num_reqs = (num_tokens + self.decode_token_per_req - - 1) // self.decode_token_per_req - num_reqs = min(num_reqs, max_num_reqs) + num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + self.query_lens = torch.from_numpy(num_scheduled_tokens) + num_tokens_unpadded = int(num_scheduled_tokens.sum()) num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) + _cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = ( + self._determine_batch_execution_and_padding( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs, + num_scheduled_tokens_np=num_scheduled_tokens, + max_num_scheduled_tokens=max_query_len, + use_cascade_attn=False, + allow_microbatching=allow_microbatching, + force_eager=is_profile + or (cudagraph_runtime_mode == CUDAGraphMode.NONE), + # `force_uniform_decode` is used for cudagraph capture; because for + # capturing mixed prefill-decode batches, we sometimes use + # num_tokens == num_reqs which looks like a uniform decode batch to the + # dispatcher; but we actually want to capture a piecewise cudagraph + force_uniform_decode=uniform_decode, + # `force_has_lora` is used for cudagraph capture; because LoRA is + # activated later in the context manager, but we need to know the + # LoRA state when determining the batch descriptor for capture + force_has_lora=activate_lora, + with_prefill=with_prefill)) + if cudagraph_runtime_mode is None: + cudagraph_runtime_mode = _cudagraph_mode + else: + assert cudagraph_runtime_mode == _cudagraph_mode, ( + f"Cudagraph runtime mode mismatch in dummy_run. " + f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}." + ) - if not self.in_profile_run and self.dynamic_eplb: - self.eplb_updator.forward_before() - - has_lora = True if self.lora_config and self.compilation_config.cudagraph_specialize_lora else False - _ag_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(num_tokens=num_tokens, uniform_decode=uniform_decode, has_lora=has_lora) + num_tokens_padded = batch_desc.num_tokens + num_reqs_padded = (batch_desc.num_reqs + if batch_desc.num_reqs is not None else num_reqs) + attn_metadata: PerLayerAttnMetadata | None = None + if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: + if create_mixed_batch: + # In the mixed batch mode (used for FI warmup), we use + # shorter sequence lengths to run faster. + # TODO(luka) better system for describing dummy batches + seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] + else: + seq_lens = max_query_len # type: ignore[assignment] + self.seq_lens.np[:num_reqs] = seq_lens + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() - num_tokens_padded = batch_descriptor.num_tokens - num_reqs_padded = (batch_descriptor.num_reqs if - batch_descriptor.num_reqs is not None else num_reqs) - if num_tokens_across_dp is not None and num_tokens_padded != num_tokens: - # pad is needed if the pad of `num_tokens` is triggered inside CudagraphDispatcher - num_tokens_across_dp[:] = num_tokens_padded - num_scheduled_tokens = num_scheduled_tokens.repeat(num_reqs_padded) - - # filter out the valid batch descriptor - if aclgraph_runtime_mode is not None: - # we allow forcing NONE when the dispatcher disagrees to support - # warm ups for aclgraph capture - if aclgraph_runtime_mode != CUDAGraphMode.NONE and aclgraph_runtime_mode != _ag_mode: - raise ValueError( - f"Aclgraph runtime mode mismatch at dummy_run. " - f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.") - else: - aclgraph_runtime_mode = _ag_mode - - # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup - # and not supported in ASCEND now. We could remove it in the future. - attn_metadata = self._build_dummy_attn_metadata( - False, - num_reqs=num_reqs_padded, - num_tokens=num_tokens_padded, - max_query_len=max_query_len, - aclgraph_runtime_mode=aclgraph_runtime_mode, - force_attention=force_attention, - num_scheduled_tokens=num_scheduled_tokens, - ) + cum_num_tokens, _ = self._get_cumsum_and_arange( + num_scheduled_tokens) + self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens + self.query_start_loc.copy_to_gpu() + + attn_metadata, _ = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_reqs=num_reqs_padded, + max_query_len=max_query_len, + ubatch_slices=ubatch_slices, + for_cudagraph_capture=is_graph_capturing, + ) - with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens, - num_sampled_tokens): + with self.maybe_dummy_run_with_lora( + self.lora_config, + num_scheduled_tokens, + num_sampled_tokens, + activate_lora, + remove_lora, + ): # Make sure padding doesn't exceed max_num_tokens + # TODO(zhenwenqi) Here we do not do multimodal, it is different from gpu assert num_tokens_padded <= self.max_num_tokens if self.is_multimodal_model: input_ids = None @@ -2154,9 +2409,9 @@ def dummy_drafter_compute_logits(hidden_states): num_tokens_across_dp=num_tokens_across_dp, with_prefill=with_prefill, in_profile_run=is_profile, - num_actual_tokens=0, - aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor, + num_actual_tokens=num_tokens, + aclgraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_desc, prefetch_stream=self.prefetch_stream, model_instance=self.model, weight_prefetch_method=self.weight_prefetch_method): @@ -2171,10 +2426,10 @@ def dummy_drafter_compute_logits(hidden_states): with_prefill=with_prefill, num_reqs=num_reqs_padded, num_tokens_across_dp=num_tokens_across_dp, - aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor, + aclgraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_desc, dummy_compute_logits=dummy_drafter_compute_logits, - in_graph_capturing=not force_attention) + skip_attn=not force_attention) if self.in_profile_run and self.dynamic_eplb: self.model.clear_all_moe_loads() if not self.in_profile_run and self.dynamic_eplb: @@ -2202,20 +2457,18 @@ def _dummy_sampler_run( # Sometimes, after the model is compiled through the AOT backend, # the model output may become a list containing only one Tensor object. if isinstance(hidden_states, list) and \ - len(hidden_states) == 1 and \ - isinstance(hidden_states[0], torch.Tensor): + len(hidden_states) == 1 and \ + isinstance(hidden_states[0], torch.Tensor): hidden_states = hidden_states[0] - hidden_states = hidden_states[logit_indices] - output = self.model.compute_logits(hidden_states) + hidden_states = hidden_states[logit_indices] + output = self.model.compute_logits(hidden_states) return output def profile_run(self) -> None: mc2_tokens_capacity = get_mc2_tokens_capacity() if self.max_num_tokens > mc2_tokens_capacity and \ select_moe_comm_method(mc2_tokens_capacity, self.vllm_config) == MoECommType.MC2: - self._dummy_run(mc2_tokens_capacity, - with_prefill=True, - is_profile=True) + self._dummy_run(mc2_tokens_capacity, is_profile=True) super().profile_run() def eplb_warmup(self): @@ -2979,7 +3232,7 @@ def _capture_aclgraphs(self, compilation_cases: list[int], # compilation_case=1 will cause the dynamic shape position to be incorrectly derived. if not self.get_kv_cache_spec(): self._dummy_run(2, - aclgraph_runtime_mode=CUDAGraphMode.NONE, + cudagraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, uniform_decode=uniform_decode) # We skip EPLB here since we don't want to record dummy metrics @@ -2991,11 +3244,11 @@ def _capture_aclgraphs(self, compilation_cases: list[int], # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. self._dummy_run(num_tokens, - aclgraph_runtime_mode=CUDAGraphMode.NONE, + cudagraph_runtime_mode=CUDAGraphMode.NONE, force_attention=force_attention, uniform_decode=uniform_decode) self._dummy_run(num_tokens, - aclgraph_runtime_mode=aclgraph_runtime_mode, + cudagraph_runtime_mode=aclgraph_runtime_mode, force_attention=force_attention, uniform_decode=uniform_decode)