@@ -639,10 +639,15 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
639639 def setUp (self ):
640640 config = MockVllmConfig ()
641641 self .p1 = patch (
642- 'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config' ,
643- new = MagicMock (return_value = None ))
642+ 'vllm_ascend.distributed.mooncake_connector.init_ascend_config' ,
643+ new = MagicMock ())
644+ self .p2 = patch (
645+ 'vllm_ascend.distributed.mooncake_connector.get_ascend_config' ,
646+ new = MagicMock (return_value = MagicMock ()))
644647 self .p1 .start ()
648+ self .p2 .start ()
645649 self .addCleanup (self .p1 .stop )
650+ self .addCleanup (self .p2 .stop )
646651 self .scheduler = MooncakeConnectorScheduler (config , "test_engine" )
647652
648653 def test_get_num_new_matched_tokens (self ):
@@ -716,7 +721,9 @@ def test_scheduler_role(self):
716721 config = MockVllmConfig ()
717722 with patch (
718723 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
719- ):
724+ ), patch (
725+ 'vllm_ascend.distributed.mooncake_connector.get_ascend_config' ,
726+ return_value = MagicMock ()):
720727 connector = MooncakeConnector (config , KVConnectorRole .SCHEDULER )
721728 self .assertIsNotNone (connector .connector_scheduler )
722729 self .assertIsNone (connector .connector_worker )
@@ -726,7 +733,9 @@ def test_scheduler_methods(self, mock_method):
726733 config = MockVllmConfig ()
727734 with patch (
728735 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
729- ):
736+ ), patch (
737+ 'vllm_ascend.distributed.mooncake_connector.get_ascend_config' ,
738+ return_value = MagicMock ()):
730739 connector = MooncakeConnector (config , KVConnectorRole .SCHEDULER )
731740 request = MockRequest ("req1" )
732741 connector .get_num_new_matched_tokens (request , 0 )
@@ -756,7 +765,9 @@ def setUp(self):
756765 def test_scheduler_initialization (self ):
757766 with patch (
758767 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
759- ):
768+ ), patch (
769+ 'vllm_ascend.distributed.mooncake_connector.get_ascend_config' ,
770+ return_value = MagicMock ()):
760771 connector = MooncakeConnector (self .config ,
761772 KVConnectorRole .SCHEDULER )
762773 self .assertIsNotNone (connector .connector_scheduler )
@@ -766,7 +777,9 @@ def test_scheduler_initialization(self):
766777 def test_get_num_new_matched_tokens (self , mock_method ):
767778 with patch (
768779 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
769- ):
780+ ), patch (
781+ 'vllm_ascend.distributed.mooncake_connector.get_ascend_config' ,
782+ return_value = MagicMock ()):
770783 connector = MooncakeConnector (self .config ,
771784 KVConnectorRole .SCHEDULER )
772785 request = MockRequest ("req1" )
@@ -777,7 +790,9 @@ def test_get_num_new_matched_tokens(self, mock_method):
777790 def test_update_state_after_alloc (self , mock_method ):
778791 with patch (
779792 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
780- ):
793+ ), patch (
794+ 'vllm_ascend.distributed.mooncake_connector.get_ascend_config' ,
795+ return_value = MagicMock ()):
781796 connector = MooncakeConnector (self .config ,
782797 KVConnectorRole .SCHEDULER )
783798 request = MockRequest ("req1" )
@@ -789,7 +804,9 @@ def test_update_state_after_alloc(self, mock_method):
789804 def test_build_connector_meta (self , mock_method ):
790805 with patch (
791806 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
792- ):
807+ ), patch (
808+ 'vllm_ascend.distributed.mooncake_connector.get_ascend_config' ,
809+ return_value = MagicMock ()):
793810 connector = MooncakeConnector (self .config ,
794811 KVConnectorRole .SCHEDULER )
795812 scheduler_output = MockSchedulerOutput ()
@@ -800,7 +817,9 @@ def test_build_connector_meta(self, mock_method):
800817 def test_request_finished (self , mock_method ):
801818 with patch (
802819 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
803- ):
820+ ), patch (
821+ 'vllm_ascend.distributed.mooncake_connector.get_ascend_config' ,
822+ return_value = MagicMock ()):
804823 connector = MooncakeConnector (self .config ,
805824 KVConnectorRole .SCHEDULER )
806825 request = MockRequest ("req1" )
@@ -814,7 +833,9 @@ def setUp(self):
814833 self .config = MockVllmConfig ()
815834 with patch (
816835 'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
817- ):
836+ ), patch (
837+ 'vllm_ascend.distributed.mooncake_connector.get_ascend_config' ,
838+ return_value = MagicMock ()):
818839 self .scheduler = MooncakeConnectorScheduler (
819840 self .config , "test_engine" )
820841
@@ -1037,9 +1058,6 @@ def setUp(self):
10371058 self .mock_pcp_group .device_group = MagicMock ()
10381059
10391060 self .patches = [
1040- patch (
1041- 'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES' ,
1042- '10,11' ),
10431061 patch ('torch.Tensor.size' , return_value = (10 , 16 , 8 , 16 )),
10441062 patch ('torch.Tensor.element_size' , return_value = 4 ),
10451063 patch ('torch.Tensor.data_ptr' , return_value = 0x1000 ),
@@ -1056,8 +1074,11 @@ def setUp(self):
10561074 'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash' ,
10571075 mock_string_to_int64_hash ),
10581076 patch (
1059- 'vllm_ascend.distributed.mooncake_transfer_engine.TransferEngine ' ,
1077+ 'vllm_ascend.distributed.mooncake_connector.global_te.get_transfer_engine ' ,
10601078 return_value = self .mock_transfer_engine ),
1079+ patch (
1080+ 'vllm_ascend.distributed.mooncake_connector.global_te.register_buffer' ,
1081+ return_value = None ),
10611082 patch (
10621083 'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread' ,
10631084 MagicMock ()),
@@ -1073,10 +1094,13 @@ def setUp(self):
10731094 patch ('vllm.distributed.parallel_state._DCP' ,
10741095 return_value = self .mock_dcp ),
10751096 patch (
1076- 'vllm .distributed.get_decode_context_model_parallel_world_size' ,
1097+ 'vllm_ascend .distributed.mooncake_connector .get_decode_context_model_parallel_world_size' ,
10771098 return_value = 1 ),
10781099 patch ('vllm_ascend.distributed.mooncake_connector.get_pcp_group' ,
10791100 return_value = self .mock_pcp_group ),
1101+ patch (
1102+ 'vllm_ascend.distributed.mooncake_connector.get_ascend_config' ,
1103+ return_value = MagicMock ()),
10801104 ]
10811105
10821106 for p in self .patches :
@@ -1090,46 +1114,6 @@ def tearDown(self):
10901114 for p in self .patches :
10911115 p .stop () # type: ignore
10921116
1093- def test_worker_use_ascend_direct (self ):
1094- test_case = [True , False ]
1095-
1096- for use_ascend_direct in test_case :
1097- with self .subTest (use_ascend_direct = use_ascend_direct ):
1098- config = MagicMock ()
1099- config .kv_transfer_config = MagicMock ()
1100- config .kv_transfer_config .get_from_extra_config .side_effect = (
1101- lambda k , d : {
1102- "prefill" : {
1103- "tp_size" : 2 ,
1104- "dp_size" : 1
1105- },
1106- "decode" : {
1107- "tp_size" : 2 ,
1108- "dp_size" : 1
1109- },
1110- "use_ascend_direct" : use_ascend_direct ,
1111- }.get (k , d ))
1112-
1113- config .parallel_config = MagicMock ()
1114- config .parallel_config .tensor_parallel_size = 2
1115- config .parallel_config .data_parallel_rank = 0
1116- config .parallel_config .data_parallel_size_local = 1
1117- config .kv_transfer_config .kv_port = 8000
1118- config .kv_transfer_config .kv_role = 'worker'
1119-
1120- with patch (
1121- "vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank" ,
1122- return_value = 0 ):
1123- with patch (
1124- "vllm_ascend.distributed.mooncake_connector.get_tp_group" ,
1125- return_value = None ):
1126- with patch (
1127- "vllm_ascend.distributed.mooncake_connector.get_ip" ,
1128- return_value = "127.0.0.1" ):
1129- worker = MooncakeConnectorWorker (
1130- config , self .engine_id )
1131- self .assertIsNotNone (worker )
1132-
11331117 def test_register_kv_caches_producer (self ):
11341118 worker = MooncakeConnectorWorker (self .vllm_config , self .engine_id )
11351119 worker .register_kv_caches (self .kv_caches )
@@ -1160,7 +1144,7 @@ def test_device_id_selection_with_physical_devices(self):
11601144 # Test with physical devices set
11611145 worker = MooncakeConnectorWorker (self .vllm_config , self .engine_id )
11621146 # Default tp_rank is 0, so device_id should be 10
1163- self .assertEqual (worker .device_id , 10 )
1147+ self .assertIsNotNone (worker .engine )
11641148
11651149
11661150if __name__ == '__main__' :
0 commit comments