File tree Expand file tree Collapse file tree 2 files changed +4
-2
lines changed
Expand file tree Collapse file tree 2 files changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -292,7 +292,7 @@ def dummy_run(self,
292292 positions = self .positions [:num_tokens ]
293293 previous_hidden_states = self .hidden_states [:num_tokens ]
294294 for i in range (self .num_speculative_tokens ):
295- if i > 0 and not in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode .FULL :
295+ if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode .FULL :
296296 aclgraph_runtime_mode = CUDAGraphMode .NONE
297297 with set_ascend_forward_context (
298298 attn_metadata ,
Original file line number Diff line number Diff line change @@ -2268,6 +2268,8 @@ def dummy_drafter_compute_logits(hidden_states):
22682268 dummy_compute_logits (hidden_states )
22692269
22702270 if self .drafter :
2271+ # `in_graph_capturing` indicates whether the main model is in graph capturing.
2272+ # The value is only used in `mtp_proposer.py` currently and defaults to False.
22712273 self .drafter .dummy_run (
22722274 num_tokens = num_tokens_padded ,
22732275 with_prefill = with_prefill ,
@@ -2276,7 +2278,7 @@ def dummy_drafter_compute_logits(hidden_states):
22762278 aclgraph_runtime_mode = aclgraph_runtime_mode ,
22772279 batch_descriptor = batch_descriptor ,
22782280 dummy_compute_logits = dummy_drafter_compute_logits ,
2279- in_graph_capturing = not force_attention )
2281+ in_graph_capturing = force_attention )
22802282 if self .in_profile_run and self .dynamic_eplb :
22812283 self .model .clear_all_moe_loads ()
22822284 if not self .in_profile_run and self .dynamic_eplb :
You can’t perform that action at this time.
0 commit comments