@@ -434,7 +434,6 @@ def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_get_dcp_size,
434434
435435
436436class TestAscendMLAMetadataBuilderBuild (TestBase ):
437-
438437 def setUp (self ):
439438 self .mock_vllm_config = MagicMock (spec = VllmConfig )
440439 self .mock_vllm_config .model_config = ModelConfig (max_model_len = 2048 )
@@ -454,9 +453,14 @@ def setUp(self):
454453 "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
455454 )
456455 @patch ("vllm_ascend.attention.mla_v1.get_ascend_config" )
457- def test_build_prefix_no_cache_metadata (self , mock_get_ascend_config ,
456+ @patch ("vllm_ascend.attention.mla_v1.torch" )
457+ def test_build_prefix_no_cache_metadata (self , mock_torch , mock_get_ascend_config ,
458458 mock_dcp_world_size ):
459459 mock_dcp_world_size .return_value = 1
460+ def mock_zeros (* args , ** kwargs ):
461+ return torch .empty (* args , ** kwargs , device = "cpu" )
462+
463+ mock_torch .zeros .side_effect = mock_zeros
460464
461465 common_attn_metadata = AscendCommonAttentionMetadata (
462466 query_start_loc = torch .tensor ([0 , 3 , 7 ]),
@@ -504,9 +508,14 @@ def test_build_prefix_no_cache_metadata(self, mock_get_ascend_config,
504508 "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
505509 )
506510 @patch ("vllm_ascend.attention.mla_v1.get_ascend_config" )
507- def test_build_chunked_prefix_metadata (self , mock_get_ascend_config ,
511+ @patch ("vllm_ascend.attention.mla_v1.torch" )
512+ def test_build_chunked_prefix_metadata (self , mock_torch , mock_get_ascend_config ,
508513 mock_dcp_world_size ):
509514 mock_dcp_world_size .return_value = 1
515+ def mock_zeros (* args , ** kwargs ):
516+ return torch .empty (* args , ** kwargs , device = "cpu" )
517+
518+ mock_torch .zeros .side_effect = mock_zeros
510519
511520 common_attn_metadata = AscendCommonAttentionMetadata (
512521 query_start_loc = torch .tensor ([0 , 2 , 5 , 9 ]),
0 commit comments