Skip to content

Commit 81291e9

Browse files
author
wangxiaoxin-sherie
committed
add fullandpiecesewise graph.
Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent 19b85ef commit 81291e9

File tree

4 files changed

+244
-75
lines changed

4 files changed

+244
-75
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 154 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,17 @@
3434
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
3535
maybe_save_kv_layer_to_connector,
3636
wait_for_kv_layer_from_connector)
37-
from vllm_ascend.compilation.acl_graph import get_graph_params
37+
from vllm_ascend.compilation.acl_graph import (get_graph_params,
38+
update_graph_params_workspaces)
3839
from vllm_ascend.ops.attention import vanilla_chunked_prefill
3940
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
4041
nd_to_nz_2d, nd_to_nz_spec)
4142

4243
from ..utils import weak_ref_tensors
44+
<<<<<<< HEAD
4345

46+
=======
47+
>>>>>>> 0f6e3d6 (add fullandpiecesewise graph.)
4448

4549
class AscendAttentionBackend(AttentionBackend):
4650
accept_output_buffer: bool = True
@@ -142,6 +146,9 @@ class AscendMetadata:
142146
seq_lens: torch.Tensor = None
143147

144148
query_start_loc: torch.Tensor = None
149+
seq_lens_list: List[int] = None
150+
151+
query_start_loc_list: List[int] = None
145152
query_lens: torch.Tensor = None
146153
# Maximum query length in the batch (None for decoding).
147154
max_query_len: Optional[int] = None
@@ -209,8 +216,6 @@ def build(
209216
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
210217
num_reqs
211218
+ 1]
212-
query_start_loc = query_start_loc_cpu.to(self.device,
213-
non_blocking=True)
214219

215220
if is_310p():
216221
if attn_state == AscendAttentionState.PrefillNoCache:
@@ -225,8 +230,10 @@ def build(
225230
attn_metadata = AscendMetadata(
226231
num_actual_tokens=num_actual_tokens,
227232
block_tables=block_table,
228-
query_start_loc=query_start_loc,
233+
query_start_loc=query_start_loc_cpu,
234+
query_start_loc_list=query_start_loc_cpu[1:].cpu().int().tolist(),
229235
query_lens=query_lens,
236+
seq_lens_list=seq_lens.cpu().int().tolist(),
230237
seq_lens=seq_lens,
231238
max_query_len=common_attn_metadata.max_query_len,
232239
slot_mapping=slot_mapping,
@@ -394,51 +401,151 @@ def _forward_decode_only(
394401
else:
395402
graph_params = get_graph_params()
396403
forward_context: ForwardContext = get_forward_context()
397-
num_tokens = query.shape[0]
398404
if forward_context.capturing:
399-
stream = torch_npu.npu.current_stream()
400-
401-
event = torch.npu.ExternalEvent()
402-
event.wait(stream)
403-
event.reset(stream)
404-
graph_params.events[num_tokens].append(event)
405-
406-
graph_params.attn_params[num_tokens].append((
407-
weak_ref_tensors(query),
408-
weak_ref_tensors(self.key_cache),
409-
weak_ref_tensors(self.value_cache),
410-
self.num_kv_heads,
411-
self.num_heads,
412-
self.scale,
413-
weak_ref_tensors(attn_metadata.block_tables),
414-
attn_metadata.seq_lens,
415-
weak_ref_tensors(output),
416-
))
417-
418-
torch.npu.graph_task_group_begin(stream)
419-
torch_npu._npu_paged_attention(
420-
query=query,
421-
key_cache=self.key_cache,
422-
value_cache=self.value_cache,
423-
num_kv_heads=self.num_kv_heads,
424-
num_heads=self.num_heads,
425-
scale_value=self.scale,
426-
block_table=attn_metadata.block_tables,
427-
context_lens=attn_metadata.seq_lens,
428-
out=output)
429-
handle = torch.npu.graph_task_group_end(stream)
430-
graph_params.handles[num_tokens].append(handle)
405+
if torch.version.cann.startswith("8.3"):
406+
# Prepare tensors for attention output
407+
query_start_loc = attn_metadata.query_start_loc_list
408+
seq_lens = attn_metadata.seq_lens_list
409+
num_tokens = query_start_loc[-1]
410+
query = query[:num_tokens]
411+
412+
# Get workspace from cache or calculate it if not present.
413+
workspace = graph_params.workspaces.get(num_tokens)
414+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
415+
key = self.key_cache.view( # type: ignore
416+
num_block, block_size, -1)
417+
value = self.value_cache.view( # type: ignore
418+
num_block, block_size, -1)
419+
softmax_lse = torch.empty(num_tokens,
420+
dtype=query.dtype,
421+
device=query.device)
422+
if workspace is None:
423+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
424+
query=query,
425+
key=key,
426+
value=value,
427+
block_table=attn_metadata.block_tables,
428+
input_layout="TND",
429+
block_size=block_size,
430+
actual_seq_lengths=query_start_loc,
431+
actual_seq_lengths_kv=seq_lens,
432+
num_key_value_heads=self.num_kv_heads,
433+
num_heads=self.num_heads,
434+
sparse_mode=0,
435+
scale=self.scale,)
436+
update_graph_params_workspaces(num_tokens, workspace)
437+
438+
# Handle graph capturing mode
439+
stream = torch_npu.npu.current_stream()
440+
441+
event = torch.npu.ExternalEvent()
442+
event.wait(stream)
443+
event.reset(stream)
444+
graph_params.events[num_tokens].append(event)
445+
graph_params.attn_params[num_tokens].append((
446+
weak_ref_tensors(query),
447+
weak_ref_tensors(key),
448+
weak_ref_tensors(value),
449+
weak_ref_tensors(attn_metadata.block_tables),
450+
block_size,
451+
seq_lens,
452+
query_start_loc,
453+
self.num_kv_heads,
454+
self.num_heads,
455+
self.scale,
456+
weak_ref_tensors(output),
457+
weak_ref_tensors(softmax_lse)
458+
))
459+
460+
torch.npu.graph_task_group_begin(stream)
461+
torch_npu.npu_fused_infer_attention_score.out(
462+
query=query,
463+
key=key,
464+
value=value,
465+
block_table=attn_metadata.block_tables,
466+
input_layout="TND",
467+
block_size=block_size,
468+
actual_seq_lengths=query_start_loc,
469+
actual_seq_lengths_kv=seq_lens,
470+
num_key_value_heads=self.num_kv_heads,
471+
num_heads=self.num_heads,
472+
scale=self.scale,
473+
sparse_mode=0,
474+
workspace=workspace,
475+
out=[output, softmax_lse],
476+
)
477+
478+
output = output.view(num_tokens, self.num_heads,
479+
self.head_size)
480+
481+
handle = torch.npu.graph_task_group_end(stream)
482+
graph_params.handles[num_tokens].append(handle)
483+
else:
484+
stream = torch_npu.npu.current_stream()
485+
486+
event = torch.npu.ExternalEvent()
487+
event.wait(stream)
488+
event.reset(stream)
489+
graph_params.events[num_tokens].append(event)
490+
491+
graph_params.attn_params[num_tokens].append((
492+
weak_ref_tensors(query),
493+
weak_ref_tensors(self.key_cache),
494+
weak_ref_tensors(self.value_cache),
495+
self.num_kv_heads,
496+
self.num_heads,
497+
self.scale,
498+
weak_ref_tensors(attn_metadata.block_tables),
499+
attn_metadata.seq_lens,
500+
weak_ref_tensors(output),
501+
))
502+
503+
torch.npu.graph_task_group_begin(stream)
504+
torch_npu._npu_paged_attention(
505+
query=query,
506+
key_cache=self.key_cache,
507+
value_cache=self.value_cache,
508+
num_kv_heads=self.num_kv_heads,
509+
num_heads=self.num_heads,
510+
scale_value=self.scale,
511+
block_table=attn_metadata.block_tables,
512+
context_lens=attn_metadata.seq_lens,
513+
out=output)
514+
handle = torch.npu.graph_task_group_end(stream)
515+
graph_params.handles[num_tokens].append(handle)
431516
else:
432-
torch_npu._npu_paged_attention(
433-
query=query,
434-
key_cache=self.key_cache,
435-
value_cache=self.value_cache,
436-
num_kv_heads=self.num_kv_heads,
437-
num_heads=self.num_heads,
438-
scale_value=self.scale,
439-
block_table=attn_metadata.block_tables,
440-
context_lens=attn_metadata.seq_lens,
441-
out=output)
517+
if torch.version.cann.startswith("8.3"):
518+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
519+
key = self.key_cache.view( # type: ignore
520+
num_block, block_size, -1)
521+
value = self.value_cache.view( # type: ignore
522+
num_block, block_size, -1)
523+
524+
output, _ = torch_npu.npu_fused_infer_attention_score(
525+
query=query,
526+
key=key,
527+
value=value,
528+
block_table=attn_metadata.block_tables,
529+
input_layout="TND",
530+
block_size=block_size,
531+
actual_seq_lengths=attn_metadata.query_start_loc_list,
532+
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
533+
num_key_value_heads=self.num_kv_heads,
534+
num_heads=self.num_heads,
535+
scale=self.scale,
536+
sparse_mode=0
537+
)
538+
else:
539+
torch_npu._npu_paged_attention(
540+
query=query,
541+
key_cache=self.key_cache,
542+
value_cache=self.value_cache,
543+
num_kv_heads=self.num_kv_heads,
544+
num_heads=self.num_heads,
545+
scale_value=self.scale,
546+
block_table=attn_metadata.block_tables,
547+
context_lens=attn_metadata.seq_lens,
548+
out=output)
442549
return output
443550

444551
def _forward_v1_style(

vllm_ascend/compilation/acl_graph.py

Lines changed: 73 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -199,34 +199,75 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
199199
graph_params.handles[runtime_shape],
200200
graph_params.events[runtime_shape],
201201
):
202-
(
203-
query,
204-
key_cache,
205-
value_cache,
206-
num_kv_heads,
207-
num_heads,
208-
scale,
209-
block_table,
210-
seq_lens,
211-
output,
212-
) = param
213-
# block_table = forward_context.attn_metadata[key].block_tables
214-
seq_lens = forward_context.attn_metadata[key].seq_lens
215-
216-
with torch.npu.stream(update_stream):
217-
torch.npu.graph_task_update_begin(update_stream, handle)
218-
torch_npu._npu_paged_attention(query=query,
219-
key_cache=key_cache,
220-
value_cache=value_cache,
221-
num_kv_heads=num_kv_heads,
222-
num_heads=num_heads,
223-
scale_value=scale,
224-
block_table=block_table,
225-
context_lens=seq_lens,
226-
out=output)
227-
torch.npu.graph_task_update_end(update_stream)
228-
229-
event.record(update_stream)
202+
if torch.version.cann.startswith("8.3"):
203+
(
204+
query,
205+
key_cache,
206+
value,
207+
block_tables,
208+
block_size,
209+
seq_lens,
210+
query_start_loc,
211+
num_kv_heads,
212+
num_heads,
213+
scale,
214+
attn_output,
215+
softmax_lse
216+
) = param
217+
218+
# block_table = forward_context.attn_metadata[key].block_tables
219+
seq_lens = forward_context.attn_metadata[key].seq_lens
220+
221+
with torch.npu.stream(update_stream):
222+
torch.npu.graph_task_update_begin(update_stream, handle)
223+
torch_npu.npu_fused_infer_attention_score.out(
224+
query=query,
225+
key=key_cache,
226+
value=value,
227+
block_table=block_tables,
228+
input_layout="TND",
229+
block_size=block_size,
230+
actual_seq_lengths=query_start_loc,
231+
actual_seq_lengths_kv=seq_lens,
232+
num_key_value_heads=num_kv_heads,
233+
num_heads=num_heads,
234+
scale=scale,
235+
sparse_mode=0,
236+
workspace=graph_params.workspaces.get(runtime_shape),
237+
out=[attn_output, softmax_lse],
238+
)
239+
torch.npu.graph_task_update_end(update_stream)
240+
241+
event.record(update_stream)
242+
else:
243+
(
244+
query,
245+
key_cache,
246+
value_cache,
247+
num_kv_heads,
248+
num_heads,
249+
scale,
250+
block_table,
251+
seq_lens,
252+
output,
253+
) = param
254+
# block_table = forward_context.attn_metadata[key].block_tables
255+
seq_lens = forward_context.attn_metadata[key].seq_lens
256+
257+
with torch.npu.stream(update_stream):
258+
torch.npu.graph_task_update_begin(update_stream, handle)
259+
torch_npu._npu_paged_attention(query=query,
260+
key_cache=key_cache,
261+
value_cache=value_cache,
262+
num_kv_heads=num_kv_heads,
263+
num_heads=num_heads,
264+
scale_value=scale,
265+
block_table=block_table,
266+
context_lens=seq_lens,
267+
out=output)
268+
torch.npu.graph_task_update_end(update_stream)
269+
270+
event.record(update_stream)
230271

231272

232273
def update_mla_attn_params(update_stream, forward_context, runtime_shape):
@@ -301,6 +342,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
301342
for size in aclgraph_capture_sizes},
302343
)
303344

345+
def update_graph_params_workspaces(num_tokens: int, workspace: int):
346+
global _graph_params
347+
if _graph_params is not None:
348+
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
304349

305350
def get_graph_params():
306351
return _graph_params

vllm_ascend/platform.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
226226
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
227227
])
228228
update_aclgraph_sizes(vllm_config)
229+
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
230+
logger.info(
231+
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
232+
"using only ACL Graph mode")
233+
assert compilation_config.level == CompilationLevel.PIECEWISE, \
234+
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
235+
compilation_config.set_splitting_ops_for_v1()
236+
compilation_config.use_inductor = False
237+
compilation_config.splitting_ops.extend([
238+
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
239+
])
240+
update_aclgraph_sizes(vllm_config)
229241
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
230242
logger.info(
231243
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,8 +1431,10 @@ def _prepare_inputs(
14311431
common_attn_metadata = AscendCommonAttentionMetadata(
14321432
query_start_loc=self.query_start_loc[:num_reqs + 1],
14331433
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
1434+
query_start_loc_list=self.query_start_loc_cpu[:num_reqs + 1].cpu().int().tolist(),
14341435
seq_lens_cpu=self.seq_lens_cpu,
14351436
seq_lens=self.seq_lens_cpu[:num_reqs],
1437+
seq_lens_list=self.seq_lens_cpu[:num_reqs].cpu().int().tolist(),
14361438
num_reqs=num_reqs,
14371439
num_actual_tokens=total_num_scheduled_tokens,
14381440
actual_seq_lengths_q=self.actual_seq_lengths_q,
@@ -2205,6 +2207,9 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
22052207
self.seq_lens_np[:num_reqs] = seq_lens
22062208
self.seq_lens_np[num_reqs:] = 0
22072209

2210+
self.query_start_loc[:num_reqs + 1] = num_tokens
2211+
self.query_start_loc_cpu[:num_reqs + 1] = num_tokens
2212+
22082213
num_computed_tokens_cpu = (
22092214
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
22102215

0 commit comments

Comments
 (0)