Skip to content

Commit b330b75

Browse files
committed
[feature] support pcp + mtp in full graph
Signed-off-by: zhangsicheng5 <[email protected]>
1 parent 9af3475 commit b330b75

File tree

4 files changed

+77
-38
lines changed

4 files changed

+77
-38
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -433,17 +433,9 @@ def build(
433433
common_attn_metadata.block_table_tensor[:graph_pad_size])
434434
else:
435435
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
436-
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
437-
if self.pcp_size > 1:
438-
num_decodes_flatten = num_decodes * self.decode_threshold
439-
block_table = common_attn_metadata.block_table_tensor[:
440-
num_decodes_flatten
441-
+
442-
num_prefills]
443436
if num_actual_tokens_pcp_padded is None:
444437
num_actual_tokens_pcp_padded = num_actual_tokens
445438

446-
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
447439
slot_mapping = common_attn_metadata.slot_mapping[:
448440
num_actual_tokens_pcp_padded]
449441
input_positions = common_attn_metadata.positions[:
@@ -466,6 +458,13 @@ def build(
466458
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
467459
num_computed_tokens_cpu = (seq_lens - query_lens)
468460

461+
if self.pcp_size * self.dcp_size > 1:
462+
num_decodes_flatten = query_lens[:num_decodes].sum().item()
463+
block_table = common_attn_metadata.block_table_tensor[:
464+
num_decodes_flatten
465+
+
466+
num_prefills]
467+
469468
prefill_metadata = None
470469
chunked_context_metadata = None
471470
if num_prefills > 0:
@@ -530,8 +529,9 @@ def build(
530529
if self.dcp_size * self.pcp_size > 1:
531530
if num_computed_tokens_of_pcp_dcp is not None:
532531
local_context_lens_allranks = torch.tensor(
533-
num_computed_tokens_of_pcp_dcp[reqs_start:num_reqs]
534-
).reshape(-1, self.dcp_size * self.pcp_size)
532+
num_computed_tokens_of_pcp_dcp[
533+
num_decodes_flatten:]).reshape(
534+
-1, self.dcp_size * self.pcp_size)
535535
# Note(qcs): The max local context lengths
536536
# padded to `cp_local_block_size`.
537537
padded_local_context_lens_cpu = (cdiv(
@@ -617,7 +617,7 @@ def build(
617617
cos=cos,
618618
pcp_metadata=pcp_metadata,
619619
)
620-
if self.pcp_size > 1:
620+
if self.pcp_size * self.dcp_size > 1:
621621
prefill_metadata.block_table = block_table[
622622
num_decodes_flatten:, ...]
623623

@@ -630,13 +630,12 @@ def build(
630630
max_seq_lens = seq_lens[:num_decodes].max().item()
631631
seq_lens = seq_lens[:num_decodes]
632632
input_positions = input_positions[:num_decode_tokens]
633-
if self.pcp_size > 1:
633+
if self.pcp_size * self.dcp_size > 1:
634634
# For pcp + spec decode, we flatten seq_lens and block_table
635635
# to avoid irregular spec_attn_mask shape
636636
block_table = block_table[:num_decodes_flatten, ...]
637637
else:
638638
block_table = block_table[:num_decodes, ...]
639-
# NOTE: Currently, MTP-fullgraph is incompatibility pcp
640639
# NOTE: Maybe this block_table change can be removed when graph_pad_size > 1.
641640
if graph_pad_size > num_decodes and \
642641
self.speculative_config.disable_padded_drafter_batch:
@@ -646,8 +645,7 @@ def build(
646645
if num_computed_tokens_of_pcp_dcp is not None:
647646
# [bs, pcp_size, dcp_size]
648647
num_computed_tokens_of_cp_dcp_array = np.array(
649-
num_computed_tokens_of_pcp_dcp)[:num_decodes *
650-
self.decode_threshold]
648+
num_computed_tokens_of_pcp_dcp)[:num_decodes_flatten]
651649

652650
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
653651
self.pcp_rank,

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,13 @@ def dummy_run(self,
255255
cos=self.runner.cos,
256256
sin=self.runner.sin,
257257
)
258+
if self.pcp_size * self.dcp_size > 1:
259+
# update long_seq related params and flatten block_table
260+
common_attn_metadata.prefill_context_parallel_metadata=\
261+
self.runner.long_seq_metadata
262+
common_attn_metadata.block_table_tensor = \
263+
self.runner.input_batch.block_table[0].get_device_tensor()[
264+
:num_reqs * self.decode_threshold]
258265

259266
builder = self.runner.attn_groups[0][0].get_metadata_builder()
260267
attn_metadata_mtp = builder.build_for_graph_capture(
@@ -344,7 +351,7 @@ def generate_token_ids(self,
344351
)
345352

346353
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
347-
if self.pcp_size > 1:
354+
if self.pcp_size * self.dcp_size > 1:
348355
long_seq_metadata = self.runner.long_seq_metadata
349356
input_ids_pcp_full = self.runner.input_ids_pcp_full
350357
query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full
@@ -381,7 +388,6 @@ def generate_token_ids(self,
381388
query_start_loc_pcp_full[:num_reqs + 1]
382389
if self.speculative_config.disable_padded_drafter_batch:
383390
assert isinstance(sampled_token_ids, list)
384-
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
385391
token_indices_to_sample = None
386392
common_attn_metadata, token_indices =\
387393
self._prepare_inputs(
@@ -592,28 +598,35 @@ def _propose(
592598
self.input_ids[last_token_indices] = next_token_ids
593599

594600
# update pcp related params
595-
if self.pcp_size > 1:
601+
if self.pcp_size * self.dcp_size > 1:
596602
assert long_seq_metadata is not None
597603
common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata
604+
if self.pcp_size > 1:
598605
# 1. preprocess decode/prefill input_ids & target_hidden_states
599606
# decode input_ids: keep unchanged
600607
# decode target_hidden_states: remove padding
601608
# prefill input_ids: add padding and pcp split
602609
# prefill target_hidden_states: pcp split
603-
num_tokens_d = num_decode_reqs * self.decode_threshold
610+
query_lens_d = self.runner.query_lens[:num_decode_reqs]
611+
num_tokens_d = query_lens_d.sum().item()
604612
num_tokens_d_padded = num_tokens_d * self.pcp_size
605613
input_ids_d = self.input_ids[:num_tokens_d]
606614
input_ids_p = self.input_ids[num_tokens_d:num_tokens]
607615
target_hidden_states_d_padded = \
608616
target_hidden_states[:num_tokens_d_padded]
609617
if num_tokens_d:
610618
# remove padding (from pcp all-gather) in decode part
611-
target_hidden_states_d = target_hidden_states_d_padded.reshape(
612-
[
613-
num_decode_reqs, self.decode_threshold * self.pcp_size,
614-
-1
615-
])[:, :self.decode_threshold, :].reshape(
616-
[num_tokens_d, -1])
619+
mask_start_loc = torch.cat([
620+
torch.tensor([0], dtype=torch.int32),
621+
torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]
622+
])
623+
mask_len = query_lens_d
624+
mask = []
625+
for req_id in range(num_decode_reqs):
626+
mask += list(
627+
range(mask_start_loc[req_id],
628+
mask_start_loc[req_id] + mask_len[req_id]))
629+
target_hidden_states_d = target_hidden_states_d_padded[mask]
617630
else:
618631
target_hidden_states_d = target_hidden_states_d_padded
619632
target_hidden_states_p = target_hidden_states[num_tokens_d_padded:]
@@ -749,6 +762,8 @@ def _propose(
749762
(0, max_num_reqs_across_dp - num_indices))
750763

751764
if self.pcp_size > 1:
765+
# remove graph padding before all_gather
766+
hidden_states = hidden_states[:num_tokens]
752767
hidden_states = get_pcp_group().all_gather(hidden_states, 0)
753768
hidden_states = torch.index_select(
754769
hidden_states, 0, self.runner.

vllm_ascend/worker/block_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self,
8080
logical_table_size = max_num_blocks_per_req
8181

8282
duplicate_size = 1
83-
if self.pcp_world_size > 1:
83+
if self.pcp_world_size * self.dcp_world_size > 1:
8484
duplicate_size += num_speculative_tokens
8585
self.block_table = torch.zeros(
8686
(max_num_reqs * duplicate_size, logical_table_size),

vllm_ascend/worker/model_runner_v1.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
496496
dtype=torch.int32,
497497
device=self.device)
498498
self.num_actual_tokens_pcp_padded = 0
499-
if self.speculative_config and self.pcp_size > 1:
499+
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
500500
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
501501
dtype=torch.int32,
502502
device=self.device)
@@ -1738,7 +1738,7 @@ def _prepare_inputs(
17381738
self.num_accepted_tokens.np[num_reqs:].fill(1)
17391739
self.num_accepted_tokens.copy_to_gpu()
17401740

1741-
if self.speculative_config and self.pcp_size > 1:
1741+
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
17421742
self._generate_pcp_mtp_input(
17431743
num_reqs, scheduler_output.total_num_scheduled_tokens,
17441744
scheduler_output.num_scheduled_tokens)
@@ -1820,28 +1820,29 @@ def _prepare_inputs(
18201820
prefill_context_parallel_metadata=long_seq_metadata,
18211821
)
18221822

1823-
if self.speculative_config and self.pcp_size > 1:
1823+
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
18241824
# For pcp + spec decode, we flatten block_table
18251825
# to avoid irregular spec_attn_mask shape, e.g.,
18261826
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
18271827
# ori block_table: # [d0, d1, p0, p1, p2]
18281828
# (num_reqs_d + num_reqs_p, max_num_blocks),
18291829
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
18301830
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
1831-
ori_query_lens = self.query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
1832-
self.query_start_loc_pcp_full_cpu[:num_reqs]
1831+
ori_query_lens = self.query_start_loc_pcp_full[1:num_reqs+1] - \
1832+
self.query_start_loc_pcp_full[:num_reqs]
18331833
num_prefill_reqs = (ori_query_lens
18341834
> self.decode_threshold).sum().item()
18351835
num_decode_reqs = num_reqs - num_prefill_reqs
1836-
num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold
1836+
num_decode_reqs_flatten = \
1837+
ori_query_lens[:num_decode_reqs].sum().item()
18371838
blk_table_tensor[
18381839
num_decode_reqs_flatten:num_decode_reqs_flatten +
18391840
num_prefill_reqs].copy_(
18401841
blk_table_tensor[num_decode_reqs:num_decode_reqs +
18411842
num_prefill_reqs].clone())
18421843
blk_table_tensor[:num_decode_reqs_flatten].copy_(
18431844
blk_table_tensor[:num_decode_reqs].repeat_interleave(
1844-
self.decode_threshold, dim=0))
1845+
ori_query_lens[:num_decode_reqs], dim=0))
18451846
common_attn_metadata.block_table_tensor = \
18461847
blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs]
18471848

@@ -2788,7 +2789,7 @@ def _build_dummy_attn_metadata(
27882789
sin=self.sin,
27892790
prefill_context_parallel_metadata=long_seq_metadata,
27902791
)
2791-
if self.pcp_size > 1:
2792+
if self.pcp_size * self.dcp_size > 1:
27922793
common_attn_metadata.block_table_tensor = \
27932794
block_table_tensor[:num_reqs * self.decode_threshold]
27942795
attn_state = AscendAttentionState.DecodeOnly
@@ -4250,8 +4251,8 @@ def _get_cp_local_seq_lens(
42504251
def _generate_pcp_metadata(self, total_num_scheduled_tokens):
42514252
# In dummy run num_reqs == 0, update it from seq_lens
42524253
num_reqs = self.input_batch.num_reqs or self.query_lens.size(0)
4253-
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
4254-
>= self.input_batch.num_prompt_tokens[:num_reqs])
4254+
num_decodes = (self.query_lens <= self.decode_threshold).sum().item()
4255+
num_prefills = num_reqs - num_decodes
42554256
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
42564257
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
42574258
long_seq_metadata = None
@@ -4269,16 +4270,41 @@ def _generate_pcp_metadata(self, total_num_scheduled_tokens):
42694270
dtype=torch.int32,
42704271
)
42714272
# For pcp + spec decode, we flatten seq_lens
4272-
# to avoid irregular spec_attn_mask shape
4273+
# to avoid irregular spec_attn_mask shape.
4274+
# Same as block_table, we flatten decode seq_lens to query_lens,
4275+
# and keep prefill seq_lens unchanged.
42734276
for decode_idx in range(self.decode_threshold):
42744277
num_computed_tokens_of_pcp_dcp[
42754278
self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \
42764279
self._get_cp_local_seq_lens(
4277-
torch.tensor(context_lens),
4280+
torch.tensor(context_lens) - decode_idx,
42784281
self.pcp_size,
42794282
self.dcp_size,
42804283
self.parallel_config.cp_kv_cache_interleave_size,
42814284
)
4285+
if self.decode_threshold > 1:
4286+
num_computed_tokens_of_pcp_dcp_list = []
4287+
if num_decodes:
4288+
num_decodes_flatten = \
4289+
self.query_lens[:num_decodes].sum().item()
4290+
if self.query_lens[:num_decodes].min().item(
4291+
) == self.decode_threshold:
4292+
decode_flatten_idx = list(range(num_decodes_flatten))
4293+
else:
4294+
decode_flatten_idx = []
4295+
for req_id in range(num_decodes):
4296+
offset = (req_id + 1) * self.decode_threshold
4297+
decode_flatten_idx += \
4298+
list(range(offset - self.query_lens[req_id], offset))
4299+
num_computed_tokens_of_pcp_dcp_list.append(
4300+
num_computed_tokens_of_pcp_dcp[decode_flatten_idx])
4301+
if num_prefills:
4302+
num_computed_tokens_of_pcp_dcp_list.append(
4303+
num_computed_tokens_of_pcp_dcp[
4304+
(num_decodes + 1) * self.decode_threshold -
4305+
1::self.decode_threshold])
4306+
num_computed_tokens_of_pcp_dcp = torch.cat(
4307+
num_computed_tokens_of_pcp_dcp_list, dim=0)
42824308
long_seq_metadata = AscendPrefillContextParallelMetadata(
42834309
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
42844310
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.

0 commit comments

Comments
 (0)