5050 update_graph_params_workspaces )
5151from vllm_ascend .ops .attention import vanilla_chunked_prefill
5252from vllm_ascend .utils import (ACL_FORMAT_FRACTAL_NZ , aligned_16 , is_310p ,
53- nd_to_nz_2d , nd_to_nz_spec ,
53+ is_Ascend950 , nd_to_nz_2d , nd_to_nz_spec ,
5454 prefill_context_parallel_enable ,
5555 weak_ref_tensors )
5656
@@ -1448,6 +1448,72 @@ def _load_kv_for_chunk(self, attn_metadata, kv_cache,
14481448 )
14491449 return key , value
14501450
1451+ def _forward_ascend_950 (self , query : torch .Tensor , key : torch .Tensor ,
1452+ value : torch .Tensor , attn_metadata : AscendMetadata ,
1453+ output : torch .Tensor ) -> torch .Tensor :
1454+ num_tokens = attn_metadata .query_start_loc [- 1 ]
1455+ if attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
1456+ output , _ = torch_npu .npu_fused_infer_attention_score_v2 (
1457+ query [:num_tokens ],
1458+ key [:num_tokens ],
1459+ value [:num_tokens ],
1460+ atten_mask = attn_metadata .attn_mask .to (torch .bool ), # type: ignore
1461+ actual_seq_qlen = attn_metadata .query_lens .cumsum (0 ),
1462+ actual_seq_kvlen = attn_metadata .seq_lens .cumsum (0 ),
1463+ num_query_heads = self .num_heads ,
1464+ num_key_value_heads = self .num_kv_heads ,
1465+ input_layout = "TND" ,
1466+ softmax_scale = self .scale )
1467+ return output [:num_tokens ]
1468+ else :
1469+ batch_size = attn_metadata .query_lens .shape [0 ]
1470+ block_table = attn_metadata .block_tables [:batch_size , :]
1471+ if attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
1472+ query = query .view (batch_size , 1 ,
1473+ self .num_heads * self .head_size )
1474+ key = self .key_cache .flatten (2 , 3 ).contiguous () # type: ignore
1475+ value = self .value_cache .flatten (2 , 3 ).contiguous () # type: ignore
1476+ atten_mask = None
1477+ actual_seq_qlen = None
1478+ actual_seq_kvlen = attn_metadata .seq_lens
1479+ sparse_mode = 0
1480+ input_layout = "BSH"
1481+ else :
1482+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
1483+ key = self .key_cache .view ( # type: ignore
1484+ num_block , block_size , - 1 )
1485+ value = self .value_cache .view ( # type: ignore
1486+ num_block , block_size , - 1 )
1487+ input_layout = "TND"
1488+ atten_mask = attn_metadata .attn_mask .to (torch .bool ) # type: ignore
1489+ if attn_metadata .attn_state == AscendAttentionState .PrefillCacheHit :
1490+ actual_seq_qlen = attn_metadata .query_lens .cumsum (0 )
1491+ actual_seq_kvlen = attn_metadata .seq_lens
1492+ sparse_mode = 2
1493+ else :
1494+ query = query [:num_tokens ]
1495+ actual_seq_qlen = attn_metadata .actual_seq_lengths_q
1496+ actual_seq_kvlen = attn_metadata .seq_lens_list
1497+ sparse_mode = 0
1498+ output , _ = torch_npu .npu_fused_infer_attention_score_v2 (
1499+ query = query ,
1500+ key = key ,
1501+ value = value ,
1502+ block_table = block_table ,
1503+ atten_mask = atten_mask ,
1504+ actual_seq_qlen = actual_seq_qlen ,
1505+ actual_seq_kvlen = actual_seq_kvlen ,
1506+ num_query_heads = self .num_heads ,
1507+ num_key_value_heads = self .num_kv_heads ,
1508+ softmax_scale = self .scale ,
1509+ sparse_mode = sparse_mode ,
1510+ block_size = block_size ,
1511+ input_layout = input_layout )
1512+ if attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
1513+ output = output .view (batch_size , self .num_heads ,
1514+ self .head_size )
1515+ return output
1516+
14511517 def forward (
14521518 self ,
14531519 layer : AttentionLayer ,
@@ -1507,12 +1573,20 @@ def forward(
15071573 if has_decode :
15081574 slot_mapping = attn_metadata .slot_mapping [:num_decode_tokens * self .pcp_size : self .pcp_size ] \
15091575 if self .pcp_size * self .dcp_size > 1 else attn_metadata .slot_mapping [:num_decode_tokens ]
1510- torch_npu ._npu_reshape_and_cache (
1511- key = key [:num_decode_tokens ],
1512- value = value [:num_decode_tokens ],
1513- key_cache = self .key_cache ,
1514- value_cache = self .value_cache ,
1515- slot_indices = slot_mapping )
1576+ if is_Ascend950 ():
1577+ num_tokens = slot_mapping .shape [0 ]
1578+ torch_npu .npu_scatter_pa_kv_cache (
1579+ key = key [:num_tokens ],
1580+ value = value [:num_tokens ].contiguous (),
1581+ slot_mapping = slot_mapping ,
1582+ out = (self .key_cache , self .value_cache ))
1583+ else :
1584+ torch_npu ._npu_reshape_and_cache (
1585+ key = key [:num_decode_tokens ],
1586+ value = value [:num_decode_tokens ],
1587+ key_cache = self .key_cache ,
1588+ value_cache = self .value_cache ,
1589+ slot_indices = slot_mapping )
15161590
15171591 if has_prefill :
15181592 if self .pcp_size > 1 :
@@ -1526,22 +1600,34 @@ def forward(
15261600 key , value = all_kv .split ([self .head_size , self .head_size ],
15271601 dim = - 1 )
15281602
1529- torch_npu ._npu_reshape_and_cache (
1530- key = key [self .pcp_size * num_decode_tokens :attn_metadata .
1531- num_actual_tokens_pcp_padded ],
1532- value = value [self .pcp_size *
1603+ if is_Ascend950 ():
1604+ num_tokens = attn_metadata .slot_mapping .shape [0 ]
1605+ torch_npu .npu_scatter_pa_kv_cache (
1606+ key = key [:num_tokens ],
1607+ value = value [:num_tokens ].contiguous (),
1608+ slot_mapping = attn_metadata .slot_mapping ,
1609+ out = (self .key_cache , self .value_cache ))
1610+ else :
1611+ torch_npu ._npu_reshape_and_cache (
1612+ key = key [self .pcp_size *
15331613 num_decode_tokens :attn_metadata .
15341614 num_actual_tokens_pcp_padded ],
1535- key_cache = self .key_cache ,
1536- value_cache = self .value_cache ,
1537- slot_indices = attn_metadata .
1538- slot_mapping [self .pcp_size *
1539- num_decode_tokens :attn_metadata .
1540- num_actual_tokens_pcp_padded ])
1615+ value = value [self .pcp_size *
1616+ num_decode_tokens :attn_metadata .
1617+ num_actual_tokens_pcp_padded ],
1618+ key_cache = self .key_cache ,
1619+ value_cache = self .value_cache ,
1620+ slot_indices = attn_metadata .
1621+ slot_mapping [self .pcp_size *
1622+ num_decode_tokens :attn_metadata .
1623+ num_actual_tokens_pcp_padded ])
15411624
15421625 forward_context : ForwardContext = get_forward_context ()
15431626 if not forward_context .capturing :
1544- if self .pcp_size * self .dcp_size > 1 :
1627+ if is_Ascend950 ():
1628+ intermediate_output = self ._forward_ascend_950 (
1629+ query , key , value , attn_metadata , output )
1630+ elif self .pcp_size * self .dcp_size > 1 :
15451631 intermediate_output = self ._forward_pcp_dcp (
15461632 query , key , value , kv_cache , attn_metadata , output )
15471633 elif attn_type == AttentionType .ENCODER_ONLY :
0 commit comments