5050 update_graph_params_workspaces )
5151from vllm_ascend .ops .attention import vanilla_chunked_prefill
5252from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
53- nd_to_nz_2d , nd_to_nz_spec ,
53+ is_Ascend950 , nd_to_nz_2d , nd_to_nz_spec ,
5454 prefill_context_parallel_enable ,
5555 weak_ref_tensors )
5656
@@ -1448,6 +1448,69 @@ def _load_kv_for_chunk(self, attn_metadata, kv_cache,
14481448 )
14491449 return key , value
14501450
1451+ def _forward_ascend_950 (self , query : torch .Tensor , key : torch .Tensor ,
1452+ value : torch .Tensor , attn_metadata : AscendMetadata ,
1453+ output : torch .Tensor ) -> torch .Tensor :
1454+ if attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
1455+ num_tokens = attn_metadata .query_start_loc [- 1 ]
1456+ output , _ = torch_npu .npu_fused_infer_attention_score_v2 (
1457+ query [:num_tokens ],
1458+ key [:num_tokens ],
1459+ value [:num_tokens ],
1460+ atten_mask = attn_metadata .attn_mask .to (torch .bool ),
1461+ actual_seq_qlen = attn_metadata .query_lens .cumsum (0 ),
1462+ actual_seq_kvlen = attn_metadata .seq_lens .cumsum (0 ),
1463+ num_query_heads = self .num_heads ,
1464+ num_key_value_heads = self .num_kv_heads ,
1465+ input_layout = "TND" ,
1466+ softmax_scale = self .scale )
1467+ return output [:num_tokens ]
1468+ else :
1469+ batch_size = attn_metadata .query_lens .shape [0 ]
1470+ block_table = attn_metadata .block_tables [:batch_size , :]
1471+ if attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
1472+ query = query .view (batch_size , 1 , self .num_heads * self .head_size )
1473+ key = self .key_cache .flatten (2 , 3 ).contiguous ()
1474+ value = self .value_cache .flatten (2 , 3 ).contiguous ()
1475+ atten_mask = None
1476+ actual_seq_qlen = None
1477+ actual_seq_kvlen = attn_metadata .seq_lens
1478+ sparse_mode = 0
1479+ input_layout = "BSH"
1480+ else :
1481+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
1482+ key = self .key_cache .view ( # type: ignore
1483+ num_block , block_size , - 1 )
1484+ value = self .value_cache .view ( # type: ignore
1485+ num_block , block_size , - 1 )
1486+ input_layout = "TND"
1487+ atten_mask = attn_metadata .attn_mask .to (torch .bool )
1488+ if attn_metadata .attn_state == AscendAttentionState .PrefillCacheHit :
1489+ actual_seq_qlen = attn_metadata .query_lens .cumsum (0 )
1490+ actual_seq_kvlen = attn_metadata .seq_lens
1491+ sparse_mode = 2
1492+ else :
1493+ actual_seq_qlen = attn_metadata .actual_seq_lengths_q
1494+ actual_seq_kvlen = attn_metadata .seq_lens_list
1495+ sparse_mode = 0
1496+ output , _ = torch_npu .npu_fused_infer_attention_score_v2 (
1497+ query = query ,
1498+ key = key ,
1499+ value = value ,
1500+ block_table = block_table ,
1501+ atten_mask = atten_mask ,
1502+ actual_seq_qlen = actual_seq_qlen ,
1503+ actual_seq_kvlen = actual_seq_kvlen ,
1504+ num_query_heads = self .num_heads ,
1505+ num_key_value_heads = self .num_kv_heads ,
1506+ softmax_scale = self .scale ,
1507+ sparse_mode = sparse_mode ,
1508+ block_size = block_size ,
1509+ input_layout = input_layout )
1510+ if attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
1511+ output = output .view (batch_size , self .num_heads , self .head_size )
1512+ return output
1513+
14511514 def forward (
14521515 self ,
14531516 layer : AttentionLayer ,
@@ -1507,12 +1570,20 @@ def forward(
15071570 if has_decode :
15081571 slot_mapping = attn_metadata .slot_mapping [:num_decode_tokens * self .pcp_size : self .pcp_size ] \
15091572 if self .pcp_size * self .dcp_size > 1 else attn_metadata .slot_mapping [:num_decode_tokens ]
1510- torch_npu ._npu_reshape_and_cache (
1511- key = key [:num_decode_tokens ],
1512- value = value [:num_decode_tokens ],
1513- key_cache = self .key_cache ,
1514- value_cache = self .value_cache ,
1515- slot_indices = slot_mapping )
1573+ if is_Ascend950 ():
1574+ num_tokens = slot_mapping .shape [0 ]
1575+ torch_npu .npu_scatter_pa_kv_cache (
1576+ key = key [:num_tokens ],
1577+ value = value [:num_tokens ].contiguous (),
1578+ slot_mapping = slot_mapping ,
1579+ out = (self .key_cache , self .value_cache ))
1580+ else :
1581+ torch_npu ._npu_reshape_and_cache (
1582+ key = key [:num_decode_tokens ],
1583+ value = value [:num_decode_tokens ],
1584+ key_cache = self .key_cache ,
1585+ value_cache = self .value_cache ,
1586+ slot_indices = slot_mapping )
15161587
15171588 if has_prefill :
15181589 if self .pcp_size > 1 :
@@ -1526,22 +1597,34 @@ def forward(
15261597 key , value = all_kv .split ([self .head_size , self .head_size ],
15271598 dim = - 1 )
15281599
1529- torch_npu ._npu_reshape_and_cache (
1530- key = key [self .pcp_size * num_decode_tokens :attn_metadata .
1531- num_actual_tokens_pcp_padded ],
1532- value = value [self .pcp_size *
1600+ if is_Ascend950 ():
1601+ num_tokens = attn_metadata .slot_mapping .shape [0 ]
1602+ torch_npu .npu_scatter_pa_kv_cache (
1603+ key = key [:num_tokens ],
1604+ value = value [:num_tokens ].contiguous (),
1605+ slot_mapping = attn_metadata .slot_mapping ,
1606+ out = (self .key_cache , self .value_cache ))
1607+ else :
1608+ torch_npu ._npu_reshape_and_cache (
1609+ key = key [self .pcp_size *
15331610 num_decode_tokens :attn_metadata .
15341611 num_actual_tokens_pcp_padded ],
1535- key_cache = self .key_cache ,
1536- value_cache = self .value_cache ,
1537- slot_indices = attn_metadata .
1538- slot_mapping [self .pcp_size *
1539- num_decode_tokens :attn_metadata .
1540- num_actual_tokens_pcp_padded ])
1612+ value = value [self .pcp_size *
1613+ num_decode_tokens :attn_metadata .
1614+ num_actual_tokens_pcp_padded ],
1615+ key_cache = self .key_cache ,
1616+ value_cache = self .value_cache ,
1617+ slot_indices = attn_metadata .
1618+ slot_mapping [self .pcp_size *
1619+ num_decode_tokens :attn_metadata .
1620+ num_actual_tokens_pcp_padded ])
15411621
15421622 forward_context : ForwardContext = get_forward_context ()
15431623 if not forward_context .capturing :
1544- if self .pcp_size * self .dcp_size > 1 :
1624+ if is_Ascend950 ():
1625+ intermediate_output = self ._forward_ascend_950 (
1626+ query , key , value , attn_metadata , output )
1627+ elif self .pcp_size * self .dcp_size > 1 :
15451628 intermediate_output = self ._forward_pcp_dcp (
15461629 query , key , value , kv_cache , attn_metadata , output )
15471630 elif attn_type == AttentionType .ENCODER_ONLY :
0 commit comments