@@ -623,11 +623,8 @@ def test_exec_kv_decode(self, mock_kv_rmsnorm_rope_cache):
623623 self .assertEqual (k_nope .shape [- 1 ], self .impl .kv_lora_rank )
624624
625625 @patch ('vllm_ascend.attention.mla_v1.get_forward_context' )
626- @patch ("torch.npu.stream" )
627- @patch ("vllm_ascend.attention.mla_v1.get_multistream_comm_context" )
628626 @patch ("torch_npu.npu_fused_infer_attention_score" )
629627 def test_forward_decode (self , mock_npu_fused_infer_attention_score ,
630- mock_get_multistream_comm_context , mock_npu_stream ,
631628 mock_get_forward_context ):
632629 B = 2
633630 N = self .impl .num_kv_heads
@@ -651,27 +648,10 @@ def test_forward_decode(self, mock_npu_fused_infer_attention_score,
651648 mock_npu_fused_infer_attention_score .return_value = [
652649 torch .randn (B , N , self .impl .kv_lora_rank ), None
653650 ]
654- mock_get_multistream_comm_context .return_value = None
655-
656651 mock_get_forward_context .return_value = MagicMock (capturing = False )
657652 result = self .impl ._forward_decode (q_nope , q_pe , k_nope , k_pe , BS ,
658653 attn_metadata )
659654
660655 self .assertEqual (result .shape [0 ], B )
661656 self .assertEqual (result .shape [1 ], N )
662657 self .assertEqual (result .shape [2 ], HD )
663-
664- self .impl .enable_kv_nz = False
665- attn_metadata .attn_state = None
666- mock_return_value = MagicMock ()
667- mock_get_multistream_comm_context .return_value = mock_return_value
668- mock_return_value .before_comm_event = MagicMock ()
669- mock_return_value .comm_stream = MagicMock ()
670- mock_npu_stream .return_value = MagicMock ()
671-
672- result = self .impl ._forward_decode (q_nope , q_pe , k_nope , k_pe , BS ,
673- attn_metadata )
674-
675- self .assertEqual (result .shape [0 ], B )
676- self .assertEqual (result .shape [1 ], N )
677- self .assertEqual (result .shape [2 ], HD )
0 commit comments