@@ -25,12 +25,6 @@ def test_get_builder_cls(self):
2525 self .assertEqual (AscendAttentionBackend .get_builder_cls (),
2626 AscendAttentionMetadataBuilder )
2727
28- @patch ('vllm_ascend.attention.attention_v1.get_ascend_device_type' ,
29- return_value = AscendDeviceType ._310P )
30- def test_get_kv_cache_shape_310p (self , mock_soc_version ):
31- result = AscendAttentionBackend .get_kv_cache_shape (10 , 20 , 30 , 40 )
32- self .assertEqual (result , (2 , 10 , 30 * 40 // 16 , 20 , 16 ))
33-
3428 @patch ('vllm_ascend.utils.get_ascend_device_type' ,
3529 return_value = AscendDeviceType ._910_93 )
3630 def test_get_kv_cache_shape_not_310p (self , mock_soc_version ):
@@ -95,76 +89,6 @@ def test_reorder_batch(self):
9589
9690 self .assertFalse (result )
9791
98- @patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
99- @patch ('torch_npu.npu_format_cast' )
100- @patch ('vllm_ascend.utils.nd_to_nz_2d' )
101- @patch ('vllm_ascend.utils.get_ascend_device_type' ,
102- return_value = AscendDeviceType ._310P )
103- def test_build_prefill_no_cache (self , mock_soc_version , mock_nd_to_nz_2d ,
104- mock_npu_format_cast ,
105- mock_ascend_metadata ):
106- common_attn_metadata = AscendCommonAttentionMetadata (
107- query_start_loc = torch .tensor ([0 , 3 , 7 ]),
108- query_start_loc_cpu = torch .tensor ([0 , 3 , 7 ]),
109- seq_lens_cpu = torch .tensor ([5 , 6 ]),
110- num_reqs = 2 ,
111- num_actual_tokens = 10 ,
112- max_query_len = 5 ,
113- decode_token_per_req = torch .tensor ([1 , 1 ]),
114- block_table_tensor = torch .zeros ((10 , 10 )),
115- slot_mapping = torch .tensor (range (20 )),
116- actual_seq_lengths_q = torch .tensor ([0 , 1 ]),
117- positions = torch .tensor ([10 , 10 ]),
118- attn_mask = torch .ones ((10 , 10 )),
119- spec_attn_mask = None ,
120- attn_state = AscendAttentionState .PrefillNoCache ,
121- num_computed_tokens_cpu = None ,
122- seq_lens = None )
123-
124- mock_nz_tensor = MagicMock ()
125- mock_model = MagicMock ()
126- mock_nd_to_nz_2d .return_value = mock_nz_tensor
127- mock_npu_format_cast .return_value = mock_nz_tensor
128-
129- self .builder .build (1 , common_attn_metadata , mock_model )
130-
131- @patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
132- @patch ('torch_npu.npu_format_cast' )
133- @patch ('vllm_ascend.utils.nd_to_nz_spec' )
134- @patch ('vllm_ascend.utils.get_ascend_device_type' ,
135- return_value = AscendDeviceType ._310P )
136- @patch ('vllm_ascend.attention.attention_v1.AscendAttentionState' )
137- def test_build_chunked_prefill (self , mock_ascend_attention_state ,
138- mock_soc_version , mock_nd_to_nz_spec ,
139- mock_npu_format_cast , mock_ascend_metadata ):
140- common_attn_metadata = AscendCommonAttentionMetadata (
141- query_start_loc = torch .tensor ([0 , 2 , 5 , 9 ]),
142- query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ]),
143- seq_lens_cpu = torch .tensor ([4 , 5 , 6 ]),
144- num_reqs = 3 ,
145- num_actual_tokens = 15 ,
146- max_query_len = 6 ,
147- decode_token_per_req = torch .tensor ([1 , 1 , 1 ]),
148- block_table_tensor = torch .zeros ((10 , 10 )),
149- slot_mapping = torch .tensor (range (20 )),
150- actual_seq_lengths_q = torch .tensor ([0 , 1 , 2 ]),
151- positions = torch .tensor ([10 , 10 ]),
152- attn_mask = torch .ones ((15 , 15 )),
153- spec_attn_mask = None ,
154- attn_state = AscendAttentionState .ChunkedPrefill ,
155- num_computed_tokens_cpu = None ,
156- seq_lens = None )
157-
158- mock_ascend_attention_state = MagicMock ()
159- mock_ascend_attention_state .PrefillNoCache = 0
160-
161- mock_nz_tensor = MagicMock ()
162- mock_model = MagicMock ()
163- mock_nd_to_nz_spec .return_value = mock_nz_tensor
164- mock_npu_format_cast .return_value = mock_nz_tensor
165-
166- self .builder .build (1 , common_attn_metadata , mock_model )
167-
16892 @patch ('vllm_ascend.attention.attention_v1.AscendMetadata' )
16993 @patch ('vllm_ascend.utils.get_ascend_device_type' ,
17094 return_value = AscendDeviceType ._910_93 )
@@ -297,8 +221,6 @@ def test_forward_prefill(self, mock_get_forward_context,
297221 key = torch .randn (10 , 8 * 64 )
298222 value = torch .randn (10 , 8 * 64 )
299223 kv_cache = torch .empty (2 , 5 , 128 , 8 , 64 )
300- output = torch .empty_like (query )
301-
302224 metadata = self .attn_metadata
303225 metadata .attn_state = AscendAttentionState .PrefillCacheHit
304226 metadata .attn_mask = torch .randn (1 , 1 , 10 , 10 )
@@ -307,17 +229,16 @@ def test_forward_prefill(self, mock_get_forward_context,
307229 metadata .block_tables = torch .zeros (1 , 5 , dtype = torch .long )
308230 metadata .num_actual_tokens = 10
309231 metadata .slot_mapping = torch .zeros (10 , dtype = torch .long )
310- metadata .num_decodes = 0
311- metadata .num_prefills = 10
312232 layer = self .layer_no_quant
313233
314234 mock_get_forward_context .return_value = MagicMock (capturing = False )
315- mock_npu_fused_infer_attention_score .return_value = (output ,
316- torch .ones (
317- 10 , 8 , 64 ))
318-
319- output = self .impl .forward (layer , query , key , value , kv_cache ,
320- metadata , output )
235+ output = self .impl .forward (layer ,
236+ query ,
237+ key ,
238+ value ,
239+ kv_cache ,
240+ metadata ,
241+ trace_flag = False )
321242
322243 mock_npu_fused_infer_attention_score .assert_called_once ()
323244 assert output .shape == (10 , 8 * 64 )
@@ -422,85 +343,6 @@ def test_forward_decode_only_swa_seq_len_mismatch(
422343
423344 assert output .shape == (10 , 8 * 64 )
424345
425- @patch ('vllm_ascend.attention.attention_v1.get_forward_context' )
426- @patch ('vllm_ascend.utils.get_ascend_device_type' ,
427- return_value = AscendDeviceType ._910_93 )
428- @patch ('torch_npu._npu_reshape_and_cache' )
429- @patch ('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill' )
430- def test_forward_head_size_192 (self , mock_vanilla_prefill ,
431- mock_npu_reshape_and_cache ,
432- mock_soc_version , mock_get_forward_context ):
433- """Test forward pass when head_size is 192"""
434-
435- self .impl .head_size = 192
436- query = torch .randn (10 , 8 * 192 )
437- key = torch .randn (10 , 8 * 192 )
438- value = torch .randn (10 , 8 * 192 )
439- kv_cache = torch .empty (2 , 5 , 128 , 8 , 192 )
440- output = torch .empty_like (query )
441-
442- mock_get_forward_context .return_value = MagicMock (capturing = False )
443-
444- metadata = self .attn_metadata
445- metadata .attn_mask = torch .randn (1 , 1 , 10 , 10 )
446- metadata .query_lens = torch .tensor ([10 ])
447- metadata .seq_lens = torch .tensor ([10 ])
448- metadata .block_tables = torch .zeros (1 , 5 , dtype = torch .long )
449- metadata .num_actual_tokens = 10
450- metadata .slot_mapping = torch .zeros (10 , dtype = torch .long )
451- metadata .num_decodes = 10
452- metadata .num_prefills = 0
453- layer = self .layer_no_quant
454- mock_vanilla_prefill .return_value = MagicMock ()
455-
456- output = self .impl_192 .forward (layer , query , key , value , kv_cache ,
457- metadata , output )
458-
459- mock_vanilla_prefill .assert_called_once ()
460- assert output .shape == (10 , 8 * 192 )
461-
462- @patch ('torch_npu.npu_format_cast' )
463- @patch ('torch_npu._npu_reshape_and_cache' )
464- @patch ('torch_npu.npu_fused_infer_attention_score' )
465- @patch ('vllm_ascend.utils.get_ascend_device_type' ,
466- return_value = AscendDeviceType ._310P )
467- @patch ('vllm_ascend.attention.attention_v1.get_forward_context' )
468- def test_forward_310p_device (self , mock_get_forward_context ,
469- mock_soc_version ,
470- mock_npu_fused_infer_attention_score ,
471- mock_npu_reshape_and_cache ,
472- mock_npu_format_cast ):
473- """Test forward pass on 310P device"""
474- query = torch .randn (10 , 8 * 64 )
475- key = torch .randn (10 , 8 * 64 )
476- value = torch .randn (10 , 8 * 64 )
477- kv_cache = torch .empty (2 , 5 , 128 , 8 , 64 )
478- output = torch .empty_like (query )
479-
480- metadata = self .attn_metadata
481- metadata .attn_mask = torch .randn (1 , 1 , 10 , 10 )
482- metadata .query_lens = torch .tensor ([10 ])
483- metadata .seq_lens = torch .tensor ([10 ])
484- metadata .block_tables = torch .zeros (1 , 5 , dtype = torch .long )
485- metadata .num_actual_tokens = 10
486- metadata .slot_mapping = torch .zeros (10 , dtype = torch .long )
487- metadata .num_decodes = 0
488- metadata .num_prefills = 10
489- layer = self .layer_no_quant
490-
491- mock_npu_format_cast .return_value = metadata .attn_mask
492-
493- mock_get_forward_context .return_value = MagicMock (capturing = False )
494- mock_npu_fused_infer_attention_score .return_value = (output ,
495- torch .ones (
496- 10 , 8 , 64 ))
497-
498- output = self .impl .forward (layer , query , key , value , kv_cache ,
499- metadata , output )
500-
501- mock_npu_fused_infer_attention_score .assert_called_once ()
502- assert output .shape == (10 , 8 * 64 )
503-
504346 @patch ('torch_npu._npu_reshape_and_cache' )
505347 def test_forward_raise_error (self , mock_paged_attention ):
506348 query = torch .randn (10 , 8 * 64 )
0 commit comments