Skip to content

Commit 0708877

Browse files
author
weijinqian_v1
committed
[Refactor] add fia_v3 attention & remove other attention operator.
Signed-off-by: weijinqian_v1 <[email protected]>
1 parent 02dd06c commit 0708877

File tree

1 file changed

+7
-165
lines changed

1 file changed

+7
-165
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 7 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,6 @@ def test_get_builder_cls(self):
2525
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
2626
AscendAttentionMetadataBuilder)
2727

28-
@patch('vllm_ascend.attention.attention_v1.get_ascend_device_type',
29-
return_value=AscendDeviceType._310P)
30-
def test_get_kv_cache_shape_310p(self, mock_soc_version):
31-
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
32-
self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16))
33-
3428
@patch('vllm_ascend.utils.get_ascend_device_type',
3529
return_value=AscendDeviceType._910_93)
3630
def test_get_kv_cache_shape_not_310p(self, mock_soc_version):
@@ -95,76 +89,6 @@ def test_reorder_batch(self):
9589

9690
self.assertFalse(result)
9791

98-
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
99-
@patch('torch_npu.npu_format_cast')
100-
@patch('vllm_ascend.utils.nd_to_nz_2d')
101-
@patch('vllm_ascend.utils.get_ascend_device_type',
102-
return_value=AscendDeviceType._310P)
103-
def test_build_prefill_no_cache(self, mock_soc_version, mock_nd_to_nz_2d,
104-
mock_npu_format_cast,
105-
mock_ascend_metadata):
106-
common_attn_metadata = AscendCommonAttentionMetadata(
107-
query_start_loc=torch.tensor([0, 3, 7]),
108-
query_start_loc_cpu=torch.tensor([0, 3, 7]),
109-
seq_lens_cpu=torch.tensor([5, 6]),
110-
num_reqs=2,
111-
num_actual_tokens=10,
112-
max_query_len=5,
113-
decode_token_per_req=torch.tensor([1, 1]),
114-
block_table_tensor=torch.zeros((10, 10)),
115-
slot_mapping=torch.tensor(range(20)),
116-
actual_seq_lengths_q=torch.tensor([0, 1]),
117-
positions=torch.tensor([10, 10]),
118-
attn_mask=torch.ones((10, 10)),
119-
spec_attn_mask=None,
120-
attn_state=AscendAttentionState.PrefillNoCache,
121-
num_computed_tokens_cpu=None,
122-
seq_lens=None)
123-
124-
mock_nz_tensor = MagicMock()
125-
mock_model = MagicMock()
126-
mock_nd_to_nz_2d.return_value = mock_nz_tensor
127-
mock_npu_format_cast.return_value = mock_nz_tensor
128-
129-
self.builder.build(1, common_attn_metadata, mock_model)
130-
131-
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
132-
@patch('torch_npu.npu_format_cast')
133-
@patch('vllm_ascend.utils.nd_to_nz_spec')
134-
@patch('vllm_ascend.utils.get_ascend_device_type',
135-
return_value=AscendDeviceType._310P)
136-
@patch('vllm_ascend.attention.attention_v1.AscendAttentionState')
137-
def test_build_chunked_prefill(self, mock_ascend_attention_state,
138-
mock_soc_version, mock_nd_to_nz_spec,
139-
mock_npu_format_cast, mock_ascend_metadata):
140-
common_attn_metadata = AscendCommonAttentionMetadata(
141-
query_start_loc=torch.tensor([0, 2, 5, 9]),
142-
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
143-
seq_lens_cpu=torch.tensor([4, 5, 6]),
144-
num_reqs=3,
145-
num_actual_tokens=15,
146-
max_query_len=6,
147-
decode_token_per_req=torch.tensor([1, 1, 1]),
148-
block_table_tensor=torch.zeros((10, 10)),
149-
slot_mapping=torch.tensor(range(20)),
150-
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
151-
positions=torch.tensor([10, 10]),
152-
attn_mask=torch.ones((15, 15)),
153-
spec_attn_mask=None,
154-
attn_state=AscendAttentionState.ChunkedPrefill,
155-
num_computed_tokens_cpu=None,
156-
seq_lens=None)
157-
158-
mock_ascend_attention_state = MagicMock()
159-
mock_ascend_attention_state.PrefillNoCache = 0
160-
161-
mock_nz_tensor = MagicMock()
162-
mock_model = MagicMock()
163-
mock_nd_to_nz_spec.return_value = mock_nz_tensor
164-
mock_npu_format_cast.return_value = mock_nz_tensor
165-
166-
self.builder.build(1, common_attn_metadata, mock_model)
167-
16892
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
16993
@patch('vllm_ascend.utils.get_ascend_device_type',
17094
return_value=AscendDeviceType._910_93)
@@ -297,8 +221,6 @@ def test_forward_prefill(self, mock_get_forward_context,
297221
key = torch.randn(10, 8 * 64)
298222
value = torch.randn(10, 8 * 64)
299223
kv_cache = torch.empty(2, 5, 128, 8, 64)
300-
output = torch.empty_like(query)
301-
302224
metadata = self.attn_metadata
303225
metadata.attn_state = AscendAttentionState.PrefillCacheHit
304226
metadata.attn_mask = torch.randn(1, 1, 10, 10)
@@ -307,17 +229,16 @@ def test_forward_prefill(self, mock_get_forward_context,
307229
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
308230
metadata.num_actual_tokens = 10
309231
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
310-
metadata.num_decodes = 0
311-
metadata.num_prefills = 10
312232
layer = self.layer_no_quant
313233

314234
mock_get_forward_context.return_value = MagicMock(capturing=False)
315-
mock_npu_fused_infer_attention_score.return_value = (output,
316-
torch.ones(
317-
10, 8, 64))
318-
319-
output = self.impl.forward(layer, query, key, value, kv_cache,
320-
metadata, output)
235+
output = self.impl.forward(layer,
236+
query,
237+
key,
238+
value,
239+
kv_cache,
240+
metadata,
241+
trace_flag=False)
321242

322243
mock_npu_fused_infer_attention_score.assert_called_once()
323244
assert output.shape == (10, 8 * 64)
@@ -422,85 +343,6 @@ def test_forward_decode_only_swa_seq_len_mismatch(
422343

423344
assert output.shape == (10, 8 * 64)
424345

425-
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
426-
@patch('vllm_ascend.utils.get_ascend_device_type',
427-
return_value=AscendDeviceType._910_93)
428-
@patch('torch_npu._npu_reshape_and_cache')
429-
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
430-
def test_forward_head_size_192(self, mock_vanilla_prefill,
431-
mock_npu_reshape_and_cache,
432-
mock_soc_version, mock_get_forward_context):
433-
"""Test forward pass when head_size is 192"""
434-
435-
self.impl.head_size = 192
436-
query = torch.randn(10, 8 * 192)
437-
key = torch.randn(10, 8 * 192)
438-
value = torch.randn(10, 8 * 192)
439-
kv_cache = torch.empty(2, 5, 128, 8, 192)
440-
output = torch.empty_like(query)
441-
442-
mock_get_forward_context.return_value = MagicMock(capturing=False)
443-
444-
metadata = self.attn_metadata
445-
metadata.attn_mask = torch.randn(1, 1, 10, 10)
446-
metadata.query_lens = torch.tensor([10])
447-
metadata.seq_lens = torch.tensor([10])
448-
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
449-
metadata.num_actual_tokens = 10
450-
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
451-
metadata.num_decodes = 10
452-
metadata.num_prefills = 0
453-
layer = self.layer_no_quant
454-
mock_vanilla_prefill.return_value = MagicMock()
455-
456-
output = self.impl_192.forward(layer, query, key, value, kv_cache,
457-
metadata, output)
458-
459-
mock_vanilla_prefill.assert_called_once()
460-
assert output.shape == (10, 8 * 192)
461-
462-
@patch('torch_npu.npu_format_cast')
463-
@patch('torch_npu._npu_reshape_and_cache')
464-
@patch('torch_npu.npu_fused_infer_attention_score')
465-
@patch('vllm_ascend.utils.get_ascend_device_type',
466-
return_value=AscendDeviceType._310P)
467-
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
468-
def test_forward_310p_device(self, mock_get_forward_context,
469-
mock_soc_version,
470-
mock_npu_fused_infer_attention_score,
471-
mock_npu_reshape_and_cache,
472-
mock_npu_format_cast):
473-
"""Test forward pass on 310P device"""
474-
query = torch.randn(10, 8 * 64)
475-
key = torch.randn(10, 8 * 64)
476-
value = torch.randn(10, 8 * 64)
477-
kv_cache = torch.empty(2, 5, 128, 8, 64)
478-
output = torch.empty_like(query)
479-
480-
metadata = self.attn_metadata
481-
metadata.attn_mask = torch.randn(1, 1, 10, 10)
482-
metadata.query_lens = torch.tensor([10])
483-
metadata.seq_lens = torch.tensor([10])
484-
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
485-
metadata.num_actual_tokens = 10
486-
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
487-
metadata.num_decodes = 0
488-
metadata.num_prefills = 10
489-
layer = self.layer_no_quant
490-
491-
mock_npu_format_cast.return_value = metadata.attn_mask
492-
493-
mock_get_forward_context.return_value = MagicMock(capturing=False)
494-
mock_npu_fused_infer_attention_score.return_value = (output,
495-
torch.ones(
496-
10, 8, 64))
497-
498-
output = self.impl.forward(layer, query, key, value, kv_cache,
499-
metadata, output)
500-
501-
mock_npu_fused_infer_attention_score.assert_called_once()
502-
assert output.shape == (10, 8 * 64)
503-
504346
@patch('torch_npu._npu_reshape_and_cache')
505347
def test_forward_raise_error(self, mock_paged_attention):
506348
query = torch.randn(10, 8 * 64)

0 commit comments

Comments
 (0)