Skip to content

Commit 6391f06

Browse files
authored
[v0.11.0-dev][bugfix] Add branch for stream up-lifting in update_attn_params (vllm-project#4437)
### What this PR does / why we need it? vllm-project#3985 move stream context initialization before for-loops to improve performance. However, we find that this might cause potential accuracy drop when used with pd disaggregation. Thus we partly revert this change when using pd disaggregation, and we shall fix this bug in th future. ### Does this PR introduce _any_ user-facing change? No. --------- Signed-off-by: Angazenn <[email protected]>
1 parent 2598124 commit 6391f06

File tree

2 files changed

+80
-23
lines changed

2 files changed

+80
-23
lines changed

vllm_ascend/compilation/acl_graph.py

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,20 @@ def __call__(self, *args, **kwargs):
189189
return entry.output
190190

191191

192-
def update_attn_params(update_stream, forward_context, runtime_shape):
192+
def update_attn_params(update_stream,
193+
forward_context,
194+
runtime_shape,
195+
kv_transfer_config=None):
193196
graph_params = get_graph_params()
194-
# FIXME: Behold! We are using a temporary hack here to update the args
195-
# for each layer's attention op in the graph.
196-
with torch.npu.stream(update_stream):
197+
198+
# NOTE(Angazenn): By moving the npu-stream context ahead,
199+
# (see https://github.com/vllm-project/vllm-ascend/pull/3985)
200+
# we can reduce host overhead introduced by stream initialization.
201+
# However, we find that this might cause potential accuracy problems
202+
# with pd-disaggreagation. Therefore, this optimization is only enabled
203+
# without pd-disaggreagation. We are working on to solve this problem
204+
# directly int the future.
205+
if kv_transfer_config is not None:
197206
for key, param, handle, event in zip(
198207
forward_context.attn_metadata,
199208
graph_params.attn_params[runtime_shape],
@@ -215,10 +224,9 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
215224

216225
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
217226
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
218-
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
219-
# might encounter a bigger workspace, while currently we use max_model_len to
220-
# calculate max workspace in capturing. So additional get_workspace is added
221-
# here to avoid such bugs.
227+
# in torch_npu. On some cases, _npu_paged_attention requires different workspace
228+
# among various seq_lens. So additional get_workspace is added here
229+
# to avoid such bugs.
222230
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
223231
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
224232
workspace = torch_npu._npu_paged_attention_get_workspace(
@@ -231,20 +239,67 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
231239
block_table=block_table,
232240
context_lens=seq_lens,
233241
out=output)
234-
torch.npu.graph_task_update_begin(update_stream, handle)
235-
torch_npu._npu_paged_attention(query=query,
236-
key_cache=key_cache,
237-
value_cache=value_cache,
238-
num_kv_heads=num_kv_heads,
239-
num_heads=num_heads,
240-
scale_value=scale,
241-
block_table=block_table,
242-
context_lens=seq_lens,
243-
out=output,
244-
workspace=workspace)
245-
torch.npu.graph_task_update_end(update_stream)
246242

247-
event.record(update_stream)
243+
with torch.npu.stream(update_stream):
244+
torch.npu.graph_task_update_begin(update_stream, handle)
245+
torch_npu._npu_paged_attention(query=query,
246+
key_cache=key_cache,
247+
value_cache=value_cache,
248+
num_kv_heads=num_kv_heads,
249+
num_heads=num_heads,
250+
scale_value=scale,
251+
block_table=block_table,
252+
context_lens=seq_lens,
253+
out=output,
254+
workspace=workspace)
255+
torch.npu.graph_task_update_end(update_stream)
256+
257+
event.record(update_stream)
258+
else:
259+
with torch.npu.stream(update_stream):
260+
for key, param, handle, event in zip(
261+
forward_context.attn_metadata,
262+
graph_params.attn_params[runtime_shape],
263+
graph_params.handles[runtime_shape],
264+
graph_params.events[runtime_shape],
265+
):
266+
(
267+
query,
268+
key_cache,
269+
value_cache,
270+
num_kv_heads,
271+
num_heads,
272+
scale,
273+
block_table,
274+
seq_lens,
275+
output,
276+
) = param
277+
seq_lens = forward_context.attn_metadata[key].seq_lens
278+
279+
workspace = torch_npu._npu_paged_attention_get_workspace(
280+
query=query,
281+
key_cache=key_cache,
282+
value_cache=value_cache,
283+
num_kv_heads=num_kv_heads,
284+
num_heads=num_heads,
285+
scale_value=scale,
286+
block_table=block_table,
287+
context_lens=seq_lens,
288+
out=output)
289+
torch.npu.graph_task_update_begin(update_stream, handle)
290+
torch_npu._npu_paged_attention(query=query,
291+
key_cache=key_cache,
292+
value_cache=value_cache,
293+
num_kv_heads=num_kv_heads,
294+
num_heads=num_heads,
295+
scale_value=scale,
296+
block_table=block_table,
297+
context_lens=seq_lens,
298+
out=output,
299+
workspace=workspace)
300+
torch.npu.graph_task_update_end(update_stream)
301+
302+
event.record(update_stream)
248303

249304

250305
def update_mla_attn_params(update_stream, forward_context, runtime_shape,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,7 +1598,8 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
15981598
self.speculative_config)
15991599
else:
16001600
update_attn_params(self.update_stream, forward_context,
1601-
maybe_padded_num_tokens)
1601+
maybe_padded_num_tokens,
1602+
self.vllm_config.kv_transfer_config)
16021603

16031604
if get_forward_context().sp_enabled:
16041605
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
@@ -2359,7 +2360,8 @@ def _generate_dummy_run_hidden_states(self, with_prefill,
23592360
num_tokens, self.speculative_config)
23602361
else:
23612362
update_attn_params(self.update_stream, forward_context,
2362-
num_tokens)
2363+
num_tokens,
2364+
self.vllm_config.kv_transfer_config)
23632365

23642366
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
23652367
hidden_states, _ = hidden_states

0 commit comments

Comments
 (0)