Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions tests/ut/kv_connector/test_mooncake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,82 @@ def get_tp_rank(prefill_tp_size: int, prefill_pp_size: int,
get_tp_rank(4, 4, 4, 1, 1, True),
[[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]])

def test_get_kv_split_metadata(self):

def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size,
tp_rank, pcp_rank, _prefill_tp_size,
remote_pcp_size, remote_dcp_size,
remote_port, remote_block_ids,
local_block_ids):

worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)

worker.use_mla = use_mla
worker.pcp_size = pcp_size
worker.dcp_size = dcp_size
worker.tp_size = tp_size
worker.tp_rank = tp_rank
worker.pcp_rank = pcp_rank
worker._prefill_tp_size = _prefill_tp_size
worker.local_remote_block_port_mapping = None

meta = types.SimpleNamespace()

meta.remote_pcp_size = remote_pcp_size
meta.remote_dcp_size = remote_dcp_size
meta.remote_port = remote_port
meta.remote_block_ids = remote_block_ids
meta.local_block_ids = local_block_ids

remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata(
'0', meta)

return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list

self.assertEqual(
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1],
[1]),
([[30001], [30002], [30003], [30004], [30005], [30006], [30007],
[30000]], [[], [], [], [], [], [], [], [1]], [[], [], [], [], [],
[], [], [1]]))

self.assertEqual(
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1],
[1]),
([[30001], [30002], [30003], [30004], [30005], [30006], [30007],
[30008], [30009], [30010], [30011], [30012], [30013], [30014],
[30015], [30000]
], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [],
[1]], [[], [], [], [], [], [], [], [], [], [], [], [], [],
[], [], [1]]))

self.assertEqual(
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30001], [30008], [30009], [30000]], [[], [], [], [1]
], [[], [], [], [1]]))

self.assertEqual(
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30001], [30008], [30009], [30000]], [[], [], [], [1]
], [[], [], [], [1]]))

self.assertEqual(
get_kv_split_metadata(True, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30009], [30001]], [[], [1]], [[], [1]]))

self.assertEqual(
get_kv_split_metadata(False, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30009], [30001]], [[], [1]], [[], [1]]))

self.assertEqual(
get_kv_split_metadata(True, 1, 2, 8, 0, 0, 8, 2, 2, 30000,
[1, 2, 3], [1, 2, 3, 4, 5]),
([[30008], [30000]], [[1, 2], [3, 4, 5]], [[1, 2], [1, 2, 3]]))

Comment on lines +1263 to +1338
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This new test method has a critical bug and can be improved for better readability and test coverage.

  1. Critical Bug: The test will fail with a TypeError when testing certain paths because self.vllm_config.model_config.hf_config.num_key_value_heads is not initialized. This value is used in _get_remote_ranks_for_req, which is called by _get_kv_split_metadata.

  2. Missing Test Coverage: The test suite does not cover the case where meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1. This is an important branch in the function under test.

  3. Readability and Maintainability: The test consists of multiple repetitive calls to self.assertEqual, which makes it verbose and hard to maintain.

I suggest refactoring this test to be data-driven. This makes it more readable, easier to extend, and allows us to fix the bug and add the missing test case cleanly.

    def test_get_kv_split_metadata(self):

        def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size,
                                  tp_rank, pcp_rank, _prefill_tp_size,
                                  remote_pcp_size, remote_dcp_size,
                                  remote_port, remote_block_ids,
                                  local_block_ids):

            worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
            # This is needed for _get_remote_ranks_for_req to work correctly.
            self.vllm_config.model_config.hf_config.num_key_value_heads = 8

            worker.use_mla = use_mla
            worker.pcp_size = pcp_size
            worker.dcp_size = dcp_size
            worker.tp_size = tp_size
            worker.tp_rank = tp_rank
            worker.pcp_rank = pcp_rank
            worker._prefill_tp_size = _prefill_tp_size
            worker.local_remote_block_port_mapping = None

            meta = types.SimpleNamespace()

            meta.remote_pcp_size = remote_pcp_size
            meta.remote_dcp_size = remote_dcp_size
            meta.remote_port = remote_port
            meta.remote_block_ids = remote_block_ids
            meta.local_block_ids = local_block_ids

            remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata(
                '0', meta)

            return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list

        test_cases = [
            {
                "name": "MLA, remote_dcp=8",
                "params": dict(use_mla=True, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=1, remote_dcp_size=8, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], [30000]], [[], [], [], [], [], [], [], [1]], [[], [], [], [], [], [], [], [1]])
            },
            {
                "name": "no MLA, remote_pcp=2, remote_dcp=8",
                "params": dict(use_mla=False, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=8, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], [30008], [30009], [30010], [30011], [30012], [30013], [30014], [30015], [30000]], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [1]], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [1]])
            },
            {
                "name": "MLA, remote_pcp=2, remote_dcp=2",
                "params": dict(use_mla=True, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30001], [30008], [30009], [30000]], [[], [], [], [1]], [[], [], [], [1]])
            },
            {
                "name": "no MLA, remote_pcp=2, remote_dcp=2",
                "params": dict(use_mla=False, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30001], [30008], [30009], [30000]], [[], [], [], [1]], [[], [], [], [1]])
            },
            {
                "name": "MLA, dcp=2, remote_pcp=2, remote_dcp=2",
                "params": dict(use_mla=True, pcp_size=1, dcp_size=2, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30009], [30001]], [[], [1]], [[], [1]])
            },
            {
                "name": "no MLA, dcp=2, remote_pcp=2, remote_dcp=2",
                "params": dict(use_mla=False, pcp_size=1, dcp_size=2, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30009], [30001]], [[], [1]], [[], [1]])
            },
            {
                "name": "MLA, dcp=2, remote_pcp=2, remote_dcp=2, multiple_blocks",
                "params": dict(use_mla=True, pcp_size=1, dcp_size=2, tp_size=8, tp_rank=0, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1, 2, 3], local_block_ids=[1, 2, 3, 4, 5]),
                "expected": ([[30008], [30000]], [[1, 2], [3, 4, 5]], [[1, 2], [1, 2, 3]])
            },
            {
                "name": "no parallel, simple case",
                "params": dict(use_mla=True, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=1, remote_dcp_size=1, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30001]], [[1]], [[1]])
            },
        ]

        for case in test_cases:
            with self.subTest(case["name"]):
                result = get_kv_split_metadata(**case["params"])
                self.assertEqual(result, case["expected"])


if __name__ == '__main__':
unittest.main()
Loading