Skip to content

Commit 31fc2ca

Browse files
committed
fix bug
Signed-off-by: zzhx1 <[email protected]>
1 parent 3d9b21b commit 31fc2ca

File tree

1 file changed

+79
-30
lines changed

1 file changed

+79
-30
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -714,10 +714,13 @@ def test_init(self, mock_distributed):
714714
def test_q_proj_and_k_up_proj(self, mock_distributed):
715715
batch_size = 4
716716
x = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim)
717-
q_proj_output = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim)
718-
self.impl.q_proj.return_value = (q_proj_output,)
717+
q_proj_output = torch.randn(batch_size, self.impl.num_heads,
718+
self.impl.qk_head_dim)
719+
self.impl.q_proj.return_value = (q_proj_output, )
719720
if not hasattr(self.impl, 'W_UK_T') or self.impl.W_UK_T is None:
720-
self.impl.W_UK_T = torch.randn(self.impl.num_heads, self.impl.qk_nope_head_dim, self.impl.kv_lora_rank)
721+
self.impl.W_UK_T = torch.randn(self.impl.num_heads,
722+
self.impl.qk_nope_head_dim,
723+
self.impl.kv_lora_rank)
721724
ql_nope, q_pe = self.impl._q_proj_and_k_up_proj(x)
722725
assert ql_nope.shape[0] == batch_size
723726
assert ql_nope.shape[1] == self.impl.num_heads
@@ -733,7 +736,8 @@ def test_process_weights_after_loading(self, mock_distributed):
733736
apply = MagicMock()
734737
quant_method.apply = apply
735738
layer.quant_method = quant_method
736-
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim + self.impl.v_head_dim)
739+
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim +
740+
self.impl.v_head_dim)
737741
shape_1 = self.impl.kv_lora_rank
738742
layer.weight = torch.randn(shape_0, shape_1)
739743
self.impl.kv_b_proj = layer
@@ -753,15 +757,18 @@ def test_process_weights_after_loading(self, mock_distributed):
753757
def test_compute_prefill_context_none(self, mock_distributed):
754758
batch_size = 4
755759
kv_cache = torch.randn(10, 1, 1, 192)
756-
query = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim)
760+
query = torch.randn(batch_size, self.impl.num_heads,
761+
self.impl.qk_head_dim)
757762
metadata = MagicMock()
758763
metadata.prefill = None
759764
prefix_out = torch.randn(2, 16, 128)
760765
prefix_lse = torch.randn(2, 16, 8)
761766
q_pe = query[..., self.impl.qk_nope_head_dim:]
762767
q_nope = query[..., :self.impl.qk_nope_head_dim]
763768

764-
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache, 32, metadata, prefix_out, prefix_lse)
769+
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
770+
32, metadata, prefix_out,
771+
prefix_lse)
765772

766773
assert torch.equal(prefix_out, out)
767774
assert torch.equal(prefix_lse, lse)
@@ -801,7 +808,8 @@ def test_compute_prefill_context(self, mock_distributed):
801808
# Mock the two NPU ops inside the method
802809
with patch("torch_npu.atb.npu_paged_cache_load") as mock_load, \
803810
patch("torch_npu.atb.npu_ring_mla") as mock_ring:
804-
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache, 32, meta, prefix_out, prefix_lse)
811+
out, lse = self.impl._compute_prefill_context(
812+
q_nope, q_pe, kv_cache, 32, meta, prefix_out, prefix_lse)
805813

806814
mock_load.assert_called_once()
807815
mock_ring.assert_called_once()
@@ -812,10 +820,14 @@ def test_compute_prefill_context(self, mock_distributed):
812820
def test_forward_decode_without_graph(self, mock_distributed):
813821
num_tokens = 100
814822
block_size = 4
815-
q_nope = torch.randn(num_tokens, self.impl.num_heads, self.impl.qk_nope_head_dim)
816-
q_pe = torch.randn(num_tokens, self.impl.num_heads, self.impl.qk_rope_head_dim)
817-
k_nope = torch.randn(num_tokens, self.impl.num_heads, self.impl.qk_nope_head_dim)
818-
k_pe = torch.randn(num_tokens, self.impl.num_heads, self.impl.qk_rope_head_dim)
823+
q_nope = torch.randn(num_tokens, self.impl.num_heads,
824+
self.impl.qk_nope_head_dim)
825+
q_pe = torch.randn(num_tokens, self.impl.num_heads,
826+
self.impl.qk_rope_head_dim)
827+
k_nope = torch.randn(num_tokens, self.impl.num_heads,
828+
self.impl.qk_nope_head_dim)
829+
k_pe = torch.randn(num_tokens, self.impl.num_heads,
830+
self.impl.qk_rope_head_dim)
819831
metadata = MagicMock()
820832
metadata.decode = MagicMock()
821833
metadata.decode.block_table = MagicMock()
@@ -824,10 +836,15 @@ def test_forward_decode_without_graph(self, mock_distributed):
824836
with patch("torch_npu.npu_fused_infer_attention_score") as mock_score, \
825837
patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj") as mock_up, \
826838
patch('vllm_ascend.attention.mla_v1.get_forward_context', return_value=MagicMock(capturing=False)):
827-
mock_score.return_value = [torch.randn(num_tokens, self.impl.num_heads, self.impl.kv_lora_rank), None]
828-
mock_up.return_value = torch.randn(num_tokens, self.impl.num_heads, self.impl.v_head_dim)
839+
mock_score.return_value = [
840+
torch.randn(num_tokens, self.impl.num_heads,
841+
self.impl.kv_lora_rank), None
842+
]
843+
mock_up.return_value = torch.randn(num_tokens, self.impl.num_heads,
844+
self.impl.v_head_dim)
829845

830-
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, block_size, metadata)
846+
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe,
847+
block_size, metadata)
831848

832849
assert result.shape[0] == num_tokens
833850
assert result.shape[1] == self.impl.num_heads
@@ -855,21 +872,48 @@ def test_mla_preprocess(self, mock_distributed):
855872
attn_metadata.prefill.cos = torch.randn(2, 64)
856873
attn_metadata.prefill.sin = torch.randn(2, 64)
857874

858-
self.impl.q_a_layernorm = MagicMock(return_value=torch.randn(attn_metadata.num_actual_tokens, self.impl.num_heads, self.impl.qk_rope_head_dim))
859-
self.impl.kv_a_proj_with_mqa = MagicMock(return_value=[torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.qk_rope_head_dim + self.impl.kv_lora_rank)])
860-
self.impl.fused_qkv_a_proj = MagicMock(return_value=[torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.qk_rope_head_dim + self.impl.kv_lora_rank + self.impl.q_lora_rank)])
861-
self.impl.q_proj = MagicMock(return_value=[torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.qk_head_dim)])
862-
self.impl.kv_b_proj = MagicMock(return_value=[torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.v_head_dim + self.impl.qk_nope_head_dim)])
863-
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
864-
self.impl.exec_kv_decode = MagicMock(return_value=[MagicMock(), MagicMock()])
875+
self.impl.q_a_layernorm = MagicMock(return_value=torch.randn(
876+
attn_metadata.num_actual_tokens, self.impl.num_heads,
877+
self.impl.qk_rope_head_dim))
878+
self.impl.kv_a_proj_with_mqa = MagicMock(return_value=[
879+
torch.randn(
880+
num_prefill_tokens, self.impl.num_heads,
881+
self.impl.qk_rope_head_dim + self.impl.kv_lora_rank)
882+
])
883+
self.impl.fused_qkv_a_proj = MagicMock(return_value=[
884+
torch.randn(
885+
num_prefill_tokens, self.impl.num_heads,
886+
self.impl.qk_rope_head_dim + self.impl.kv_lora_rank +
887+
self.impl.q_lora_rank)
888+
])
889+
self.impl.q_proj = MagicMock(return_value=[
890+
torch.randn(num_prefill_tokens, self.impl.num_heads,
891+
self.impl.qk_head_dim)
892+
])
893+
self.impl.kv_b_proj = MagicMock(return_value=[
894+
torch.randn(num_prefill_tokens, self.impl.num_heads,
895+
self.impl.v_head_dim + self.impl.qk_nope_head_dim)
896+
])
897+
self.impl.rope_single = MagicMock(
898+
side_effect=lambda x, cos, sin: x)
899+
self.impl.exec_kv_decode = MagicMock(
900+
return_value=[MagicMock(), MagicMock()])
865901
self.impl.exec_kv_prefill = MagicMock(return_value=[
866-
torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.qk_rope_head_dim),
867-
torch.randn(num_prefill_tokens, self.impl.num_heads, self.impl.kv_lora_rank)
902+
torch.randn(num_prefill_tokens, self.impl.num_heads,
903+
self.impl.qk_rope_head_dim),
904+
torch.randn(num_prefill_tokens, self.impl.num_heads,
905+
self.impl.kv_lora_rank)
868906
])
869-
self.impl._q_proj_and_k_up_proj = MagicMock(return_value=[MagicMock(), MagicMock()])
907+
self.impl._q_proj_and_k_up_proj = MagicMock(
908+
return_value=[MagicMock(), MagicMock()])
870909
self.impl.num_kv_heads = self.impl.num_heads
871910

872-
decode_res, prefill_res = self.impl._mla_preprocess("mock_layer", hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False)
911+
decode_res, prefill_res = self.impl._mla_preprocess(
912+
"mock_layer",
913+
hidden_states,
914+
kv_cache,
915+
attn_metadata,
916+
need_gather_q_kv=False)
873917

874918
assert decode_res is not None
875919
assert prefill_res is not None
@@ -893,7 +937,8 @@ def test_exec_kv_prefill(self, mock_distributed):
893937
torch.randn(B, N, 1, self.impl.qk_rope_head_dim),
894938
torch.randn(B, N, 1, self.impl.kv_lora_rank)
895939
]
896-
k_pe, k_nope = self.impl.exec_kv_prefill(kv_no_split, cos, sin, kv_cache, slots)
940+
k_pe, k_nope = self.impl.exec_kv_prefill(kv_no_split, cos, sin,
941+
kv_cache, slots)
897942

898943
assert k_pe.shape[-1] == self.impl.qk_rope_head_dim
899944
assert k_nope.shape[-1] == self.impl.kv_lora_rank
@@ -916,7 +961,8 @@ def test_exec_kv_decode(self, mock_distributed):
916961
torch.randn(B, N, 1, self.impl.qk_rope_head_dim),
917962
torch.randn(B, N, 1, self.impl.kv_lora_rank), None, None
918963
]
919-
k_pe, k_nope = self.impl.exec_kv_decode(kv_no_split, cos, sin, kv_cache, slots)
964+
k_pe, k_nope = self.impl.exec_kv_decode(kv_no_split, cos, sin,
965+
kv_cache, slots)
920966

921967
assert k_pe.shape[-1] == self.impl.qk_rope_head_dim
922968
assert k_nope.shape[-1] == self.impl.kv_lora_rank
@@ -942,9 +988,12 @@ def test_forward_decode(self, mock_distributed):
942988

943989
with patch("torch_npu.npu_fused_infer_attention_score") as mock_score, \
944990
patch('vllm_ascend.attention.mla_v1.get_forward_context', return_value=MagicMock(capturing=False)):
945-
mock_score.return_value = [torch.randn(B, N, self.impl.kv_lora_rank), None]
946-
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS, attn_metadata)
991+
mock_score.return_value = [
992+
torch.randn(B, N, self.impl.kv_lora_rank), None
993+
]
994+
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
995+
attn_metadata)
947996

948997
assert result.shape[0] == B
949998
assert result.shape[1] == N
950-
assert result.shape[2] == HD
999+
assert result.shape[2] == HD

0 commit comments

Comments
 (0)