Skip to content

Commit b4bf01e

Browse files
weijinqian0weijinqian_v1
andauthored
[Refactor] Remove redundant attention operator branches. (#4531)
[Refactor] Remove redundant attention operator branches. Reason: We replace other attention ops with fused_infer_attention_score expect decode_only state. clean code and remove 310P support. #4455 - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: weijinqian_v1 <[email protected]> Co-authored-by: weijinqian_v1 <[email protected]>
1 parent 981a14f commit b4bf01e

File tree

3 files changed

+119
-470
lines changed

3 files changed

+119
-470
lines changed

tests/ut/attention/test_attention_v1.py

Lines changed: 12 additions & 234 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,6 @@ def test_get_builder_cls(self):
2525
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
2626
AscendAttentionMetadataBuilder)
2727

28-
@patch('vllm_ascend.attention.attention_v1.get_ascend_device_type',
29-
return_value=AscendDeviceType._310P)
30-
def test_get_kv_cache_shape_310p(self, mock_soc_version):
31-
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
32-
self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16))
33-
3428
@patch('vllm_ascend.utils.get_ascend_device_type',
3529
return_value=AscendDeviceType._910_93)
3630
def test_get_kv_cache_shape_not_310p(self, mock_soc_version):
@@ -95,76 +89,6 @@ def test_reorder_batch(self):
9589

9690
self.assertFalse(result)
9791

98-
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
99-
@patch('torch_npu.npu_format_cast')
100-
@patch('vllm_ascend.utils.nd_to_nz_2d')
101-
@patch('vllm_ascend.utils.get_ascend_device_type',
102-
return_value=AscendDeviceType._310P)
103-
def test_build_prefill_no_cache(self, mock_soc_version, mock_nd_to_nz_2d,
104-
mock_npu_format_cast,
105-
mock_ascend_metadata):
106-
common_attn_metadata = AscendCommonAttentionMetadata(
107-
query_start_loc=torch.tensor([0, 3, 7]),
108-
query_start_loc_cpu=torch.tensor([0, 3, 7]),
109-
seq_lens_cpu=torch.tensor([5, 6]),
110-
num_reqs=2,
111-
num_actual_tokens=10,
112-
max_query_len=5,
113-
decode_token_per_req=torch.tensor([1, 1]),
114-
block_table_tensor=torch.zeros((10, 10)),
115-
slot_mapping=torch.tensor(range(20)),
116-
actual_seq_lengths_q=torch.tensor([0, 1]),
117-
positions=torch.tensor([10, 10]),
118-
attn_mask=torch.ones((10, 10)),
119-
spec_attn_mask=None,
120-
attn_state=AscendAttentionState.PrefillNoCache,
121-
num_computed_tokens_cpu=None,
122-
seq_lens=None)
123-
124-
mock_nz_tensor = MagicMock()
125-
mock_model = MagicMock()
126-
mock_nd_to_nz_2d.return_value = mock_nz_tensor
127-
mock_npu_format_cast.return_value = mock_nz_tensor
128-
129-
self.builder.build(1, common_attn_metadata, mock_model)
130-
131-
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
132-
@patch('torch_npu.npu_format_cast')
133-
@patch('vllm_ascend.utils.nd_to_nz_spec')
134-
@patch('vllm_ascend.utils.get_ascend_device_type',
135-
return_value=AscendDeviceType._310P)
136-
@patch('vllm_ascend.attention.attention_v1.AscendAttentionState')
137-
def test_build_chunked_prefill(self, mock_ascend_attention_state,
138-
mock_soc_version, mock_nd_to_nz_spec,
139-
mock_npu_format_cast, mock_ascend_metadata):
140-
common_attn_metadata = AscendCommonAttentionMetadata(
141-
query_start_loc=torch.tensor([0, 2, 5, 9]),
142-
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
143-
seq_lens_cpu=torch.tensor([4, 5, 6]),
144-
num_reqs=3,
145-
num_actual_tokens=15,
146-
max_query_len=6,
147-
decode_token_per_req=torch.tensor([1, 1, 1]),
148-
block_table_tensor=torch.zeros((10, 10)),
149-
slot_mapping=torch.tensor(range(20)),
150-
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
151-
positions=torch.tensor([10, 10]),
152-
attn_mask=torch.ones((15, 15)),
153-
spec_attn_mask=None,
154-
attn_state=AscendAttentionState.ChunkedPrefill,
155-
num_computed_tokens_cpu=None,
156-
seq_lens=None)
157-
158-
mock_ascend_attention_state = MagicMock()
159-
mock_ascend_attention_state.PrefillNoCache = 0
160-
161-
mock_nz_tensor = MagicMock()
162-
mock_model = MagicMock()
163-
mock_nd_to_nz_spec.return_value = mock_nz_tensor
164-
mock_npu_format_cast.return_value = mock_nz_tensor
165-
166-
self.builder.build(1, common_attn_metadata, mock_model)
167-
16892
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
16993
@patch('vllm_ascend.utils.get_ascend_device_type',
17094
return_value=AscendDeviceType._910_93)
@@ -286,73 +210,40 @@ def test_forward_no_attn_metadata(self):
286210

287211
assert output.shape == (10, 8 * 64)
288212

289-
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
290-
@patch('torch_npu._npu_reshape_and_cache')
291-
@patch('torch_npu._npu_flash_attention')
292-
def test_forward_prefill_no_cache(self, mock_flash_attention,
293-
mock_reshape_cache,
294-
mock_get_forward_context):
295-
"""Test forward pass in PrefillNoCache state"""
296-
query = torch.randn(10, 8 * 64)
297-
key = torch.randn(10, 8 * 64)
298-
value = torch.randn(10, 8 * 64)
299-
kv_cache = torch.empty(2, 5, 128, 8, 64)
300-
output = torch.empty_like(query)
301-
302-
mock_get_forward_context.return_value = MagicMock(capturing=False)
303-
304-
metadata = self.attn_metadata
305-
metadata.attn_state = AscendAttentionState.PrefillNoCache
306-
metadata.attn_mask = torch.randn(1, 1, 10, 10)
307-
metadata.seq_lens = torch.tensor([10])
308-
metadata.num_actual_tokens = 10
309-
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
310-
metadata.num_decodes = 0
311-
metadata.num_prefills = 10
312-
layer = self.layer_no_quant
313-
314-
output = self.impl.forward(layer, query, key, value, kv_cache,
315-
metadata, output)
316-
317-
mock_reshape_cache.assert_called_once()
318-
mock_flash_attention.assert_called_once()
319-
assert output.shape == (10, 8 * 64)
320-
321213
@patch('torch_npu._npu_reshape_and_cache')
322214
@patch('torch_npu.npu_fused_infer_attention_score')
323215
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
324-
def test_forward_prefill_cache_hit(self, mock_get_forward_context,
325-
mock_npu_fused_infer_attention_score,
326-
mock_npu_reshape_and_cache):
216+
def test_forward_prefill(self, mock_get_forward_context,
217+
mock_npu_fused_infer_attention_score,
218+
mock_npu_reshape_and_cache):
327219
"""Test forward pass in PrefillCacheHit state"""
328-
query = torch.randn(10, 8 * 64)
329-
key = torch.randn(10, 8 * 64)
330-
value = torch.randn(10, 8 * 64)
220+
query = torch.randn(10, 8, 64)
221+
key = torch.randn(10, 8, 64)
222+
value = torch.randn(10, 8, 64)
331223
kv_cache = torch.empty(2, 5, 128, 8, 64)
332224
output = torch.empty_like(query)
333-
334225
metadata = self.attn_metadata
335226
metadata.attn_state = AscendAttentionState.PrefillCacheHit
336227
metadata.attn_mask = torch.randn(1, 1, 10, 10)
337228
metadata.query_lens = torch.tensor([10])
338229
metadata.seq_lens = torch.tensor([10])
230+
metadata.actual_seq_lengths_q = [10]
339231
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
340232
metadata.num_actual_tokens = 10
341-
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
233+
metadata.num_decode_tokens = 0
342234
metadata.num_decodes = 0
343235
metadata.num_prefills = 10
236+
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
344237
layer = self.layer_no_quant
345238

346239
mock_get_forward_context.return_value = MagicMock(capturing=False)
347-
mock_npu_fused_infer_attention_score.return_value = (output,
348-
torch.ones(
349-
10, 8, 64))
350-
240+
mock_npu_fused_infer_attention_score.return_value = (torch.ones(
241+
10, 8, 64), torch.ones(10, 8, 64))
351242
output = self.impl.forward(layer, query, key, value, kv_cache,
352243
metadata, output)
353244

354245
mock_npu_fused_infer_attention_score.assert_called_once()
355-
assert output.shape == (10, 8 * 64)
246+
assert output.shape == (10, 8, 64)
356247

357248
@patch('torch_npu._npu_paged_attention')
358249
@patch('torch_npu._npu_reshape_and_cache')
@@ -454,119 +345,6 @@ def test_forward_decode_only_swa_seq_len_mismatch(
454345

455346
assert output.shape == (10, 8 * 64)
456347

457-
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
458-
@patch('vllm_ascend.utils.get_ascend_device_type',
459-
return_value=AscendDeviceType._910_93)
460-
@patch('torch_npu._npu_reshape_and_cache')
461-
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
462-
def test_forward_head_size_192(self, mock_vanilla_prefill,
463-
mock_npu_reshape_and_cache,
464-
mock_soc_version, mock_get_forward_context):
465-
"""Test forward pass when head_size is 192"""
466-
467-
self.impl.head_size = 192
468-
query = torch.randn(10, 8 * 192)
469-
key = torch.randn(10, 8 * 192)
470-
value = torch.randn(10, 8 * 192)
471-
kv_cache = torch.empty(2, 5, 128, 8, 192)
472-
output = torch.empty_like(query)
473-
474-
mock_get_forward_context.return_value = MagicMock(capturing=False)
475-
476-
metadata = self.attn_metadata
477-
metadata.attn_mask = torch.randn(1, 1, 10, 10)
478-
metadata.query_lens = torch.tensor([10])
479-
metadata.seq_lens = torch.tensor([10])
480-
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
481-
metadata.num_actual_tokens = 10
482-
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
483-
metadata.num_decodes = 10
484-
metadata.num_prefills = 0
485-
layer = self.layer_no_quant
486-
mock_vanilla_prefill.return_value = MagicMock()
487-
488-
output = self.impl_192.forward(layer, query, key, value, kv_cache,
489-
metadata, output)
490-
491-
mock_vanilla_prefill.assert_called_once()
492-
assert output.shape == (10, 8 * 192)
493-
494-
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
495-
@patch('torch_npu.npu_fused_infer_attention_score')
496-
@patch('torch_npu._npu_reshape_and_cache')
497-
def test_forward_normal_v1_situation(self, mock_npu_reshape_and_cache,
498-
mock_npu_fused_infer_attention_score,
499-
mock_get_forward_context):
500-
"""Test forward pass in normal V1 situation"""
501-
query = torch.randn(10, 8 * 64)
502-
key = torch.randn(10, 8 * 64)
503-
value = torch.randn(10, 8 * 64)
504-
kv_cache = torch.empty(2, 5, 128, 8, 64)
505-
output = torch.empty_like(query)
506-
507-
metadata = self.attn_metadata
508-
metadata.attn_mask = torch.randn(1, 1, 10, 10)
509-
metadata.query_lens = torch.tensor([10])
510-
metadata.seq_lens = torch.tensor([10])
511-
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
512-
metadata.num_actual_tokens = 10
513-
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
514-
metadata.num_decodes = 0
515-
metadata.num_prefills = 10
516-
layer = self.layer_no_quant
517-
mock_get_forward_context.return_value = MagicMock(capturing=False)
518-
mock_npu_fused_infer_attention_score.return_value = (output,
519-
torch.ones(
520-
10, 8, 64))
521-
522-
output = self.impl.forward(layer, query, key, value, kv_cache,
523-
metadata, output)
524-
525-
mock_npu_fused_infer_attention_score.assert_called_once()
526-
assert output.shape == (10, 8 * 64)
527-
528-
@patch('torch_npu.npu_format_cast')
529-
@patch('torch_npu._npu_reshape_and_cache')
530-
@patch('torch_npu.npu_fused_infer_attention_score')
531-
@patch('vllm_ascend.utils.get_ascend_device_type',
532-
return_value=AscendDeviceType._310P)
533-
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
534-
def test_forward_310p_device(self, mock_get_forward_context,
535-
mock_soc_version,
536-
mock_npu_fused_infer_attention_score,
537-
mock_npu_reshape_and_cache,
538-
mock_npu_format_cast):
539-
"""Test forward pass on 310P device"""
540-
query = torch.randn(10, 8 * 64)
541-
key = torch.randn(10, 8 * 64)
542-
value = torch.randn(10, 8 * 64)
543-
kv_cache = torch.empty(2, 5, 128, 8, 64)
544-
output = torch.empty_like(query)
545-
546-
metadata = self.attn_metadata
547-
metadata.attn_mask = torch.randn(1, 1, 10, 10)
548-
metadata.query_lens = torch.tensor([10])
549-
metadata.seq_lens = torch.tensor([10])
550-
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
551-
metadata.num_actual_tokens = 10
552-
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
553-
metadata.num_decodes = 0
554-
metadata.num_prefills = 10
555-
layer = self.layer_no_quant
556-
557-
mock_npu_format_cast.return_value = metadata.attn_mask
558-
559-
mock_get_forward_context.return_value = MagicMock(capturing=False)
560-
mock_npu_fused_infer_attention_score.return_value = (output,
561-
torch.ones(
562-
10, 8, 64))
563-
564-
output = self.impl.forward(layer, query, key, value, kv_cache,
565-
metadata, output)
566-
567-
mock_npu_fused_infer_attention_score.assert_called_once()
568-
assert output.shape == (10, 8 * 64)
569-
570348
@patch('torch_npu._npu_reshape_and_cache')
571349
def test_forward_raise_error(self, mock_paged_attention):
572350
query = torch.randn(10, 8 * 64)

0 commit comments

Comments
 (0)