Skip to content

Commit 65e37c7

Browse files
author
wangxiaoxin-sherie
committed
add fullandpiecesewise graph.
Signed-off-by: wangxiaoxin-sherie <[email protected]>
1 parent 94dd832 commit 65e37c7

File tree

4 files changed

+241
-75
lines changed

4 files changed

+241
-75
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 150 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
4040
nd_to_nz_2d, nd_to_nz_spec)
4141

42+
from ..utils import weak_ref_tensors
4243

4344
class AscendAttentionBackend(AttentionBackend):
4445
accept_output_buffer: bool = True
@@ -140,6 +141,9 @@ class AscendMetadata:
140141
seq_lens: torch.Tensor = None
141142

142143
query_start_loc: torch.Tensor = None
144+
seq_lens_list: List[int] = None
145+
146+
query_start_loc_list: List[int] = None
143147
query_lens: torch.Tensor = None
144148
# Maximum query length in the batch (None for decoding).
145149
max_query_len: Optional[int] = None
@@ -207,8 +211,6 @@ def build(
207211
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
208212
num_reqs
209213
+ 1]
210-
query_start_loc = query_start_loc_cpu.to(self.device,
211-
non_blocking=True)
212214

213215
if is_310p():
214216
if attn_state == AscendAttentionState.PrefillNoCache:
@@ -223,8 +225,10 @@ def build(
223225
attn_metadata = AscendMetadata(
224226
num_actual_tokens=num_actual_tokens,
225227
block_tables=block_table,
226-
query_start_loc=query_start_loc,
228+
query_start_loc=query_start_loc_cpu,
229+
query_start_loc_list=query_start_loc_cpu[1:].cpu().int().tolist(),
227230
query_lens=query_lens,
231+
seq_lens_list=seq_lens.cpu().int().tolist(),
228232
seq_lens=seq_lens,
229233
max_query_len=common_attn_metadata.max_query_len,
230234
slot_mapping=slot_mapping,
@@ -391,51 +395,151 @@ def _forward_decode_only(
391395
else:
392396
graph_params = get_graph_params()
393397
forward_context: ForwardContext = get_forward_context()
394-
num_tokens = query.shape[0]
395398
if forward_context.capturing:
396-
stream = torch_npu.npu.current_stream()
397-
398-
event = torch.npu.ExternalEvent()
399-
event.wait(stream)
400-
event.reset(stream)
401-
graph_params.events[num_tokens].append(event)
402-
403-
graph_params.attn_params[num_tokens].append((
404-
query,
405-
self.key_cache,
406-
self.value_cache,
407-
self.num_kv_heads,
408-
self.num_heads,
409-
self.scale,
410-
attn_metadata.block_tables,
411-
attn_metadata.seq_lens,
412-
output,
413-
))
414-
415-
torch.npu.graph_task_group_begin(stream)
416-
torch_npu._npu_paged_attention(
417-
query=query,
418-
key_cache=self.key_cache,
419-
value_cache=self.value_cache,
420-
num_kv_heads=self.num_kv_heads,
421-
num_heads=self.num_heads,
422-
scale_value=self.scale,
423-
block_table=attn_metadata.block_tables,
424-
context_lens=attn_metadata.seq_lens,
425-
out=output)
426-
handle = torch.npu.graph_task_group_end(stream)
427-
graph_params.handles[num_tokens].append(handle)
399+
if torch.version.cann.startswith("8.3"):
400+
# Prepare tensors for attention output
401+
query_start_loc = attn_metadata.query_start_loc_list
402+
seq_lens = attn_metadata.seq_lens_list
403+
num_tokens = query_start_loc[-1]
404+
query = query[:num_tokens]
405+
406+
# Get workspace from cache or calculate it if not present.
407+
workspace = graph_params.workspaces.get(num_tokens)
408+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
409+
key = self.key_cache.view( # type: ignore
410+
num_block, block_size, -1)
411+
value = self.value_cache.view( # type: ignore
412+
num_block, block_size, -1)
413+
softmax_lse = torch.empty(num_tokens,
414+
dtype=query.dtype,
415+
device=query.device)
416+
if workspace is None:
417+
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
418+
query=query,
419+
key=key,
420+
value=value,
421+
block_table=attn_metadata.block_tables,
422+
input_layout="TND",
423+
block_size=block_size,
424+
actual_seq_lengths=query_start_loc,
425+
actual_seq_lengths_kv=seq_lens,
426+
num_key_value_heads=self.num_kv_heads,
427+
num_heads=self.num_heads,
428+
sparse_mode=0,
429+
scale=self.scale,)
430+
graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
431+
432+
# Handle graph capturing mode
433+
stream = torch_npu.npu.current_stream()
434+
435+
event = torch.npu.ExternalEvent()
436+
event.wait(stream)
437+
event.reset(stream)
438+
graph_params.events[num_tokens].append(event)
439+
graph_params.attn_params[num_tokens].append((
440+
weak_ref_tensors(query),
441+
weak_ref_tensors(key),
442+
weak_ref_tensors(value),
443+
weak_ref_tensors(attn_metadata.block_tables),
444+
block_size,
445+
seq_lens,
446+
query_start_loc,
447+
self.num_kv_heads,
448+
self.num_heads,
449+
self.scale,
450+
weak_ref_tensors(output),
451+
weak_ref_tensors(softmax_lse)
452+
))
453+
454+
torch.npu.graph_task_group_begin(stream)
455+
torch_npu.npu_fused_infer_attention_score.out(
456+
query=query,
457+
key=key,
458+
value=value,
459+
block_table=attn_metadata.block_tables,
460+
input_layout="TND",
461+
block_size=block_size,
462+
actual_seq_lengths=query_start_loc,
463+
actual_seq_lengths_kv=seq_lens,
464+
num_key_value_heads=self.num_kv_heads,
465+
num_heads=self.num_heads,
466+
scale=self.scale,
467+
sparse_mode=0,
468+
workspace=workspace,
469+
out=[output, softmax_lse],
470+
)
471+
472+
output = output.view(num_tokens, self.num_heads,
473+
self.head_size)
474+
475+
handle = torch.npu.graph_task_group_end(stream)
476+
graph_params.handles[num_tokens].append(handle)
477+
else:
478+
stream = torch_npu.npu.current_stream()
479+
480+
event = torch.npu.ExternalEvent()
481+
event.wait(stream)
482+
event.reset(stream)
483+
graph_params.events[num_tokens].append(event)
484+
485+
graph_params.attn_params[num_tokens].append((
486+
weak_ref_tensors(query),
487+
weak_ref_tensors(self.key_cache),
488+
weak_ref_tensors(self.value_cache),
489+
self.num_kv_heads,
490+
self.num_heads,
491+
self.scale,
492+
weak_ref_tensors(attn_metadata.block_tables),
493+
attn_metadata.seq_lens,
494+
weak_ref_tensors(output),
495+
))
496+
497+
torch.npu.graph_task_group_begin(stream)
498+
torch_npu._npu_paged_attention(
499+
query=query,
500+
key_cache=self.key_cache,
501+
value_cache=self.value_cache,
502+
num_kv_heads=self.num_kv_heads,
503+
num_heads=self.num_heads,
504+
scale_value=self.scale,
505+
block_table=attn_metadata.block_tables,
506+
context_lens=attn_metadata.seq_lens,
507+
out=output)
508+
handle = torch.npu.graph_task_group_end(stream)
509+
graph_params.handles[num_tokens].append(handle)
428510
else:
429-
torch_npu._npu_paged_attention(
430-
query=query,
431-
key_cache=self.key_cache,
432-
value_cache=self.value_cache,
433-
num_kv_heads=self.num_kv_heads,
434-
num_heads=self.num_heads,
435-
scale_value=self.scale,
436-
block_table=attn_metadata.block_tables,
437-
context_lens=attn_metadata.seq_lens,
438-
out=output)
511+
if torch.version.cann.startswith("8.3"):
512+
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
513+
key = self.key_cache.view( # type: ignore
514+
num_block, block_size, -1)
515+
value = self.value_cache.view( # type: ignore
516+
num_block, block_size, -1)
517+
518+
output, _ = torch_npu.npu_fused_infer_attention_score(
519+
query=query,
520+
key=key,
521+
value=value,
522+
block_table=attn_metadata.block_tables,
523+
input_layout="TND",
524+
block_size=block_size,
525+
actual_seq_lengths=attn_metadata.query_start_loc_list,
526+
actual_seq_lengths_kv=attn_metadata.seq_lens_list,
527+
num_key_value_heads=self.num_kv_heads,
528+
num_heads=self.num_heads,
529+
scale=self.scale,
530+
sparse_mode=0
531+
)
532+
else:
533+
torch_npu._npu_paged_attention(
534+
query=query,
535+
key_cache=self.key_cache,
536+
value_cache=self.value_cache,
537+
num_kv_heads=self.num_kv_heads,
538+
num_heads=self.num_heads,
539+
scale_value=self.scale,
540+
block_table=attn_metadata.block_tables,
541+
context_lens=attn_metadata.seq_lens,
542+
out=output)
439543
return output
440544

441545
def _forward_v1_style(

vllm_ascend/compilation/acl_graph.py

Lines changed: 69 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -199,34 +199,75 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
199199
graph_params.handles[runtime_shape],
200200
graph_params.events[runtime_shape],
201201
):
202-
(
203-
query,
204-
key_cache,
205-
value_cache,
206-
num_kv_heads,
207-
num_heads,
208-
scale,
209-
block_table,
210-
seq_lens,
211-
output,
212-
) = param
213-
# block_table = forward_context.attn_metadata[key].block_tables
214-
seq_lens = forward_context.attn_metadata[key].seq_lens
215-
216-
with torch.npu.stream(update_stream):
217-
torch.npu.graph_task_update_begin(update_stream, handle)
218-
torch_npu._npu_paged_attention(query=query,
219-
key_cache=key_cache,
220-
value_cache=value_cache,
221-
num_kv_heads=num_kv_heads,
222-
num_heads=num_heads,
223-
scale_value=scale,
224-
block_table=block_table,
225-
context_lens=seq_lens,
226-
out=output)
227-
torch.npu.graph_task_update_end(update_stream)
228-
229-
event.record(update_stream)
202+
if torch.version.cann.startswith("8.3"):
203+
(
204+
query,
205+
key_cache,
206+
value,
207+
block_tables,
208+
block_size,
209+
seq_lens,
210+
query_start_loc,
211+
num_kv_heads,
212+
num_heads,
213+
scale,
214+
attn_output,
215+
softmax_lse
216+
) = param
217+
218+
# block_table = forward_context.attn_metadata[key].block_tables
219+
seq_lens = forward_context.attn_metadata[key].seq_lens
220+
221+
with torch.npu.stream(update_stream):
222+
torch.npu.graph_task_update_begin(update_stream, handle)
223+
torch_npu.npu_fused_infer_attention_score.out(
224+
query=query,
225+
key=key_cache,
226+
value=value,
227+
block_table=block_tables,
228+
input_layout="TND",
229+
block_size=block_size,
230+
actual_seq_lengths=query_start_loc,
231+
actual_seq_lengths_kv=seq_lens,
232+
num_key_value_heads=num_kv_heads,
233+
num_heads=num_heads,
234+
scale=scale,
235+
sparse_mode=0,
236+
workspace=graph_params.workspaces.get(runtime_shape),
237+
out=[attn_output, softmax_lse],
238+
)
239+
torch.npu.graph_task_update_end(update_stream)
240+
241+
event.record(update_stream)
242+
else:
243+
(
244+
query,
245+
key_cache,
246+
value_cache,
247+
num_kv_heads,
248+
num_heads,
249+
scale,
250+
block_table,
251+
seq_lens,
252+
output,
253+
) = param
254+
# block_table = forward_context.attn_metadata[key].block_tables
255+
seq_lens = forward_context.attn_metadata[key].seq_lens
256+
257+
with torch.npu.stream(update_stream):
258+
torch.npu.graph_task_update_begin(update_stream, handle)
259+
torch_npu._npu_paged_attention(query=query,
260+
key_cache=key_cache,
261+
value_cache=value_cache,
262+
num_kv_heads=num_kv_heads,
263+
num_heads=num_heads,
264+
scale_value=scale,
265+
block_table=block_table,
266+
context_lens=seq_lens,
267+
out=output)
268+
torch.npu.graph_task_update_end(update_stream)
269+
270+
event.record(update_stream)
230271

231272

232273
@dataclass

vllm_ascend/platform.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
178178

179179
compilation_config.cudagraph_num_of_warmups = 1
180180

181-
if compilation_config.level not in [
181+
if compilation_config.level == CompilationLevel.PIECEWISE:
182+
logger.warning(
183+
"NEW NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
184+
compilation_config.level)
185+
elif compilation_config.level not in [
182186
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
183187
]:
184188
logger.warning(
@@ -231,6 +235,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
231235
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
232236
])
233237
update_aclgraph_sizes(vllm_config)
238+
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
239+
logger.info(
240+
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
241+
"using only ACL Graph mode")
242+
assert compilation_config.level == CompilationLevel.PIECEWISE, \
243+
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
244+
compilation_config.set_splitting_ops_for_v1()
245+
compilation_config.use_inductor = False
246+
compilation_config.splitting_ops.extend([
247+
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
248+
])
249+
update_aclgraph_sizes(vllm_config)
234250
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
235251
logger.info(
236252
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,8 +1406,10 @@ def _prepare_inputs(
14061406
common_attn_metadata = AscendCommonAttentionMetadata(
14071407
query_start_loc=self.query_start_loc[:num_reqs + 1],
14081408
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
1409+
query_start_loc_list=self.query_start_loc_cpu[:num_reqs + 1].cpu().int().tolist(),
14091410
seq_lens_cpu=self.seq_lens_cpu,
14101411
seq_lens=self.seq_lens_cpu[:num_reqs],
1412+
seq_lens_list=self.seq_lens_cpu[:num_reqs].cpu().int().tolist(),
14111413
num_reqs=num_reqs,
14121414
num_actual_tokens=total_num_scheduled_tokens,
14131415
actual_seq_lengths_q=self.actual_seq_lengths_q,
@@ -2172,6 +2174,9 @@ def _build_attention_metadata(self, create_mixed_batch, num_reqs,
21722174
self.seq_lens_np[:num_reqs] = seq_lens
21732175
self.seq_lens_np[num_reqs:] = 0
21742176

2177+
self.query_start_loc[:num_reqs + 1] = num_tokens
2178+
self.query_start_loc_cpu[:num_reqs + 1] = num_tokens
2179+
21752180
num_computed_tokens_cpu = (
21762181
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
21772182

0 commit comments

Comments
 (0)