diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index c43fe7c035d..b674ce0ff22 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -175,6 +175,7 @@ jobs: VLLM_USE_MODELSCOPE: True if: ${{ inputs.type == 'full' }} run: | + pytest -sv tests/e2e/multicard/test_full_graph_mode.py pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/test_expert_parallel.py # pytest -sv tests/e2e/multicard/test_external_launcher.py diff --git a/tests/e2e/multicard/test_full_graph_mode.py b/tests/e2e/multicard/test_full_graph_mode.py index 3b9f2932309..72515ef7e46 100644 --- a/tests/e2e/multicard/test_full_graph_mode.py +++ b/tests/e2e/multicard/test_full_graph_mode.py @@ -29,7 +29,7 @@ from tests.e2e.model_utils import check_outputs_equal -def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH(): +def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY(): if 'HCCL_OP_EXPANSION_MODE' in os.environ: del os.environ['HCCL_OP_EXPANSION_MODE'] prompts = [ @@ -70,3 +70,45 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH(): name_0="vllm_eager_outputs", name_1="vllm_fullgraph_outputs", ) + +def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL(): + if 'HCCL_OP_EXPANSION_MODE' in os.environ: + del os.environ['HCCL_OP_EXPANSION_MODE'] + prompts = [ + "Hello, my name is", "The president of the United States is", + "The capital of France is", "The future of AI is" + ] + model = "Qwen/Qwen3-30B-A3B" + sampling_params = SamplingParams(max_tokens=32, temperature=0.0) + with VllmRunner(model, + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=False, + compilation_config={"cudagraph_mode": + "FULL"}) as runner: + vllm_fullgraph_outputs = runner.model.generate(prompts, + sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + ) as runner: + vllm_eager_outputs = runner.model.generate(prompts, sampling_params) + + vllm_fullgraph_outputs_list = [] + for output in vllm_fullgraph_outputs: + vllm_fullgraph_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + vllm_eager_outputs_list = [] + for output in vllm_eager_outputs: + vllm_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=vllm_fullgraph_outputs_list, + name_0="vllm_eager_outputs", + name_1="vllm_fullgraph_outputs", + ) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 594e825e6c9..07eb56d29cb 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -191,6 +191,7 @@ class AscendMetadata: seq_lens: torch.Tensor = None seq_lens_list: List[int] = None # type: ignore actual_seq_lengths_q: List[int] = None # type: ignore + query_start_loc_list: List[int] = None # type: ignore query_start_loc: torch.Tensor = None query_lens: torch.Tensor = None @@ -217,7 +218,8 @@ class AscendMetadata: class AscendAttentionMetadataBuilder: # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + AttentionCGSupport.ALWAYS + # AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. @@ -368,6 +370,7 @@ def build( num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, block_tables=block_table, query_start_loc=query_start_loc, + query_start_loc_list=query_start_loc_cpu[1:].tolist(), query_lens=query_lens, seq_lens=seq_lens, seq_lens_list=seq_lens.tolist(), @@ -449,6 +452,123 @@ def __init__( ) if self.dcp_size > 1 else 0 self.dcp_group = get_dcp_group( ).device_group if self.dcp_size > 1 else None + + def full_graph_attention(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + num_tokens=0): + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + block_size = 128 + block_table = None + actual_seq_lengths_kv = attn_metadata.query_start_loc_list + elif attn_metadata.attn_state == \ + AscendAttentionState.PrefillCacheHit: + batch_size = attn_metadata.query_lens.shape[0] + block_table = attn_metadata.block_tables[:batch_size, :] + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + actual_seq_lengths_kv = attn_metadata.seq_lens_list + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + block_table = attn_metadata.block_tables + actual_seq_lengths_kv = attn_metadata.seq_lens_list + # Normal V1 situation. + else: + num_block, block_size, _, _ = self.key_cache.shape # type: ignore + key = self.key_cache.view( # type: ignore + num_block, block_size, -1) + value = self.value_cache.view( # type: ignore + num_block, block_size, -1) + block_table = attn_metadata.block_tables + actual_seq_lengths_kv = attn_metadata.seq_lens_list + + num_tokens = attn_metadata.query_start_loc_list[-1] + query = query[:num_tokens] + graph_params = get_graph_params() + query_start_loc = attn_metadata.query_start_loc_list + # Prepare tensors for attention output + # TODO: Refactor this to step-level instead of layer-level + + # Get workspace from cache or calculate it if not present. + workspace = graph_params.workspaces.get(num_tokens) + softmax_lse = torch.empty(1, + dtype=query.dtype, + device=query.device) + if workspace is None: + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=query_start_loc, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + sparse_mode=3, + scale=self.scale,) + graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace) + + # Handle graph capturing mode + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + graph_params.attn_params[num_tokens].append(( + weak_ref_tensors(query), + weak_ref_tensors(key), + weak_ref_tensors(value), + weak_ref_tensors(block_table), + weak_ref_tensors(attn_metadata.attn_mask), + block_size, + actual_seq_lengths_kv, + query_start_loc, + self.num_kv_heads, + self.num_heads, + self.scale, + weak_ref_tensors(output), + weak_ref_tensors(softmax_lse) + )) + + torch.npu.graph_task_group_begin(stream) + torch_npu.npu_fused_infer_attention_score.out( + query=query, + key=key, + value=value, + atten_mask=attn_metadata.attn_mask, + block_table=block_table, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=query_start_loc, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_key_value_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale=self.scale, + sparse_mode=3, + workspace=workspace, + out=[output, softmax_lse], + ) + + output = output.view(num_tokens, self.num_heads, + self.head_size) + + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + return output, num_tokens def _forward_prefill_no_cache( self, @@ -500,7 +620,6 @@ def _forward_prefill_cache_hit( batch_size = attn_metadata.query_lens.shape[0] block_table = attn_metadata.block_tables[:batch_size, :] num_block, block_size, _, _ = self.key_cache.shape # type: ignore - if torch.version.cann.startswith("8.3") and block_size == 128: # TODO:The npu_fused_infer_attention_score op is planned to # be utilized in a wider range in upcoming versions. @@ -545,6 +664,7 @@ def _forward_decode_only( attn_metadata: AscendMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: + num_tokens = query.shape[0] if is_310p(): # seq_lens_tensor needs to be transferred to the device for 310P. attn_metadata.seq_lens = \ @@ -1136,49 +1256,46 @@ def forward( slot_mapping[self.pcp_size * num_decode_tokens:attn_metadata. num_actual_tokens_pcp_padded]) - - if self.pcp_size * self.dcp_size > 1: - intermediate_output = self._forward_pcp_dcp( - query, key, value, attn_metadata, output) - elif attn_type == AttentionType.ENCODER_ONLY: - # TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly. - cum_seq_len = attn_metadata.query_start_loc[1:].tolist() - intermediate_output = torch_npu.npu_fusion_attention( - query, - key, - value, - head_num=self.num_heads, - input_layout="TND", - scale=self.scale, - sparse_mode=4, - atten_mask=attn_metadata.attn_mask, - pre_tockens=attn_metadata.max_query_len, - next_tockens=attn_metadata.max_query_len, - actual_seq_qlen=cum_seq_len, - actual_seq_kvlen=cum_seq_len, - )[0] - # V0-Style scheduler situation. - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - intermediate_output = self._forward_prefill_no_cache( - query, key, value, attn_metadata, output, num_tokens) - elif attn_metadata.attn_state == \ - AscendAttentionState.PrefillCacheHit: - intermediate_output = self._forward_prefill_cache_hit( - query, attn_metadata, output) - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - intermediate_output = self._forward_decode_only( - query, attn_metadata, output) - # Normal V1 situation. + forward_context: ForwardContext = get_forward_context() + if ~forward_context.capturing: + if self.pcp_size * self.dcp_size > 1: + intermediate_output = self._forward_pcp_dcp( + query, key, value, attn_metadata, output) + elif attn_type == AttentionType.ENCODER_ONLY: + # TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly. + cum_seq_len = attn_metadata.query_start_loc[1:].tolist() + intermediate_output = torch_npu.npu_fusion_attention( + query, + key, + value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=4, + atten_mask=attn_metadata.attn_mask, + pre_tockens=attn_metadata.max_query_len, + next_tockens=attn_metadata.max_query_len, + actual_seq_qlen=cum_seq_len, + actual_seq_kvlen=cum_seq_len, + )[0] + # V0-Style scheduler situation. + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + intermediate_output = self._forward_prefill_no_cache( + query, key, value, attn_metadata, output, num_tokens) + elif attn_metadata.attn_state == \ + AscendAttentionState.PrefillCacheHit: + intermediate_output = self._forward_prefill_cache_hit( + query, attn_metadata, output) + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + intermediate_output = self._forward_decode_only( + query, attn_metadata, output) + # Normal V1 situation. + else: + intermediate_output = self._forward_v1_style( + query, attn_metadata, output) else: - if torch.version.cann.startswith("8.3"): - # npu_fused_infer_attention_score does not support cases - # where query.shape[0] != attn_metadata.query_start_loc[-1]. - # Thus we need unpad it here. - num_tokens = attn_metadata.query_start_loc[-1] - query = query[:num_tokens] - intermediate_output = self._forward_v1_style( - query, attn_metadata, output) - + intermediate_output, num_tokens = self.full_graph_attention(query, key, value, attn_metadata, + output) output[:num_tokens] = intermediate_output[:num_tokens] return output diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 5548787ffd5..c47a1856fc5 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -196,50 +196,49 @@ def update_attn_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], - ): - ( - query, - key_cache, - value_cache, - num_kv_heads, - num_heads, - scale, - block_table, - seq_lens, - output, - ) = param - seq_lens = forward_context.attn_metadata[key].seq_lens - torch_npu_check = version_check() + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value, + block_tables, + attn_mask, + block_size, + seq_lens, + query_start_loc, + num_kv_heads, + num_heads, + scale, + attn_output, + softmax_lse + ) = param - with torch.npu.stream(update_stream): + seq_lens = forward_context.attn_metadata[key].seq_lens_list + query_start_loc = forward_context.attn_metadata[key].query_start_loc_list torch.npu.graph_task_update_begin(update_stream, handle) - if torch_npu_check: - torch_npu._npu_paged_attention( - query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output, - workspace=graph_params.workspaces.get(runtime_shape)) - else: - torch_npu._npu_paged_attention(query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output) + torch_npu.npu_fused_infer_attention_score.out( + query=query, + key=key_cache, + value=value, + block_table=block_tables, + atten_mask=attn_mask, + input_layout="TND", + block_size=block_size, + actual_seq_lengths=query_start_loc, + actual_seq_lengths_kv=seq_lens, + num_key_value_heads=num_kv_heads, + num_heads=num_heads, + scale=scale, + sparse_mode=3, + workspace=graph_params.workspaces.get(runtime_shape), + out=[attn_output, softmax_lse], + ) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) @@ -426,11 +425,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]): for size in aclgraph_capture_sizes}, ) - -def update_graph_params_workspaces(num_tokens: int, workspace: Any): +def update_graph_params_workspaces(num_tokens: int, workspace: int): global _graph_params if _graph_params is not None: - _graph_params.workspaces[num_tokens] = workspace + _graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace) def get_graph_params(): diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 0c49d95648e..203478e750b 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -288,7 +288,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "vllm.mla_forward" ]) update_aclgraph_sizes(vllm_config) - elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ + compilation_config.cudagraph_mode == CUDAGraphMode.FULL: logger.info( "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") @@ -326,7 +327,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: compilation_config.use_inductor = False compilation_config.splitting_ops.extend(["vllm::mla_forward"]) update_aclgraph_sizes(vllm_config) - elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ + compilation_config.cudagraph_mode == CUDAGraphMode.FULL: logger.info( "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index efdd7c102ac..80adf1e9d31 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -420,7 +420,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): device=self.device) if self.vllm_config.model_config.use_mla and \ - self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + self.compilation_config.cudagraph_mode.has_full_cudagraphs(): rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.cos = torch.ones(self.max_num_reqs * self.decode_token_per_req, @@ -951,35 +951,8 @@ def _make_attention_mask(self, seq_lens, position, # Pooling situation. if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": return self.attn_mask_builder.get_pooling_mask(self.device) - # Chunk Prefill situation. - elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse: - if self.dcp_size > 1: - max_seq_len = max(seq_lens.max().item(), 0) - return self.attn_mask_builder.get_attn_mask( - max_seq_len, self.dtype, self.device) - elif torch.version.cann.startswith("8.3"): - return self.attn_mask_builder.get_splitfuse_attn_mask() - else: - return self.attn_mask_builder.get_splitfuse_attn_mask( - seq_lens, position, self.dtype, self.device) - - # Prefill without cache situation. - elif attn_state == AscendAttentionState.PrefillNoCache: - max_seq_len = max(seq_lens.max().item(), 0) - return self.attn_mask_builder.get_attn_mask( - max_seq_len, self.dtype, self.device) - # Prefill with cache hit. - elif attn_state == AscendAttentionState.PrefillCacheHit: - if torch.version.cann.startswith("8.3"): - return self.attn_mask_builder.get_attn_mask( - 2048, self.dtype, self.device) - else: - return self.attn_mask_builder.get_attn_mask( - 128, self.dtype, self.device) - # Decode-only situation. else: - return None - + return self.attn_mask_builder.get_splitfuse_attn_mask() def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): mrope_pos_ptr = 0 for index, req_id in enumerate(self.input_batch.req_ids): @@ -2504,6 +2477,7 @@ def _build_dummy_attn_metadata( max_query_len: int, aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, + num_scheduled_tokens: np.ndarray = None, ) -> Optional[dict[str, Any]]: attn_metadata: Optional[dict[str, Any]] = None @@ -2517,6 +2491,16 @@ def _build_dummy_attn_metadata( self.seq_lens_np[:num_reqs] = seq_lens self.seq_lens_np[num_reqs:] = 0 + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) + + self.query_start_loc[1:num_reqs + 1] = torch.Tensor(cu_num_tokens) + self.query_start_loc_cpu[1:num_reqs + 1] = torch.Tensor(cu_num_tokens) + + assigned_mask_dim = 2048 + self.attn_mask = torch.triu( + torch.ones(assigned_mask_dim, assigned_mask_dim), + diagonal=1).to(torch.int8).to(self.device) + num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) @@ -2746,6 +2730,7 @@ def _dummy_run( max_query_len=max_query_len, aclgraph_runtime_mode=aclgraph_runtime_mode, force_attention=force_attention, + num_scheduled_tokens=num_scheduled_tokens, ) need_dummy_logits = (not self.in_profile_run