Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,8 @@ def _forward_decode_only(
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens, workspace)
update_graph_params_workspaces(
num_tokens, weak_ref_tensors(workspace))

# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
Expand All @@ -459,7 +460,7 @@ def _forward_decode_only(
self.num_kv_heads,
self.num_heads,
self.scale,
weak_ref_tensors(attn_metadata.block_tables),
attn_metadata.block_tables,
attn_metadata.seq_lens,
weak_ref_tensors(output),
))
Expand Down
22 changes: 13 additions & 9 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
split_decodes_and_prefills,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import get_graph_params
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_enable_nz)
is_enable_nz, weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import InputBatch

if TYPE_CHECKING:
Expand Down Expand Up @@ -663,7 +664,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
None), AscendW8A8LinearMethod):
self.enable_mlapo = False
logger.warning(
logger.warning_once(
"Currently mlapo only supports W8A8 quantization in MLA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
Expand Down Expand Up @@ -1035,19 +1036,22 @@ def _forward_decode(
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, k_nope, **common_kwargs)
graph_params.workspaces[num_tokens] = workspace
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))

attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty(num_tokens,
dtype=q_nope.dtype,
device=q_nope.device)

graph_params.attn_params[num_tokens].append(
(q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads,
input_layout, spec_attn_mask, sparse_mode, self.scale,
decode_meta.block_table, block_size,
decode_meta.seq_lens_list, actual_seq_lengths, workspace,
attn_output, softmax_lse))
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe),
self.num_heads, self.num_kv_heads, input_layout,
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None
else None, sparse_mode, self.scale, decode_meta.block_table,
block_size, decode_meta.seq_lens_list, actual_seq_lengths,
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))

torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
Expand Down
8 changes: 3 additions & 5 deletions vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
seq_lens,
output,
) = param
# block_table = forward_context.attn_metadata[key].block_tables
seq_lens = forward_context.attn_metadata[key].seq_lens
torch_npu_check = version_check()

Expand Down Expand Up @@ -258,8 +257,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
):
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
spec_attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list, actual_seq_lengths, workspace, attn_output,
softmax_lse) = param
seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "deepseek_mtp":
actual_seq_lengths = forward_context.attn_metadata[
Expand Down Expand Up @@ -295,7 +293,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
block_size=block_size,
actual_seq_lengths_kv=seq_lens_list,
actual_seq_lengths=actual_seq_lengths,
workspace=workspace,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse])
torch.npu.graph_task_update_end(update_stream)

Expand Down Expand Up @@ -329,7 +327,7 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
)


def update_graph_params_workspaces(num_tokens: int, workspace: int):
def update_graph_params_workspaces(num_tokens: int, workspace: Any):
global _graph_params
if _graph_params is not None:
_graph_params.workspaces[num_tokens] = workspace
Expand Down
7 changes: 7 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,13 @@ def weak_ref_tensors(
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.

This function should be used in the following scenario:
When a tensor is created during graph capture, and it's held by a method
that's not part of the graph, we don't really need to store it, but we
**do need** its buffer pointer. If we don't handle this, it cannot
be garbage collected, leading to a memory leak. To avoid this,
we should create a weak reference to the tensor.
"""
if isinstance(tensors, torch.Tensor):
return weak_ref_tensor(tensors)
Expand Down
Loading