@@ -454,12 +454,20 @@ def setUp(self):
454454 "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
455455 )
456456 @patch ("vllm_ascend.attention.mla_v1.get_ascend_config" )
457- def test_build_prefix_no_cache_metadata (self , mock_get_ascend_config ,
457+ @patch ("vllm_ascend.attention.mla_v1.torch.zeros" , wraps = torch .zeros )
458+ @patch ("torch.Tensor.npu" , new = lambda self : self )
459+ @patch ("torch.npu.is_available" )
460+ def test_build_prefix_no_cache_metadata (self , mock_npu_available ,
461+ mock_zeros , mock_get_ascend_config ,
458462 mock_dcp_world_size ):
459- if not torch .npu .is_available ():
460- self .skipTest ("NPU not available, skipping NPU-dependent tests" )
463+ mock_npu_available .return_value = False
461464 mock_dcp_world_size .return_value = 1
462465
466+ def zeros_override (* args , ** kwargs ):
467+ kwargs .pop ('pin_memory' , None )
468+ return mock_zeros ._mock_wraps (* args , ** kwargs )
469+
470+ mock_zeros .side_effect = zeros_override
463471 common_attn_metadata = AscendCommonAttentionMetadata (
464472 query_start_loc = torch .tensor ([0 , 3 , 7 ]),
465473 query_start_loc_cpu = torch .tensor ([0 , 3 , 7 ]),
@@ -506,12 +514,21 @@ def test_build_prefix_no_cache_metadata(self, mock_get_ascend_config,
506514 "vllm_ascend.attention.mla_v1.get_decode_context_model_parallel_world_size"
507515 )
508516 @patch ("vllm_ascend.attention.mla_v1.get_ascend_config" )
509- def test_build_chunked_prefix_metadata (self , mock_get_ascend_config ,
517+ @patch ("vllm_ascend.attention.mla_v1.torch.zeros" , wraps = torch .zeros )
518+ @patch ("torch.Tensor.npu" , new = lambda self : self )
519+ @patch ("torch.npu.is_available" )
520+ def test_build_chunked_prefix_metadata (self , mock_npu_available ,
521+ mock_zeros , mock_get_ascend_config ,
510522 mock_dcp_world_size ):
511- if not torch .npu .is_available ():
512- self .skipTest ("NPU not available, skipping NPU-dependent tests" )
523+ mock_npu_available .return_value = False
513524 mock_dcp_world_size .return_value = 1
514525
526+ def zeros_override (* args , ** kwargs ):
527+ kwargs .pop ('pin_memory' , None )
528+ return mock_zeros ._mock_wraps (* args , ** kwargs )
529+
530+ mock_zeros .side_effect = zeros_override
531+
515532 common_attn_metadata = AscendCommonAttentionMetadata (
516533 query_start_loc = torch .tensor ([0 , 2 , 5 , 9 ]),
517534 query_start_loc_cpu = torch .tensor ([0 , 2 , 5 , 9 ]),
0 commit comments