Skip to content

Commit 3dcad2c

Browse files
committed
[Fix] Prevent memory leak in MLA decode graph (#3743)
The cache for MLA decode graph parameters was holding strong references to tensors, preventing them from being garbage collected and leading to increased memory usage. This change wraps the cached tensors in weak references, allowing them to be deallocated when no longer in use and reducing overall memory pressure. None. None. - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@c9461e0 --------- Signed-off-by: Yizhou Liu <[email protected]>
1 parent 1b16c01 commit 3dcad2c

File tree

4 files changed

+26
-16
lines changed

4 files changed

+26
-16
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,8 @@ def _forward_decode_only(
443443
block_table=attn_metadata.block_tables,
444444
context_lens=attn_metadata.seq_lens,
445445
out=output)
446-
update_graph_params_workspaces(num_tokens, workspace)
446+
update_graph_params_workspaces(
447+
num_tokens, weak_ref_tensors(workspace))
447448

448449
# Handle graph capturing mode
449450
stream = torch_npu.npu.current_stream()
@@ -459,7 +460,7 @@ def _forward_decode_only(
459460
self.num_kv_heads,
460461
self.num_heads,
461462
self.scale,
462-
weak_ref_tensors(attn_metadata.block_tables),
463+
attn_metadata.block_tables,
463464
attn_metadata.seq_lens,
464465
weak_ref_tensors(output),
465466
))

vllm_ascend/attention/mla_v1.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@
2525
split_decodes_and_prefills,
2626
trans_rope_weight, transdata,
2727
wait_for_kv_layer_from_connector)
28-
from vllm_ascend.compilation.acl_graph import get_graph_params
28+
from vllm_ascend.compilation.acl_graph import (get_graph_params,
29+
update_graph_params_workspaces)
2930
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
3031
from vllm_ascend.multistream.context import get_multistream_comm_context
3132
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
3233
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
3334
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
3435
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
35-
is_enable_nz)
36+
is_enable_nz, weak_ref_tensors)
3637
from vllm_ascend.worker.npu_input_batch import InputBatch
3738

3839
if TYPE_CHECKING:
@@ -663,7 +664,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
663664
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
664665
None), AscendW8A8LinearMethod):
665666
self.enable_mlapo = False
666-
logger.warning(
667+
logger.warning_once(
667668
"Currently mlapo only supports W8A8 quantization in MLA scenario."
668669
"Some layers in your model are not quantized with W8A8,"
669670
"thus mlapo is disabled for these layers.")
@@ -1035,19 +1036,22 @@ def _forward_decode(
10351036
if workspace is None:
10361037
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
10371038
q_nope, k_nope, k_nope, **common_kwargs)
1038-
graph_params.workspaces[num_tokens] = workspace
1039+
update_graph_params_workspaces(num_tokens,
1040+
weak_ref_tensors(workspace))
10391041

10401042
attn_output = torch.empty_like(q_nope)
10411043
softmax_lse = torch.empty(num_tokens,
10421044
dtype=q_nope.dtype,
10431045
device=q_nope.device)
10441046

10451047
graph_params.attn_params[num_tokens].append(
1046-
(q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads,
1047-
input_layout, spec_attn_mask, sparse_mode, self.scale,
1048-
decode_meta.block_table, block_size,
1049-
decode_meta.seq_lens_list, actual_seq_lengths, workspace,
1050-
attn_output, softmax_lse))
1048+
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
1049+
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe),
1050+
self.num_heads, self.num_kv_heads, input_layout,
1051+
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None
1052+
else None, sparse_mode, self.scale, decode_meta.block_table,
1053+
block_size, decode_meta.seq_lens_list, actual_seq_lengths,
1054+
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))
10511055

10521056
torch.npu.graph_task_group_begin(stream)
10531057
torch_npu.npu_fused_infer_attention_score.out(

vllm_ascend/compilation/acl_graph.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
212212
seq_lens,
213213
output,
214214
) = param
215-
# block_table = forward_context.attn_metadata[key].block_tables
216215
seq_lens = forward_context.attn_metadata[key].seq_lens
217216
torch_npu_check = version_check()
218217

@@ -258,8 +257,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
258257
):
259258
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
260259
spec_attn_mask, sparse_mode, scale, block_table, block_size,
261-
seq_lens_list, actual_seq_lengths, workspace, attn_output,
262-
softmax_lse) = param
260+
seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param
263261
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
264262
if speculative_config and speculative_config.method == "deepseek_mtp":
265263
actual_seq_lengths = forward_context.attn_metadata[
@@ -295,7 +293,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
295293
block_size=block_size,
296294
actual_seq_lengths_kv=seq_lens_list,
297295
actual_seq_lengths=actual_seq_lengths,
298-
workspace=workspace,
296+
workspace=graph_params.workspaces.get(runtime_shape),
299297
out=[attn_output, softmax_lse])
300298
torch.npu.graph_task_update_end(update_stream)
301299

@@ -329,7 +327,7 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
329327
)
330328

331329

332-
def update_graph_params_workspaces(num_tokens: int, workspace: int):
330+
def update_graph_params_workspaces(num_tokens: int, workspace: Any):
333331
global _graph_params
334332
if _graph_params is not None:
335333
_graph_params.workspaces[num_tokens] = workspace

vllm_ascend/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,13 @@ def weak_ref_tensors(
686686
"""
687687
Convenience function to create weak references to tensors,
688688
for single tensor, list of tensors or tuple of tensors.
689+
690+
This function should be used in the following scenario:
691+
When a tensor is created during graph capture, and it's held by a method
692+
that's not part of the graph, we don't really need to store it, but we
693+
**do need** its buffer pointer. If we don't handle this, it cannot
694+
be garbage collected, leading to a memory leak. To avoid this,
695+
we should create a weak reference to the tensor.
689696
"""
690697
if isinstance(tensors, torch.Tensor):
691698
return weak_ref_tensor(tensors)

0 commit comments

Comments
 (0)