4141 split_decodes_and_prefills )
4242from vllm_ascend .compilation .acl_graph import (get_graph_params ,
4343 update_graph_params_workspaces )
44- from vllm_ascend .utils import prefill_context_parallel_enable , weak_ref_tensors
44+ from vllm_ascend .utils import (AscendDeviceType , get_ascend_device_type ,
45+ prefill_context_parallel_enable ,
46+ weak_ref_tensors )
4547
4648# isort: off
4749if prefill_context_parallel_enable ():
@@ -1421,12 +1423,20 @@ def forward(
14211423 if has_decode :
14221424 slot_mapping = attn_metadata .slot_mapping [:num_decode_tokens * self .pcp_size : self .pcp_size ] \
14231425 if self .pcp_size * self .dcp_size > 1 else attn_metadata .slot_mapping [:num_decode_tokens ]
1424- torch_npu ._npu_reshape_and_cache (
1425- key = key [:num_decode_tokens ],
1426- value = value [:num_decode_tokens ],
1427- key_cache = self .key_cache ,
1428- value_cache = self .value_cache ,
1429- slot_indices = slot_mapping )
1426+ if get_ascend_device_type () == AscendDeviceType ._910_95 :
1427+ torch_npu .npu_scatter_pa_kv_cache (
1428+ key = key [:num_decode_tokens ],
1429+ value = value [:num_decode_tokens ],
1430+ key_cache = self .key_cache ,
1431+ value_cache = self .value_cache ,
1432+ slot_indices = slot_mapping )
1433+ else :
1434+ torch_npu ._npu_reshape_and_cache (
1435+ key = key [:num_decode_tokens ],
1436+ value = value [:num_decode_tokens ],
1437+ key_cache = self .key_cache ,
1438+ value_cache = self .value_cache ,
1439+ slot_indices = slot_mapping )
14301440
14311441 if has_prefill :
14321442 if self .pcp_size > 1 :
@@ -1440,18 +1450,35 @@ def forward(
14401450 key , value = all_kv .split ([self .head_size , self .head_size ],
14411451 dim = - 1 )
14421452
1443- torch_npu ._npu_reshape_and_cache (
1444- key = key [self .pcp_size * num_decode_tokens :attn_metadata .
1445- num_actual_tokens_pcp_padded ],
1446- value = value [self .pcp_size *
1453+ if get_ascend_device_type () == AscendDeviceType ._910_95 :
1454+ torch_npu .npu_scatter_pa_kv_cache (
1455+ key = key [self .pcp_size *
14471456 num_decode_tokens :attn_metadata .
14481457 num_actual_tokens_pcp_padded ],
1449- key_cache = self .key_cache ,
1450- value_cache = self .value_cache ,
1451- slot_indices = attn_metadata .
1452- slot_mapping [self .pcp_size *
1453- num_decode_tokens :attn_metadata .
1454- num_actual_tokens_pcp_padded ])
1458+ value = value [self .pcp_size *
1459+ num_decode_tokens :attn_metadata .
1460+ num_actual_tokens_pcp_padded ].contiguous (),
1461+ key_cache = self .key_cache ,
1462+ value_cache = self .value_cache ,
1463+ slot_mapping = attn_metadata .
1464+ slot_mapping [self .pcp_size *
1465+ num_decode_tokens :attn_metadata .
1466+ num_actual_tokens_pcp_padded ],
1467+ out = (self .key_cache , self .value_cache ))
1468+ else :
1469+ torch_npu ._npu_reshape_and_cache (
1470+ key = key [self .pcp_size *
1471+ num_decode_tokens :attn_metadata .
1472+ num_actual_tokens_pcp_padded ],
1473+ value = value [self .pcp_size *
1474+ num_decode_tokens :attn_metadata .
1475+ num_actual_tokens_pcp_padded ],
1476+ key_cache = self .key_cache ,
1477+ value_cache = self .value_cache ,
1478+ slot_indices = attn_metadata .
1479+ slot_mapping [self .pcp_size *
1480+ num_decode_tokens :attn_metadata .
1481+ num_actual_tokens_pcp_padded ])
14551482
14561483 forward_context : ForwardContext = get_forward_context ()
14571484 if not forward_context .capturing :
0 commit comments