@@ -189,11 +189,20 @@ def __call__(self, *args, **kwargs):
189189 return entry .output
190190
191191
192- def update_attn_params (update_stream , forward_context , runtime_shape ):
192+ def update_attn_params (update_stream ,
193+ forward_context ,
194+ runtime_shape ,
195+ kv_transfer_config = None ):
193196 graph_params = get_graph_params ()
194- # FIXME: Behold! We are using a temporary hack here to update the args
195- # for each layer's attention op in the graph.
196- with torch .npu .stream (update_stream ):
197+
198+ # NOTE(Angazenn): By moving the npu-stream context ahead,
199+ # (see https://github.com/vllm-project/vllm-ascend/pull/3985)
200+ # we can reduce host overhead introduced by stream initialization.
201+ # However, we find that this might cause potential accuracy problems
202+ # with pd-disaggreagation. Therefore, this optimization is only enabled
203+ # without pd-disaggreagation. We are working on to solve this problem
204+ # directly int the future.
205+ if kv_transfer_config is not None :
197206 for key , param , handle , event in zip (
198207 forward_context .attn_metadata ,
199208 graph_params .attn_params [runtime_shape ],
@@ -215,10 +224,9 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
215224
216225 # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
217226 # mode with GQA. This is triggered by getting workspace for _npu_paged_attention
218- # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
219- # might encounter a bigger workspace, while currently we use max_model_len to
220- # calculate max workspace in capturing. So additional get_workspace is added
221- # here to avoid such bugs.
227+ # in torch_npu. On some cases, _npu_paged_attention requires different workspace
228+ # among various seq_lens. So additional get_workspace is added here
229+ # to avoid such bugs.
222230 # TODO(Angazenn): we will remove this once _npu_paged_attention is fully
223231 # replaced by npu_fused_infer_attention_score which does not contain such bugs.
224232 workspace = torch_npu ._npu_paged_attention_get_workspace (
@@ -231,20 +239,67 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
231239 block_table = block_table ,
232240 context_lens = seq_lens ,
233241 out = output )
234- torch .npu .graph_task_update_begin (update_stream , handle )
235- torch_npu ._npu_paged_attention (query = query ,
236- key_cache = key_cache ,
237- value_cache = value_cache ,
238- num_kv_heads = num_kv_heads ,
239- num_heads = num_heads ,
240- scale_value = scale ,
241- block_table = block_table ,
242- context_lens = seq_lens ,
243- out = output ,
244- workspace = workspace )
245- torch .npu .graph_task_update_end (update_stream )
246242
247- event .record (update_stream )
243+ with torch .npu .stream (update_stream ):
244+ torch .npu .graph_task_update_begin (update_stream , handle )
245+ torch_npu ._npu_paged_attention (query = query ,
246+ key_cache = key_cache ,
247+ value_cache = value_cache ,
248+ num_kv_heads = num_kv_heads ,
249+ num_heads = num_heads ,
250+ scale_value = scale ,
251+ block_table = block_table ,
252+ context_lens = seq_lens ,
253+ out = output ,
254+ workspace = workspace )
255+ torch .npu .graph_task_update_end (update_stream )
256+
257+ event .record (update_stream )
258+ else :
259+ with torch .npu .stream (update_stream ):
260+ for key , param , handle , event in zip (
261+ forward_context .attn_metadata ,
262+ graph_params .attn_params [runtime_shape ],
263+ graph_params .handles [runtime_shape ],
264+ graph_params .events [runtime_shape ],
265+ ):
266+ (
267+ query ,
268+ key_cache ,
269+ value_cache ,
270+ num_kv_heads ,
271+ num_heads ,
272+ scale ,
273+ block_table ,
274+ seq_lens ,
275+ output ,
276+ ) = param
277+ seq_lens = forward_context .attn_metadata [key ].seq_lens
278+
279+ workspace = torch_npu ._npu_paged_attention_get_workspace (
280+ query = query ,
281+ key_cache = key_cache ,
282+ value_cache = value_cache ,
283+ num_kv_heads = num_kv_heads ,
284+ num_heads = num_heads ,
285+ scale_value = scale ,
286+ block_table = block_table ,
287+ context_lens = seq_lens ,
288+ out = output )
289+ torch .npu .graph_task_update_begin (update_stream , handle )
290+ torch_npu ._npu_paged_attention (query = query ,
291+ key_cache = key_cache ,
292+ value_cache = value_cache ,
293+ num_kv_heads = num_kv_heads ,
294+ num_heads = num_heads ,
295+ scale_value = scale ,
296+ block_table = block_table ,
297+ context_lens = seq_lens ,
298+ out = output ,
299+ workspace = workspace )
300+ torch .npu .graph_task_update_end (update_stream )
301+
302+ event .record (update_stream )
248303
249304
250305def update_mla_attn_params (update_stream , forward_context , runtime_shape ,
0 commit comments