Skip to content

Conversation

@YzTongNiar
Copy link
Contributor

@YzTongNiar YzTongNiar commented Dec 16, 2025

What this PR does / why we need it?

Add UT for mooncake

Does this PR introduce any user-facing change?

How was this patch tested?

Signed-off-by: tongyuzhou <[email protected]>
Signed-off-by: wangxiaochao <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request adds a new unit test for _get_kv_split_metadata. My review focuses on improving this test. I've identified a critical bug in the test setup that would cause failures, a gap in test coverage for a key code path, and an opportunity to refactor the test for better readability and maintainability. I've provided a single comprehensive code suggestion to address all these points.

Comment on lines +1260 to +1335
def test_get_kv_split_metadata(self):

def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size,
tp_rank, pcp_rank, _prefill_tp_size,
remote_pcp_size, remote_dcp_size,
remote_port, remote_block_ids,
local_block_ids):

worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)

worker.use_mla = use_mla
worker.pcp_size = pcp_size
worker.dcp_size = dcp_size
worker.tp_size = tp_size
worker.tp_rank = tp_rank
worker.pcp_rank = pcp_rank
worker._prefill_tp_size = _prefill_tp_size
worker.local_remote_block_port_mapping = None

meta = types.SimpleNamespace()

meta.remote_pcp_size = remote_pcp_size
meta.remote_dcp_size = remote_dcp_size
meta.remote_port = remote_port
meta.remote_block_ids = remote_block_ids
meta.local_block_ids = local_block_ids

remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata(
'0', meta)

return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list

self.assertEqual(
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1],
[1]),
([[30001], [30002], [30003], [30004], [30005], [30006], [30007],
[30000]], [[], [], [], [], [], [], [], [1]], [[], [], [], [], [],
[], [], [1]]))

self.assertEqual(
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1],
[1]),
([[30001], [30002], [30003], [30004], [30005], [30006], [30007],
[30008], [30009], [30010], [30011], [30012], [30013], [30014],
[30015], [30000]
], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [],
[1]], [[], [], [], [], [], [], [], [], [], [], [], [], [],
[], [], [1]]))

self.assertEqual(
get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30001], [30008], [30009], [30000]], [[], [], [], [1]
], [[], [], [], [1]]))

self.assertEqual(
get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30001], [30008], [30009], [30000]], [[], [], [], [1]
], [[], [], [], [1]]))

self.assertEqual(
get_kv_split_metadata(True, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30009], [30001]], [[], [1]], [[], [1]]))

self.assertEqual(
get_kv_split_metadata(False, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30009], [30001]], [[], [1]], [[], [1]]))

self.assertEqual(
get_kv_split_metadata(True, 1, 2, 8, 0, 0, 8, 2, 2, 30000,
[1, 2, 3], [1, 2, 3, 4, 5]),
([[30008], [30000]], [[1, 2], [3, 4, 5]], [[1, 2], [1, 2, 3]]))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This new test method has a critical bug and can be improved for better readability and test coverage.

  1. Critical Bug: The test will fail with a TypeError when testing certain paths because self.vllm_config.model_config.hf_config.num_key_value_heads is not initialized. This value is used in _get_remote_ranks_for_req, which is called by _get_kv_split_metadata.

  2. Missing Test Coverage: The test suite does not cover the case where meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1. This is an important branch in the function under test.

  3. Readability and Maintainability: The test consists of multiple repetitive calls to self.assertEqual, which makes it verbose and hard to maintain.

I suggest refactoring this test to be data-driven. This makes it more readable, easier to extend, and allows us to fix the bug and add the missing test case cleanly.

    def test_get_kv_split_metadata(self):

        def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size,
                                  tp_rank, pcp_rank, _prefill_tp_size,
                                  remote_pcp_size, remote_dcp_size,
                                  remote_port, remote_block_ids,
                                  local_block_ids):

            worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
            # This is needed for _get_remote_ranks_for_req to work correctly.
            self.vllm_config.model_config.hf_config.num_key_value_heads = 8

            worker.use_mla = use_mla
            worker.pcp_size = pcp_size
            worker.dcp_size = dcp_size
            worker.tp_size = tp_size
            worker.tp_rank = tp_rank
            worker.pcp_rank = pcp_rank
            worker._prefill_tp_size = _prefill_tp_size
            worker.local_remote_block_port_mapping = None

            meta = types.SimpleNamespace()

            meta.remote_pcp_size = remote_pcp_size
            meta.remote_dcp_size = remote_dcp_size
            meta.remote_port = remote_port
            meta.remote_block_ids = remote_block_ids
            meta.local_block_ids = local_block_ids

            remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata(
                '0', meta)

            return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list

        test_cases = [
            {
                "name": "MLA, remote_dcp=8",
                "params": dict(use_mla=True, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=1, remote_dcp_size=8, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], [30000]], [[], [], [], [], [], [], [], [1]], [[], [], [], [], [], [], [], [1]])
            },
            {
                "name": "no MLA, remote_pcp=2, remote_dcp=8",
                "params": dict(use_mla=False, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=8, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], [30008], [30009], [30010], [30011], [30012], [30013], [30014], [30015], [30000]], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [1]], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [1]])
            },
            {
                "name": "MLA, remote_pcp=2, remote_dcp=2",
                "params": dict(use_mla=True, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30001], [30008], [30009], [30000]], [[], [], [], [1]], [[], [], [], [1]])
            },
            {
                "name": "no MLA, remote_pcp=2, remote_dcp=2",
                "params": dict(use_mla=False, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30001], [30008], [30009], [30000]], [[], [], [], [1]], [[], [], [], [1]])
            },
            {
                "name": "MLA, dcp=2, remote_pcp=2, remote_dcp=2",
                "params": dict(use_mla=True, pcp_size=1, dcp_size=2, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30009], [30001]], [[], [1]], [[], [1]])
            },
            {
                "name": "no MLA, dcp=2, remote_pcp=2, remote_dcp=2",
                "params": dict(use_mla=False, pcp_size=1, dcp_size=2, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30009], [30001]], [[], [1]], [[], [1]])
            },
            {
                "name": "MLA, dcp=2, remote_pcp=2, remote_dcp=2, multiple_blocks",
                "params": dict(use_mla=True, pcp_size=1, dcp_size=2, tp_size=8, tp_rank=0, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1, 2, 3], local_block_ids=[1, 2, 3, 4, 5]),
                "expected": ([[30008], [30000]], [[1, 2], [3, 4, 5]], [[1, 2], [1, 2, 3]])
            },
            {
                "name": "no parallel, simple case",
                "params": dict(use_mla=True, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=1, remote_dcp_size=1, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
                "expected": ([[30001]], [[1]], [[1]])
            },
        ]

        for case in test_cases:
            with self.subTest(case["name"]):
                result = get_kv_split_metadata(**case["params"])
                self.assertEqual(result, case["expected"])

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants