@@ -714,10 +714,13 @@ def test_init(self, mock_distributed):
714714 def test_q_proj_and_k_up_proj (self , mock_distributed ):
715715 batch_size = 4
716716 x = torch .randn (batch_size , self .impl .num_heads , self .impl .qk_head_dim )
717- q_proj_output = torch .randn (batch_size , self .impl .num_heads , self .impl .qk_head_dim )
718- self .impl .q_proj .return_value = (q_proj_output ,)
717+ q_proj_output = torch .randn (batch_size , self .impl .num_heads ,
718+ self .impl .qk_head_dim )
719+ self .impl .q_proj .return_value = (q_proj_output , )
719720 if not hasattr (self .impl , 'W_UK_T' ) or self .impl .W_UK_T is None :
720- self .impl .W_UK_T = torch .randn (self .impl .num_heads , self .impl .qk_nope_head_dim , self .impl .kv_lora_rank )
721+ self .impl .W_UK_T = torch .randn (self .impl .num_heads ,
722+ self .impl .qk_nope_head_dim ,
723+ self .impl .kv_lora_rank )
721724 ql_nope , q_pe = self .impl ._q_proj_and_k_up_proj (x )
722725 assert ql_nope .shape [0 ] == batch_size
723726 assert ql_nope .shape [1 ] == self .impl .num_heads
@@ -733,7 +736,8 @@ def test_process_weights_after_loading(self, mock_distributed):
733736 apply = MagicMock ()
734737 quant_method .apply = apply
735738 layer .quant_method = quant_method
736- shape_0 = self .impl .num_heads * (self .impl .qk_nope_head_dim + self .impl .v_head_dim )
739+ shape_0 = self .impl .num_heads * (self .impl .qk_nope_head_dim +
740+ self .impl .v_head_dim )
737741 shape_1 = self .impl .kv_lora_rank
738742 layer .weight = torch .randn (shape_0 , shape_1 )
739743 self .impl .kv_b_proj = layer
@@ -753,15 +757,18 @@ def test_process_weights_after_loading(self, mock_distributed):
753757 def test_compute_prefill_context_none (self , mock_distributed ):
754758 batch_size = 4
755759 kv_cache = torch .randn (10 , 1 , 1 , 192 )
756- query = torch .randn (batch_size , self .impl .num_heads , self .impl .qk_head_dim )
760+ query = torch .randn (batch_size , self .impl .num_heads ,
761+ self .impl .qk_head_dim )
757762 metadata = MagicMock ()
758763 metadata .prefill = None
759764 prefix_out = torch .randn (2 , 16 , 128 )
760765 prefix_lse = torch .randn (2 , 16 , 8 )
761766 q_pe = query [..., self .impl .qk_nope_head_dim :]
762767 q_nope = query [..., :self .impl .qk_nope_head_dim ]
763768
764- out , lse = self .impl ._compute_prefill_context (q_nope , q_pe , kv_cache , 32 , metadata , prefix_out , prefix_lse )
769+ out , lse = self .impl ._compute_prefill_context (q_nope , q_pe , kv_cache ,
770+ 32 , metadata , prefix_out ,
771+ prefix_lse )
765772
766773 assert torch .equal (prefix_out , out )
767774 assert torch .equal (prefix_lse , lse )
@@ -801,7 +808,8 @@ def test_compute_prefill_context(self, mock_distributed):
801808 # Mock the two NPU ops inside the method
802809 with patch ("torch_npu.atb.npu_paged_cache_load" ) as mock_load , \
803810 patch ("torch_npu.atb.npu_ring_mla" ) as mock_ring :
804- out , lse = self .impl ._compute_prefill_context (q_nope , q_pe , kv_cache , 32 , meta , prefix_out , prefix_lse )
811+ out , lse = self .impl ._compute_prefill_context (
812+ q_nope , q_pe , kv_cache , 32 , meta , prefix_out , prefix_lse )
805813
806814 mock_load .assert_called_once ()
807815 mock_ring .assert_called_once ()
@@ -812,10 +820,14 @@ def test_compute_prefill_context(self, mock_distributed):
812820 def test_forward_decode_without_graph (self , mock_distributed ):
813821 num_tokens = 100
814822 block_size = 4
815- q_nope = torch .randn (num_tokens , self .impl .num_heads , self .impl .qk_nope_head_dim )
816- q_pe = torch .randn (num_tokens , self .impl .num_heads , self .impl .qk_rope_head_dim )
817- k_nope = torch .randn (num_tokens , self .impl .num_heads , self .impl .qk_nope_head_dim )
818- k_pe = torch .randn (num_tokens , self .impl .num_heads , self .impl .qk_rope_head_dim )
823+ q_nope = torch .randn (num_tokens , self .impl .num_heads ,
824+ self .impl .qk_nope_head_dim )
825+ q_pe = torch .randn (num_tokens , self .impl .num_heads ,
826+ self .impl .qk_rope_head_dim )
827+ k_nope = torch .randn (num_tokens , self .impl .num_heads ,
828+ self .impl .qk_nope_head_dim )
829+ k_pe = torch .randn (num_tokens , self .impl .num_heads ,
830+ self .impl .qk_rope_head_dim )
819831 metadata = MagicMock ()
820832 metadata .decode = MagicMock ()
821833 metadata .decode .block_table = MagicMock ()
@@ -824,10 +836,15 @@ def test_forward_decode_without_graph(self, mock_distributed):
824836 with patch ("torch_npu.npu_fused_infer_attention_score" ) as mock_score , \
825837 patch ("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj" ) as mock_up , \
826838 patch ('vllm_ascend.attention.mla_v1.get_forward_context' , return_value = MagicMock (capturing = False )):
827- mock_score .return_value = [torch .randn (num_tokens , self .impl .num_heads , self .impl .kv_lora_rank ), None ]
828- mock_up .return_value = torch .randn (num_tokens , self .impl .num_heads , self .impl .v_head_dim )
839+ mock_score .return_value = [
840+ torch .randn (num_tokens , self .impl .num_heads ,
841+ self .impl .kv_lora_rank ), None
842+ ]
843+ mock_up .return_value = torch .randn (num_tokens , self .impl .num_heads ,
844+ self .impl .v_head_dim )
829845
830- result = self .impl ._forward_decode (q_nope , q_pe , k_nope , k_pe , block_size , metadata )
846+ result = self .impl ._forward_decode (q_nope , q_pe , k_nope , k_pe ,
847+ block_size , metadata )
831848
832849 assert result .shape [0 ] == num_tokens
833850 assert result .shape [1 ] == self .impl .num_heads
@@ -855,21 +872,48 @@ def test_mla_preprocess(self, mock_distributed):
855872 attn_metadata .prefill .cos = torch .randn (2 , 64 )
856873 attn_metadata .prefill .sin = torch .randn (2 , 64 )
857874
858- self .impl .q_a_layernorm = MagicMock (return_value = torch .randn (attn_metadata .num_actual_tokens , self .impl .num_heads , self .impl .qk_rope_head_dim ))
859- self .impl .kv_a_proj_with_mqa = MagicMock (return_value = [torch .randn (num_prefill_tokens , self .impl .num_heads , self .impl .qk_rope_head_dim + self .impl .kv_lora_rank )])
860- self .impl .fused_qkv_a_proj = MagicMock (return_value = [torch .randn (num_prefill_tokens , self .impl .num_heads , self .impl .qk_rope_head_dim + self .impl .kv_lora_rank + self .impl .q_lora_rank )])
861- self .impl .q_proj = MagicMock (return_value = [torch .randn (num_prefill_tokens , self .impl .num_heads , self .impl .qk_head_dim )])
862- self .impl .kv_b_proj = MagicMock (return_value = [torch .randn (num_prefill_tokens , self .impl .num_heads , self .impl .v_head_dim + self .impl .qk_nope_head_dim )])
863- self .impl .rope_single = MagicMock (side_effect = lambda x , cos , sin : x )
864- self .impl .exec_kv_decode = MagicMock (return_value = [MagicMock (), MagicMock ()])
875+ self .impl .q_a_layernorm = MagicMock (return_value = torch .randn (
876+ attn_metadata .num_actual_tokens , self .impl .num_heads ,
877+ self .impl .qk_rope_head_dim ))
878+ self .impl .kv_a_proj_with_mqa = MagicMock (return_value = [
879+ torch .randn (
880+ num_prefill_tokens , self .impl .num_heads ,
881+ self .impl .qk_rope_head_dim + self .impl .kv_lora_rank )
882+ ])
883+ self .impl .fused_qkv_a_proj = MagicMock (return_value = [
884+ torch .randn (
885+ num_prefill_tokens , self .impl .num_heads ,
886+ self .impl .qk_rope_head_dim + self .impl .kv_lora_rank +
887+ self .impl .q_lora_rank )
888+ ])
889+ self .impl .q_proj = MagicMock (return_value = [
890+ torch .randn (num_prefill_tokens , self .impl .num_heads ,
891+ self .impl .qk_head_dim )
892+ ])
893+ self .impl .kv_b_proj = MagicMock (return_value = [
894+ torch .randn (num_prefill_tokens , self .impl .num_heads ,
895+ self .impl .v_head_dim + self .impl .qk_nope_head_dim )
896+ ])
897+ self .impl .rope_single = MagicMock (
898+ side_effect = lambda x , cos , sin : x )
899+ self .impl .exec_kv_decode = MagicMock (
900+ return_value = [MagicMock (), MagicMock ()])
865901 self .impl .exec_kv_prefill = MagicMock (return_value = [
866- torch .randn (num_prefill_tokens , self .impl .num_heads , self .impl .qk_rope_head_dim ),
867- torch .randn (num_prefill_tokens , self .impl .num_heads , self .impl .kv_lora_rank )
902+ torch .randn (num_prefill_tokens , self .impl .num_heads ,
903+ self .impl .qk_rope_head_dim ),
904+ torch .randn (num_prefill_tokens , self .impl .num_heads ,
905+ self .impl .kv_lora_rank )
868906 ])
869- self .impl ._q_proj_and_k_up_proj = MagicMock (return_value = [MagicMock (), MagicMock ()])
907+ self .impl ._q_proj_and_k_up_proj = MagicMock (
908+ return_value = [MagicMock (), MagicMock ()])
870909 self .impl .num_kv_heads = self .impl .num_heads
871910
872- decode_res , prefill_res = self .impl ._mla_preprocess ("mock_layer" , hidden_states , kv_cache , attn_metadata , need_gather_q_kv = False )
911+ decode_res , prefill_res = self .impl ._mla_preprocess (
912+ "mock_layer" ,
913+ hidden_states ,
914+ kv_cache ,
915+ attn_metadata ,
916+ need_gather_q_kv = False )
873917
874918 assert decode_res is not None
875919 assert prefill_res is not None
@@ -893,7 +937,8 @@ def test_exec_kv_prefill(self, mock_distributed):
893937 torch .randn (B , N , 1 , self .impl .qk_rope_head_dim ),
894938 torch .randn (B , N , 1 , self .impl .kv_lora_rank )
895939 ]
896- k_pe , k_nope = self .impl .exec_kv_prefill (kv_no_split , cos , sin , kv_cache , slots )
940+ k_pe , k_nope = self .impl .exec_kv_prefill (kv_no_split , cos , sin ,
941+ kv_cache , slots )
897942
898943 assert k_pe .shape [- 1 ] == self .impl .qk_rope_head_dim
899944 assert k_nope .shape [- 1 ] == self .impl .kv_lora_rank
@@ -916,7 +961,8 @@ def test_exec_kv_decode(self, mock_distributed):
916961 torch .randn (B , N , 1 , self .impl .qk_rope_head_dim ),
917962 torch .randn (B , N , 1 , self .impl .kv_lora_rank ), None , None
918963 ]
919- k_pe , k_nope = self .impl .exec_kv_decode (kv_no_split , cos , sin , kv_cache , slots )
964+ k_pe , k_nope = self .impl .exec_kv_decode (kv_no_split , cos , sin ,
965+ kv_cache , slots )
920966
921967 assert k_pe .shape [- 1 ] == self .impl .qk_rope_head_dim
922968 assert k_nope .shape [- 1 ] == self .impl .kv_lora_rank
@@ -942,9 +988,12 @@ def test_forward_decode(self, mock_distributed):
942988
943989 with patch ("torch_npu.npu_fused_infer_attention_score" ) as mock_score , \
944990 patch ('vllm_ascend.attention.mla_v1.get_forward_context' , return_value = MagicMock (capturing = False )):
945- mock_score .return_value = [torch .randn (B , N , self .impl .kv_lora_rank ), None ]
946- result = self .impl ._forward_decode (q_nope , q_pe , k_nope , k_pe , BS , attn_metadata )
991+ mock_score .return_value = [
992+ torch .randn (B , N , self .impl .kv_lora_rank ), None
993+ ]
994+ result = self .impl ._forward_decode (q_nope , q_pe , k_nope , k_pe , BS ,
995+ attn_metadata )
947996
948997 assert result .shape [0 ] == B
949998 assert result .shape [1 ] == N
950- assert result .shape [2 ] == HD
999+ assert result .shape [2 ] == HD
0 commit comments