Skip to content

Commit c3ad518

Browse files
committed
implement async scheduling for mtp
Signed-off-by: Ronald1995 <[email protected]>
1 parent 84d7f5a commit c3ad518

File tree

5 files changed

+924
-531
lines changed

5 files changed

+924
-531
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def build(
348348
device=query_start_loc_cpu.device)
349349
])
350350

351-
query_start_loc = query_start_loc_cpu.to(self.device,
351+
query_start_loc = query_start_loc_cpu.pin_memory().to(self.device,
352352
non_blocking=True)
353353

354354
if get_ascend_device_type() == AscendDeviceType._310P:

vllm_ascend/attention/mla_v1.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -566,10 +566,13 @@ def build(
566566
out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
567567
dtype=torch.int32,
568568
)
569-
chunked_context_metadata = \
570-
AscendMLAPrefillMetadata.ChunkedContextMetadata(
571-
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
572-
starts=local_chunk_starts.to(device, non_blocking=True),
569+
chunked_context_metadata = AscendMLAPrefillMetadata.ChunkedContextMetadata(
570+
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
571+
device, non_blocking=True
572+
),
573+
starts=local_chunk_starts.pin_memory().to(
574+
device, non_blocking=True
575+
),
573576
seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
574577
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
575578
chunk_seq_lens=chunk_seq_lens,
@@ -578,22 +581,27 @@ def build(
578581
padded_chunk_seq_lens_npu=padded_local_chunk_seq_lens.npu(),
579582
padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
580583
local_context_lens_allranks=local_context_lens_allranks.tolist(),
581-
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
584+
padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.pin_memory().to(
582585
device, non_blocking=True
583586
),
584587
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
585588
chunk_size=padded_local_max_context_chunk_across_ranks,
586589
)
587590
else:
588-
chunked_context_metadata = \
591+
chunked_context_metadata = (
589592
AscendMLAPrefillMetadata.ChunkedContextMetadata(
590-
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
591-
starts=chunk_starts.to(device, non_blocking=True),
592-
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
593-
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
594-
chunk_seq_lens=chunk_seq_lens,
595-
chunk_seq_lens_npu=chunk_seq_lens.npu(),
596-
workspace=self.chunked_prefill_workspace,
593+
cu_seq_lens=cu_seq_lens_cpu.pin_memory().to(
594+
device, non_blocking=True
595+
),
596+
starts=chunk_starts.pin_memory().to(
597+
device, non_blocking=True
598+
),
599+
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
600+
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
601+
chunk_seq_lens=chunk_seq_lens,
602+
chunk_seq_lens_npu=chunk_seq_lens.npu(),
603+
workspace=self.chunked_prefill_workspace,
604+
)
597605
)
598606
prefill_input_positions = input_positions[tokens_start:]
599607
cos = self.cos_cache[
@@ -626,7 +634,7 @@ def build(
626634
cos = common_attn_metadata.cos
627635
sin = common_attn_metadata.sin
628636
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
629-
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
637+
actual_seq_lengths_q = query_start_loc_cpu[1:num_decodes + 1].tolist()
630638
max_seq_lens = seq_lens[:num_decodes].max().item()
631639
seq_lens = seq_lens[:num_decodes]
632640
input_positions = input_positions[:num_decode_tokens]

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ def __init__(
144144
self.arange = torch.arange(max_num_slots_for_arange,
145145
device=device,
146146
dtype=torch.int32)
147+
self.arange_cpu = torch.arange(
148+
max_num_slots_for_arange, device="cpu", dtype=torch.int32
149+
)
147150

148151
self.inputs_embeds = torch.zeros(
149152
(self.max_num_tokens, self.hidden_size),
@@ -159,6 +162,7 @@ def __init__(
159162
)
160163
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
161164
"index_topk")
165+
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
162166

163167
def load_model(self, model) -> None:
164168
loader = get_model_loader(self.vllm_config.load_config)
@@ -342,6 +346,7 @@ def generate_token_ids(self,
342346
self.runner.discard_request_indices.gpu,
343347
self.runner.num_discarded_requests
344348
)
349+
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
345350

346351
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
347352
if self.pcp_size > 1:
@@ -421,6 +426,24 @@ def generate_token_ids(self,
421426
)
422427

423428
return draft_token_ids
429+
430+
def _copy_valid_sampled_token_count(
431+
self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
432+
) -> None:
433+
if self.runner.valid_sampled_token_count_event is not None:
434+
default_stream = torch.npu.current_stream()
435+
# initialize a new stream to overlap the copy operation with
436+
# prepare_input of draft model.
437+
with torch.npu.stream(self.runner.valid_sampled_token_count_copy_stream):
438+
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
439+
default_stream
440+
) # type: ignore
441+
self.runner.valid_sampled_token_count_cpu[
442+
: valid_sampled_tokens_count.shape[0]
443+
].copy_(valid_sampled_tokens_count, non_blocking=True)
444+
self.runner.valid_sampled_token_count_event.record()
445+
446+
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)
424447

425448
def _init_mtp_model(self):
426449
architecture = self.vllm_config.model_config.architecture
@@ -689,7 +712,11 @@ def _propose(
689712
uniform_decode=False)
690713
aclgraph_runtime_mode, batch_descriptor = \
691714
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
692-
715+
if self.use_async_scheduling:
716+
# there is synchronize between mtp steps when enable aclgraph,
717+
# disable aclgraph when use async scheduling to avoid the
718+
# synchronize overhead.
719+
aclgraph_runtime_mode = CUDAGraphMode.NONE
693720
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
694721
) and aclgraph_runtime_mode == CUDAGraphMode.FULL:
695722
graph_pad_size = num_input_tokens
@@ -795,7 +822,7 @@ def _propose(
795822
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
796823
if self.speculative_config.disable_padded_drafter_batch or \
797824
aclgraph_runtime_mode != CUDAGraphMode.FULL:
798-
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
825+
attn_metadata_i.decode.actual_seq_lengths_q = self.arange_cpu[
799826
1:batch_size + 1].tolist()
800827
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
801828
attn_metadata_i.decode.actual_seq_lengths_q = \

0 commit comments

Comments
 (0)