@@ -1464,6 +1464,72 @@ def _load_kv_for_chunk(self, attn_metadata, kv_cache,
14641464 )
14651465 return key , value
14661466
1467+ def _forward_ascend_950 (self , query : torch .Tensor , key : torch .Tensor ,
1468+ value : torch .Tensor , attn_metadata : AscendMetadata ,
1469+ output : torch .Tensor ) -> torch .Tensor :
1470+ num_tokens = attn_metadata .query_start_loc [- 1 ]
1471+ if attn_metadata .attn_state == AscendAttentionState .PrefillNoCache :
1472+ output_data , _ = torch_npu .npu_fused_infer_attention_score_v2 (
1473+ query [:num_tokens ],
1474+ key [:num_tokens ],
1475+ value [:num_tokens ],
1476+ atten_mask = attn_metadata .attn_mask .to ( # type: ignore
1477+ torch .bool ),
1478+ actual_seq_qlen = attn_metadata .query_lens .cumsum (0 ),
1479+ actual_seq_kvlen = attn_metadata .seq_lens .cumsum (0 ),
1480+ num_query_heads = self .num_heads ,
1481+ num_key_value_heads = self .num_kv_heads ,
1482+ input_layout = "TND" ,
1483+ softmax_scale = self .scale )
1484+ else :
1485+ batch_size = attn_metadata .query_lens .shape [0 ]
1486+ block_table = attn_metadata .block_tables [:batch_size , :]
1487+ num_block , block_size , _ , _ = self .key_cache .shape # type: ignore
1488+ if attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
1489+ query = query [:batch_size ]
1490+ query = query .view (batch_size , 1 ,
1491+ self .num_heads * self .head_size )
1492+ key = self .key_cache .flatten (2 , 3 ).contiguous () # type: ignore
1493+ value = self .value_cache .flatten ( # type: ignore
1494+ 2 , 3 ).contiguous ()
1495+ atten_mask = None
1496+ actual_seq_qlen = None
1497+ actual_seq_kvlen = attn_metadata .seq_lens
1498+ sparse_mode = 0
1499+ input_layout = "BSH"
1500+ else :
1501+ query = query [:num_tokens ]
1502+ key = self .key_cache .view ( # type: ignore
1503+ num_block , block_size , - 1 )
1504+ value = self .value_cache .view ( # type: ignore
1505+ num_block , block_size , - 1 )
1506+ input_layout = "TND"
1507+ atten_mask = attn_metadata .attn_mask
1508+ actual_seq_qlen = attn_metadata .actual_seq_lengths_q
1509+ actual_seq_kvlen = attn_metadata .seq_lens_list
1510+ sparse_mode = 3
1511+ output_data , _ = torch_npu .npu_fused_infer_attention_score_v2 (
1512+ query = query ,
1513+ key = key ,
1514+ value = value ,
1515+ block_table = block_table ,
1516+ atten_mask = atten_mask ,
1517+ actual_seq_qlen = actual_seq_qlen ,
1518+ actual_seq_kvlen = actual_seq_kvlen ,
1519+ num_query_heads = self .num_heads ,
1520+ num_key_value_heads = self .num_kv_heads ,
1521+ softmax_scale = self .scale ,
1522+ sparse_mode = sparse_mode ,
1523+ block_size = block_size ,
1524+ input_layout = input_layout )
1525+ if attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
1526+ output [:batch_size ] = output_data .view (batch_size ,
1527+ self .num_heads ,
1528+ self .head_size )
1529+ else :
1530+ output [:num_tokens ] = output_data
1531+ return output
1532+
14671533 def forward (
14681534 self ,
14691535 layer : AttentionLayer ,
@@ -1523,12 +1589,19 @@ def forward(
15231589 if has_decode :
15241590 slot_mapping = attn_metadata .slot_mapping [:num_decode_tokens * self .pcp_size : self .pcp_size ] \
15251591 if self .pcp_size * self .dcp_size > 1 else attn_metadata .slot_mapping [:num_decode_tokens ]
1526- torch_npu ._npu_reshape_and_cache (
1527- key = key [:num_decode_tokens ],
1528- value = value [:num_decode_tokens ],
1529- key_cache = self .key_cache ,
1530- value_cache = self .value_cache ,
1531- slot_indices = slot_mapping )
1592+ if get_ascend_device_type () == AscendDeviceType ._910_95 :
1593+ torch_npu .npu_scatter_pa_kv_cache (
1594+ key = key [:num_decode_tokens ],
1595+ value = value [:num_decode_tokens ].contiguous (),
1596+ slot_mapping = slot_mapping ,
1597+ out = (self .key_cache , self .value_cache ))
1598+ else :
1599+ torch_npu ._npu_reshape_and_cache (
1600+ key = key [:num_decode_tokens ],
1601+ value = value [:num_decode_tokens ],
1602+ key_cache = self .key_cache ,
1603+ value_cache = self .value_cache ,
1604+ slot_indices = slot_mapping )
15321605
15331606 if has_prefill :
15341607 if self .pcp_size > 1 :
@@ -1542,22 +1615,40 @@ def forward(
15421615 key , value = all_kv .split ([self .head_size , self .head_size ],
15431616 dim = - 1 )
15441617
1545- torch_npu ._npu_reshape_and_cache (
1546- key = key [self .pcp_size * num_decode_tokens :attn_metadata .
1547- num_actual_tokens_pcp_padded ],
1548- value = value [self .pcp_size *
1618+ if get_ascend_device_type () == AscendDeviceType ._910_95 :
1619+ torch_npu .npu_scatter_pa_kv_cache (
1620+ key = key [self .pcp_size *
1621+ num_decode_tokens :attn_metadata .
1622+ num_actual_tokens_pcp_padded ],
1623+ value = value [self .pcp_size *
1624+ num_decode_tokens :attn_metadata .
1625+ num_actual_tokens_pcp_padded ].contiguous (),
1626+ slot_mapping = attn_metadata .
1627+ slot_mapping [self .pcp_size *
1628+ num_decode_tokens :attn_metadata .
1629+ num_actual_tokens_pcp_padded ],
1630+ out = (self .key_cache , self .value_cache ))
1631+ else :
1632+ torch_npu ._npu_reshape_and_cache (
1633+ key = key [self .pcp_size *
15491634 num_decode_tokens :attn_metadata .
15501635 num_actual_tokens_pcp_padded ],
1551- key_cache = self .key_cache ,
1552- value_cache = self .value_cache ,
1553- slot_indices = attn_metadata .
1554- slot_mapping [self .pcp_size *
1555- num_decode_tokens :attn_metadata .
1556- num_actual_tokens_pcp_padded ])
1636+ value = value [self .pcp_size *
1637+ num_decode_tokens :attn_metadata .
1638+ num_actual_tokens_pcp_padded ],
1639+ key_cache = self .key_cache ,
1640+ value_cache = self .value_cache ,
1641+ slot_indices = attn_metadata .
1642+ slot_mapping [self .pcp_size *
1643+ num_decode_tokens :attn_metadata .
1644+ num_actual_tokens_pcp_padded ])
15571645
15581646 forward_context : ForwardContext = get_forward_context ()
15591647 if not forward_context .capturing :
1560- if self .pcp_size * self .dcp_size > 1 :
1648+ if get_ascend_device_type () == AscendDeviceType ._910_95 :
1649+ intermediate_output = self ._forward_ascend_950 (
1650+ query , key , value , attn_metadata , output )
1651+ elif self .pcp_size * self .dcp_size > 1 :
15611652 intermediate_output = self ._forward_pcp_dcp (
15621653 query , key , value , kv_cache , attn_metadata , output )
15631654 elif attn_type == AttentionType .ENCODER_ONLY :
0 commit comments