@@ -25,12 +25,6 @@ def test_get_builder_cls(self):
2525 self .assertEqual (AscendAttentionBackend .get_builder_cls (),
2626 AscendAttentionMetadataBuilder )
2727
28- @patch ('vllm_ascend.attention.attention_v1.get_ascend_device_type' ,
29- return_value = AscendDeviceType ._310P )
30- def test_get_kv_cache_shape_310p (self , mock_soc_version ):
31- result = AscendAttentionBackend .get_kv_cache_shape (10 , 20 , 30 , 40 )
32- self .assertEqual (result , (2 , 10 , 30 * 40 // 16 , 20 , 16 ))
33-
3428 @patch ('vllm_ascend.utils.get_ascend_device_type' ,
3529 return_value = AscendDeviceType ._910_93 )
3630 def test_get_kv_cache_shape_not_310p (self , mock_soc_version ):
@@ -95,76 +89,6 @@ def test_reorder_batch(self):
9589
9690 self .assertFalse (result )
9791
98- @patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
99- @patch ('torch_npu.npu_format_cast' )
100- @patch ('vllm_ascend.utils.nd_to_nz_2d' )
101- @patch ('vllm_ascend.utils.get_ascend_device_type' ,
102- return_value = AscendDeviceType ._310P )
103- def test_build_prefill_no_cache (self , mock_soc_version , mock_nd_to_nz_2d ,
104- mock_npu_format_cast ,
105- mock_ascend_metadata ):
106- common_attn_metadata = AscendCommonAttentionMetadata (
107- query_start_loc = torch .tensor ([0 , 3 , 7 ]),
108- query_start_loc_cpu = torch .tensor ([0 , 3 , 7 ]),
109- seq_lens_cpu = torch .tensor ([5 , 6 ]),
110- num_reqs = 2 ,
111- num_actual_tokens = 10 ,
112- max_query_len = 5 ,
113- decode_token_per_req = torch .tensor ([1 , 1 ]),
114- block_table_tensor = torch .zeros ((10 , 10 )),
115- slot_mapping = torch .tensor (range (20 )),
116- actual_seq_lengths_q = torch .tensor ([0 , 1 ]),
117- positions = torch .tensor ([10 , 10 ]),
118- attn_mask = torch .ones ((10 , 10 )),
119- spec_attn_mask = None ,
120- attn_state = AscendAttentionState .PrefillNoCache ,
121- num_computed_tokens_cpu = None ,
122- seq_lens = None )
123-
124- mock_nz_tensor = MagicMock ()
125- mock_model = MagicMock ()
126- mock_nd_to_nz_2d .return_value = mock_nz_tensor
127- mock_npu_format_cast .return_value = mock_nz_tensor
128-
129- self .builder .build (1 , common_attn_metadata , mock_model )
130-
131- @patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
132- @patch ('torch_npu.npu_format_cast' )
133- @patch ('vllm_ascend.utils.nd_to_nz_spec' )
134- @patch ('vllm_ascend.utils.get_ascend_device_type' ,
135- return_value = AscendDeviceType ._310P )
136- @patch ('vllm_ascend.attention.attention_v1.AscendAttentionState' )
137- def test_build_chunked_prefill (self , mock_ascend_attention_state ,
138- mock_soc_version , mock_nd_to_nz_spec ,
139- mock_npu_format_cast , mock_ascend_metadata ):
140- common_attn_metadata = AscendCommonAttentionMetadata (
141- query_start_loc = torch .tensor ([0 , 2 , 5 , 9 ]),
142- query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ]),
143- seq_lens_cpu = torch .tensor ([4 , 5 , 6 ]),
144- num_reqs = 3 ,
145- num_actual_tokens = 15 ,
146- max_query_len = 6 ,
147- decode_token_per_req = torch .tensor ([1 , 1 , 1 ]),
148- block_table_tensor = torch .zeros ((10 , 10 )),
149- slot_mapping = torch .tensor (range (20 )),
150- actual_seq_lengths_q = torch .tensor ([0 , 1 , 2 ]),
151- positions = torch .tensor ([10 , 10 ]),
152- attn_mask = torch .ones ((15 , 15 )),
153- spec_attn_mask = None ,
154- attn_state = AscendAttentionState .ChunkedPrefill ,
155- num_computed_tokens_cpu = None ,
156- seq_lens = None )
157-
158- mock_ascend_attention_state = MagicMock ()
159- mock_ascend_attention_state .PrefillNoCache = 0
160-
161- mock_nz_tensor = MagicMock ()
162- mock_model = MagicMock ()
163- mock_nd_to_nz_spec .return_value = mock_nz_tensor
164- mock_npu_format_cast .return_value = mock_nz_tensor
165-
166- self .builder .build (1 , common_attn_metadata , mock_model )
167-
16892 @patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
16993 @patch ('vllm_ascend.utils.get_ascend_device_type' ,
17094 return_value = AscendDeviceType ._910_93 )
@@ -286,73 +210,40 @@ def test_forward_no_attn_metadata(self):
286210
287211 assert output .shape == (10 , 8 * 64 )
288212
289- @patch ('vllm_ascend.attention.attention_v1.get_forward_context' )
290- @patch ('torch_npu._npu_reshape_and_cache' )
291- @patch ('torch_npu._npu_flash_attention' )
292- def test_forward_prefill_no_cache (self , mock_flash_attention ,
293- mock_reshape_cache ,
294- mock_get_forward_context ):
295- """Test forward pass in PrefillNoCache state"""
296- query = torch .randn (10 , 8 * 64 )
297- key = torch .randn (10 , 8 * 64 )
298- value = torch .randn (10 , 8 * 64 )
299- kv_cache = torch .empty (2 , 5 , 128 , 8 , 64 )
300- output = torch .empty_like (query )
301-
302- mock_get_forward_context .return_value = MagicMock (capturing = False )
303-
304- metadata = self .attn_metadata
305- metadata .attn_state = AscendAttentionState .PrefillNoCache
306- metadata .attn_mask = torch .randn (1 , 1 , 10 , 10 )
307- metadata .seq_lens = torch .tensor ([10 ])
308- metadata .num_actual_tokens = 10
309- metadata .slot_mapping = torch .zeros (10 , dtype = torch .long )
310- metadata .num_decodes = 0
311- metadata .num_prefills = 10
312- layer = self .layer_no_quant
313-
314- output = self .impl .forward (layer , query , key , value , kv_cache ,
315- metadata , output )
316-
317- mock_reshape_cache .assert_called_once ()
318- mock_flash_attention .assert_called_once ()
319- assert output .shape == (10 , 8 * 64 )
320-
321213 @patch ('torch_npu._npu_reshape_and_cache' )
322214 @patch ('torch_npu.npu_fused_infer_attention_score' )
323215 @patch ('vllm_ascend.attention.attention_v1.get_forward_context' )
324- def test_forward_prefill_cache_hit (self , mock_get_forward_context ,
325- mock_npu_fused_infer_attention_score ,
326- mock_npu_reshape_and_cache ):
216+ def test_forward_prefill (self , mock_get_forward_context ,
217+ mock_npu_fused_infer_attention_score ,
218+ mock_npu_reshape_and_cache ):
327219 """Test forward pass in PrefillCacheHit state"""
328- query = torch .randn (10 , 8 * 64 )
329- key = torch .randn (10 , 8 * 64 )
330- value = torch .randn (10 , 8 * 64 )
220+ query = torch .randn (10 , 8 , 64 )
221+ key = torch .randn (10 , 8 , 64 )
222+ value = torch .randn (10 , 8 , 64 )
331223 kv_cache = torch .empty (2 , 5 , 128 , 8 , 64 )
332224 output = torch .empty_like (query )
333-
334225 metadata = self .attn_metadata
335226 metadata .attn_state = AscendAttentionState .PrefillCacheHit
336227 metadata .attn_mask = torch .randn (1 , 1 , 10 , 10 )
337228 metadata .query_lens = torch .tensor ([10 ])
338229 metadata .seq_lens = torch .tensor ([10 ])
230+ metadata .actual_seq_lengths_q = [10 ]
339231 metadata .block_tables = torch .zeros (1 , 5 , dtype = torch .long )
340232 metadata .num_actual_tokens = 10
341- metadata .slot_mapping = torch . zeros ( 10 , dtype = torch . long )
233+ metadata .num_decode_tokens = 0
342234 metadata .num_decodes = 0
343235 metadata .num_prefills = 10
236+ metadata .slot_mapping = torch .zeros (10 , dtype = torch .long )
344237 layer = self .layer_no_quant
345238
346239 mock_get_forward_context .return_value = MagicMock (capturing = False )
347- mock_npu_fused_infer_attention_score .return_value = (output ,
348- torch .ones (
349- 10 , 8 , 64 ))
350-
240+ mock_npu_fused_infer_attention_score .return_value = (torch .ones (
241+ 10 , 8 , 64 ), torch .ones (10 , 8 , 64 ))
351242 output = self .impl .forward (layer , query , key , value , kv_cache ,
352243 metadata , output )
353244
354245 mock_npu_fused_infer_attention_score .assert_called_once ()
355- assert output .shape == (10 , 8 * 64 )
246+ assert output .shape == (10 , 8 , 64 )
356247
357248 @patch ('torch_npu._npu_paged_attention' )
358249 @patch ('torch_npu._npu_reshape_and_cache' )
@@ -454,119 +345,6 @@ def test_forward_decode_only_swa_seq_len_mismatch(
454345
455346 assert output .shape == (10 , 8 * 64 )
456347
457- @patch ('vllm_ascend.attention.attention_v1.get_forward_context' )
458- @patch ('vllm_ascend.utils.get_ascend_device_type' ,
459- return_value = AscendDeviceType ._910_93 )
460- @patch ('torch_npu._npu_reshape_and_cache' )
461- @patch ('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill' )
462- def test_forward_head_size_192 (self , mock_vanilla_prefill ,
463- mock_npu_reshape_and_cache ,
464- mock_soc_version , mock_get_forward_context ):
465- """Test forward pass when head_size is 192"""
466-
467- self .impl .head_size = 192
468- query = torch .randn (10 , 8 * 192 )
469- key = torch .randn (10 , 8 * 192 )
470- value = torch .randn (10 , 8 * 192 )
471- kv_cache = torch .empty (2 , 5 , 128 , 8 , 192 )
472- output = torch .empty_like (query )
473-
474- mock_get_forward_context .return_value = MagicMock (capturing = False )
475-
476- metadata = self .attn_metadata
477- metadata .attn_mask = torch .randn (1 , 1 , 10 , 10 )
478- metadata .query_lens = torch .tensor ([10 ])
479- metadata .seq_lens = torch .tensor ([10 ])
480- metadata .block_tables = torch .zeros (1 , 5 , dtype = torch .long )
481- metadata .num_actual_tokens = 10
482- metadata .slot_mapping = torch .zeros (10 , dtype = torch .long )
483- metadata .num_decodes = 10
484- metadata .num_prefills = 0
485- layer = self .layer_no_quant
486- mock_vanilla_prefill .return_value = MagicMock ()
487-
488- output = self .impl_192 .forward (layer , query , key , value , kv_cache ,
489- metadata , output )
490-
491- mock_vanilla_prefill .assert_called_once ()
492- assert output .shape == (10 , 8 * 192 )
493-
494- @patch ('vllm_ascend.attention.attention_v1.get_forward_context' )
495- @patch ('torch_npu.npu_fused_infer_attention_score' )
496- @patch ('torch_npu._npu_reshape_and_cache' )
497- def test_forward_normal_v1_situation (self , mock_npu_reshape_and_cache ,
498- mock_npu_fused_infer_attention_score ,
499- mock_get_forward_context ):
500- """Test forward pass in normal V1 situation"""
501- query = torch .randn (10 , 8 * 64 )
502- key = torch .randn (10 , 8 * 64 )
503- value = torch .randn (10 , 8 * 64 )
504- kv_cache = torch .empty (2 , 5 , 128 , 8 , 64 )
505- output = torch .empty_like (query )
506-
507- metadata = self .attn_metadata
508- metadata .attn_mask = torch .randn (1 , 1 , 10 , 10 )
509- metadata .query_lens = torch .tensor ([10 ])
510- metadata .seq_lens = torch .tensor ([10 ])
511- metadata .block_tables = torch .zeros (1 , 5 , dtype = torch .long )
512- metadata .num_actual_tokens = 10
513- metadata .slot_mapping = torch .zeros (10 , dtype = torch .long )
514- metadata .num_decodes = 0
515- metadata .num_prefills = 10
516- layer = self .layer_no_quant
517- mock_get_forward_context .return_value = MagicMock (capturing = False )
518- mock_npu_fused_infer_attention_score .return_value = (output ,
519- torch .ones (
520- 10 , 8 , 64 ))
521-
522- output = self .impl .forward (layer , query , key , value , kv_cache ,
523- metadata , output )
524-
525- mock_npu_fused_infer_attention_score .assert_called_once ()
526- assert output .shape == (10 , 8 * 64 )
527-
528- @patch ('torch_npu.npu_format_cast' )
529- @patch ('torch_npu._npu_reshape_and_cache' )
530- @patch ('torch_npu.npu_fused_infer_attention_score' )
531- @patch ('vllm_ascend.utils.get_ascend_device_type' ,
532- return_value = AscendDeviceType ._310P )
533- @patch ('vllm_ascend.attention.attention_v1.get_forward_context' )
534- def test_forward_310p_device (self , mock_get_forward_context ,
535- mock_soc_version ,
536- mock_npu_fused_infer_attention_score ,
537- mock_npu_reshape_and_cache ,
538- mock_npu_format_cast ):
539- """Test forward pass on 310P device"""
540- query = torch .randn (10 , 8 * 64 )
541- key = torch .randn (10 , 8 * 64 )
542- value = torch .randn (10 , 8 * 64 )
543- kv_cache = torch .empty (2 , 5 , 128 , 8 , 64 )
544- output = torch .empty_like (query )
545-
546- metadata = self .attn_metadata
547- metadata .attn_mask = torch .randn (1 , 1 , 10 , 10 )
548- metadata .query_lens = torch .tensor ([10 ])
549- metadata .seq_lens = torch .tensor ([10 ])
550- metadata .block_tables = torch .zeros (1 , 5 , dtype = torch .long )
551- metadata .num_actual_tokens = 10
552- metadata .slot_mapping = torch .zeros (10 , dtype = torch .long )
553- metadata .num_decodes = 0
554- metadata .num_prefills = 10
555- layer = self .layer_no_quant
556-
557- mock_npu_format_cast .return_value = metadata .attn_mask
558-
559- mock_get_forward_context .return_value = MagicMock (capturing = False )
560- mock_npu_fused_infer_attention_score .return_value = (output ,
561- torch .ones (
562- 10 , 8 , 64 ))
563-
564- output = self .impl .forward (layer , query , key , value , kv_cache ,
565- metadata , output )
566-
567- mock_npu_fused_infer_attention_score .assert_called_once ()
568- assert output .shape == (10 , 8 * 64 )
569-
570348 @patch ('torch_npu._npu_reshape_and_cache' )
571349 def test_forward_raise_error (self , mock_paged_attention ):
572350 query = torch .randn (10 , 8 * 64 )
0 commit comments