Skip to content

Commit 6a3ec43

Browse files
committed
Support mtp run in full graph mode
Signed-off-by: anon189Ty <[email protected]>
1 parent 51e5806 commit 6a3ec43

File tree

5 files changed

+185
-17
lines changed

5 files changed

+185
-17
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def set_ascend_forward_context(
7171
batch_descriptor: Optional[BatchDescriptor] = None,
7272
prefetch_stream: torch.npu.Stream = None,
7373
model_instance: torch.nn.Module = None,
74-
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
74+
weight_prefetch_method: Optional[WeightPrefetchMethod] = None,
75+
is_mtp_model=False):
7576
"""A context manager that stores the current forward context,
7677
can be attention metadata, etc.
7778
We add some additional param into forward_context.
@@ -153,6 +154,7 @@ def set_ascend_forward_context(
153154
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
154155
forward_context.model_instance = model_instance
155156
forward_context.weight_prefetch_method = weight_prefetch_method
157+
forward_context.is_mtp_model = is_mtp_model
156158

157159
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
158160
# It will be improved later by implementing operator fusion through the FX graph.

vllm_ascend/attention/mla_v1.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
trans_rope_weight, transdata,
2727
wait_for_kv_layer_from_connector)
2828
from vllm_ascend.compilation.acl_graph import (get_graph_params,
29+
get_mtp_graph_params,
2930
update_graph_params_workspaces)
3031
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
3132
from vllm_ascend.multistream.context import get_multistream_comm_context
@@ -1028,8 +1029,11 @@ def _forward_decode(
10281029
"actual_seq_lengths": actual_seq_lengths,
10291030
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
10301031
}
1031-
graph_params = get_graph_params()
10321032
forward_context: ForwardContext = get_forward_context()
1033+
if forward_context.is_mtp_model:
1034+
graph_params = get_mtp_graph_params()
1035+
else:
1036+
graph_params = get_graph_params()
10331037
if forward_context.capturing:
10341038
stream = torch_npu.npu.current_stream()
10351039

vllm_ascend/compilation/acl_graph.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,10 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
249249

250250
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
251251
speculative_config):
252-
graph_params = get_graph_params()
252+
if forward_context.is_mtp_model:
253+
graph_params = get_mtp_graph_params()
254+
else:
255+
graph_params = get_graph_params()
253256
# FIXME: Behold! We are using a temporary hack here to update the args
254257
# for each layer's attention op in the graph.
255258
with torch.npu.stream(update_stream):
@@ -265,7 +268,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
265268
softmax_lse) = param
266269
seq_lens_list = forward_context.attn_metadata[
267270
key].decode.seq_lens_list
268-
if speculative_config and speculative_config.method == "deepseek_mtp":
271+
if speculative_config and speculative_config.method == "deepseek_mtp" \
272+
and not forward_context.is_mtp_model:
269273
actual_seq_lengths = forward_context.attn_metadata[
270274
key].decode.actual_seq_lengths_q
271275
spec_multiple = speculative_config.num_speculative_tokens + 1
@@ -340,3 +344,40 @@ def update_graph_params_workspaces(num_tokens: int, workspace: Any):
340344

341345
def get_graph_params():
342346
return _graph_params
347+
348+
349+
@dataclass
350+
class MTPGraphParams:
351+
events: dict[int, list[torch.npu.ExternalEvent]]
352+
workspaces: dict[int, torch.Tensor]
353+
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
354+
attn_params: dict[int, list[tuple]]
355+
356+
357+
_mtp_graph_params: Optional[MTPGraphParams] = None
358+
359+
360+
def set_mtp_graph_params(aclgraph_capture_sizes: set[int]):
361+
global _mtp_graph_params
362+
if _mtp_graph_params is not None:
363+
raise ValueError("MTPGraph parameters have already been set!")
364+
_mtp_graph_params = MTPGraphParams(
365+
{size: []
366+
for size in aclgraph_capture_sizes},
367+
{size: None
368+
for size in aclgraph_capture_sizes},
369+
{size: []
370+
for size in aclgraph_capture_sizes},
371+
{size: []
372+
for size in aclgraph_capture_sizes},
373+
)
374+
375+
376+
def update_mtp_graph_params_workspaces(num_tokens: int, workspace: Any):
377+
global _mtp_graph_params
378+
if _mtp_graph_params is not None:
379+
_mtp_graph_params.workspaces[num_tokens] = workspace
380+
381+
382+
def get_mtp_graph_params():
383+
return _mtp_graph_params

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 134 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717

1818
from vllm_ascend.ascend_config import get_ascend_config
1919
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
20+
from vllm_ascend.attention.attention_v1 import AscendAttentionState
2021
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
2122
from vllm_ascend.patch.worker.patch_deepseek_mtp import \
2223
AscendDeepSeekMTP as DeepSeekMTP
24+
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
25+
set_mtp_graph_params,
26+
update_mla_attn_params)
2327
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
2428
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
2529
TorchairDeepSeekMTP
@@ -71,6 +75,23 @@ def __init__(
7175
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
7276
"index_topk")
7377

78+
self.actual_seq_lengths_q = list(
79+
range(1, self.runner.max_num_tokens + 1, 1))
80+
self.query_start_loc = torch.zeros(self.runner.max_num_reqs + 1,
81+
dtype=torch.int32,
82+
device=self.device)
83+
self.query_start_loc_cpu = torch.zeros(self.runner.max_num_reqs + 1,
84+
dtype=torch.int32,
85+
device="cpu",
86+
pin_memory=True)
87+
self.slot_mapping = torch.zeros(self.runner.max_num_tokens,
88+
dtype=torch.int32,
89+
device=self.device)
90+
self.seq_lens_cpu = torch.zeros(self.runner.max_num_reqs,
91+
dtype=torch.int32,
92+
device="cpu",
93+
pin_memory=True)
94+
7495
def load_model(self, model) -> None:
7596
loader = get_model_loader(self.vllm_config.load_config)
7697

@@ -106,6 +127,15 @@ def load_model(self, model) -> None:
106127
process_weights_after_loading(self.model, draft_model_config,
107128
target_device)
108129

130+
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
131+
):
132+
self.update_stream: torch.npu.Stream = torch.npu.Stream()
133+
set_mtp_graph_params(
134+
self.vllm_config.compilation_config.cudagraph_capture_sizes)
135+
self.model = ACLGraphWrapper(self.model,
136+
self.vllm_config,
137+
runtime_mode=CUDAGraphMode.FULL)
138+
109139
@torch.inference_mode()
110140
def dummy_run(self,
111141
num_tokens: int,
@@ -131,7 +161,7 @@ def dummy_run(self,
131161
skip_attn = False
132162
if skip_attn:
133163
attn_metadata = None
134-
else:
164+
elif is_running_torchair:
135165
common_attn_metadata = TorchairCommonAttentionMetadata(
136166
num_reqs=num_reqs,
137167
num_actual_tokens=1,
@@ -142,6 +172,56 @@ def dummy_run(self,
142172
)
143173
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
144174
common_attn_metadata)
175+
elif aclgraph_runtime_mode == CUDAGraphMode.FULL:
176+
assert with_prefill is False, \
177+
"Full decode graph only supports uniform batch now."
178+
num_reqs = num_tokens
179+
max_seq_lens = self.runner.model_config.max_model_len
180+
self.seq_lens_cpu[:num_reqs] = max_seq_lens
181+
self.seq_lens_cpu[num_reqs:] = 0
182+
if len(self.runner.attn_groups) > 0:
183+
num_computed_tokens_cpu = (
184+
self.runner.input_batch.
185+
num_computed_tokens_cpu_tensor[:num_reqs])
186+
query_start_loc = torch.tensor(
187+
[0] + self.actual_seq_lengths_q[:num_reqs],
188+
device=self.runner.device,
189+
dtype=torch.int32)
190+
self.query_start_loc[:num_reqs + 1].copy_(query_start_loc)
191+
common_attn_metadata = AscendCommonAttentionMetadata(
192+
query_start_loc=self.query_start_loc[:num_reqs + 1],
193+
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
194+
1],
195+
seq_lens_cpu=self.seq_lens_cpu,
196+
seq_lens=self.seq_lens_cpu[:num_reqs],
197+
num_reqs=num_reqs,
198+
num_actual_tokens=num_tokens,
199+
max_query_len=self.num_speculative_tokens + 1,
200+
num_computed_tokens_cpu=num_computed_tokens_cpu,
201+
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
202+
block_table_tensor=self.runner.input_batch.block_table[0].
203+
get_device_tensor()[:num_reqs],
204+
slot_mapping=self.slot_mapping,
205+
positions=self.positions,
206+
attn_mask=self.runner.attn_mask,
207+
spec_attn_mask=self.runner.spec_attn_mask,
208+
attn_state=self.runner.attn_state,
209+
decode_token_per_req=self.runner.decode_token_per_req,
210+
cos=self.runner.cos, # 考虑mrope,是否可以共用?
211+
sin=self.runner.sin,
212+
)
213+
214+
builder = self.runner.attn_groups[0][0].get_metadata_builder()
215+
attn_metadata_mtp = builder.build_for_graph_capture(
216+
common_attn_metadata, AscendAttentionState.SpecDecoding,
217+
self.runner.get_model())
218+
attn_metadata = {}
219+
for layer_name in self.attn_layer_name:
220+
attn_metadata[layer_name] = attn_metadata_mtp
221+
else:
222+
attn_metadata = None
223+
else:
224+
attn_metadata = None
145225

146226
input_ids = self.input_ids[:num_tokens]
147227
positions = self.positions[:num_tokens]
@@ -158,7 +238,8 @@ def dummy_run(self,
158238
in_profile_run=self.runner.in_profile_run,
159239
num_actual_tokens=0,
160240
aclgraph_runtime_mode=aclgraph_runtime_mode,
161-
batch_descriptor=batch_descriptor):
241+
batch_descriptor=batch_descriptor,
242+
is_mtp_model=True):
162243
if is_running_torchair:
163244
assert attn_metadata is not None
164245
torch._dynamo.mark_static(input_ids)
@@ -188,6 +269,14 @@ def dummy_run(self,
188269
self.model(input_ids=input_ids,
189270
positions=positions,
190271
hidden_states=previous_hidden_states)
272+
forward_context = get_forward_context()
273+
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
274+
not forward_context.capturing:
275+
if self.vllm_config.model_config.use_mla:
276+
update_mla_attn_params(
277+
self.update_stream, forward_context,
278+
positions.shape[0],
279+
self.vllm_config.speculative_config)
191280
if with_prefill:
192281
break
193282

@@ -260,7 +349,8 @@ def generate_token_ids(self,
260349
cu_num_tokens=cu_num_tokens,
261350
block_table=attn_metadata.block_tables,
262351
sampling_metadata=sampling_metadata,
263-
token_indices=accepted_token_indices)
352+
token_indices=accepted_token_indices,
353+
scheduler_output=scheduler_output)
264354
spec_token_ids = draft_token_ids.tolist()
265355
return spec_token_ids
266356

@@ -322,6 +412,17 @@ def _prepare_inputs(
322412
target_positions = positions[token_indices]
323413
target_hidden_states = hidden_states[token_indices]
324414
target_slot_mapping = slot_mapping[token_indices]
415+
416+
batch_size = num_rejected_tokens.shape[0]
417+
self.query_start_loc[:batch_size + 1].copy_(cu_num_tokens[:batch_size +
418+
1])
419+
self.query_start_loc_cpu[:batch_size + 1].copy_(
420+
self.query_start_loc[:batch_size + 1], non_blocking=True)
421+
target_positions_len = target_positions.shape[0]
422+
self.positions[:target_positions_len].copy_(target_positions)
423+
target_slot_mapping_len = target_slot_mapping.shape[0]
424+
self.slot_mapping[:target_slot_mapping_len].copy_(target_slot_mapping)
425+
325426
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
326427

327428
def _propose(
@@ -341,7 +442,8 @@ def _propose(
341442
# [batch_size, max_num_blocks_per_req]
342443
block_table: torch.Tensor,
343444
sampling_metadata: SamplingMetadata,
344-
token_indices=None) -> torch.Tensor:
445+
token_indices=None,
446+
scheduler_output: SchedulerOutput = None) -> torch.Tensor:
345447
num_tokens = target_token_ids.shape[0]
346448
batch_size = next_token_ids.shape[0]
347449
last_token_indices = cu_num_tokens[1:] - 1
@@ -385,18 +487,20 @@ def _propose(
385487

386488
seq_lens = target_positions[last_token_indices] + 1
387489
seq_lens = seq_lens.int()
490+
seq_lens_len = seq_lens.shape[0]
491+
self.seq_lens_cpu[:seq_lens_len].copy_(seq_lens, non_blocking=True)
388492
common_attn_metadata = AscendCommonAttentionMetadata(
389-
query_start_loc=cu_num_tokens[:batch_size + 1],
390-
query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(),
391-
seq_lens_cpu=seq_lens.cpu(),
493+
query_start_loc=self.query_start_loc[:batch_size + 1],
494+
query_start_loc_cpu=self.query_start_loc_cpu[:batch_size + 1],
495+
seq_lens_cpu=self.seq_lens_cpu[:seq_lens_len],
392496
num_reqs=batch_size,
393497
num_actual_tokens=num_tokens,
394498
max_query_len=max_query_len,
395499
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
396500
block_table_tensor=self.runner.input_batch.block_table[0].
397501
get_device_tensor(),
398-
slot_mapping=target_slot_mapping,
399-
positions=target_positions,
502+
slot_mapping=self.slot_mapping[:target_slot_mapping.shape[0]],
503+
positions=self.positions[:target_positions.shape[0]],
400504
attn_mask=self.runner.attn_mask,
401505
spec_attn_mask=self.runner.spec_attn_mask,
402506
attn_state=self.runner.attn_state,
@@ -434,8 +538,18 @@ def _propose(
434538

435539
moe_comm_type = self.runner._select_moe_comm_method(
436540
num_input_tokens, with_prefill)
437-
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
438-
uniform_decode=False)
541+
542+
if scheduler_output:
543+
uniform_decode = (max_query_len in list(
544+
range(1, self.num_speculative_tokens + 2))) and (
545+
scheduler_output.total_num_scheduled_tokens //
546+
(self.num_speculative_tokens + 2 - max_query_len)
547+
== self.runner.input_batch.num_reqs * max_query_len)
548+
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
549+
uniform_decode=uniform_decode)
550+
else:
551+
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
552+
uniform_decode=False)
439553
aclgraph_runtime_mode, batch_descriptor = \
440554
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
441555

@@ -451,7 +565,8 @@ def _propose(
451565
aclgraph_runtime_mode=aclgraph_runtime_mode,
452566
batch_descriptor=batch_descriptor,
453567
in_profile_run=self.runner.in_profile_run,
454-
num_actual_tokens=num_tokens):
568+
num_actual_tokens=num_tokens,
569+
is_mtp_model=True):
455570
with ProfileExecuteDuration().capture_async('mtp_forward'):
456571
model_kwargs = {}
457572
model_kwargs["attn_metadata"] = attn_metadata
@@ -475,6 +590,13 @@ def _propose(
475590
positions=self.positions[:num_input_tokens],
476591
hidden_states=self.hidden_states[:num_input_tokens]
477592
)
593+
forward_context = get_forward_context()
594+
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
595+
if self.vllm_config.model_config.use_mla:
596+
update_mla_attn_params(
597+
self.update_stream, forward_context,
598+
num_input_tokens,
599+
self.vllm_config.speculative_config)
478600

479601
num_indices = last_token_indices.shape[0]
480602
if lmhead_tp_enable():

vllm_ascend/worker/model_runner_v1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2501,7 +2501,6 @@ def dummy_compute_logits(hidden_states):
25012501
self.drafter.dummy_run(
25022502
num_tokens=num_tokens,
25032503
with_prefill=with_prefill,
2504-
skip_attn=True,
25052504
num_reqs=num_reqs,
25062505
num_tokens_across_dp=num_tokens_across_dp,
25072506
aclgraph_runtime_mode=aclgraph_runtime_mode,

0 commit comments

Comments
 (0)