-
Notifications
You must be signed in to change notification settings - Fork 665
[UT] Add mooncake ut test #5080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: tongyuzhou <[email protected]> Signed-off-by: wangxiaochao <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The pull request adds a new unit test for _get_kv_split_metadata. My review focuses on improving this test. I've identified a critical bug in the test setup that would cause failures, a gap in test coverage for a key code path, and an opportunity to refactor the test for better readability and maintainability. I've provided a single comprehensive code suggestion to address all these points.
| def test_get_kv_split_metadata(self): | ||
|
|
||
| def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size, | ||
| tp_rank, pcp_rank, _prefill_tp_size, | ||
| remote_pcp_size, remote_dcp_size, | ||
| remote_port, remote_block_ids, | ||
| local_block_ids): | ||
|
|
||
| worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) | ||
|
|
||
| worker.use_mla = use_mla | ||
| worker.pcp_size = pcp_size | ||
| worker.dcp_size = dcp_size | ||
| worker.tp_size = tp_size | ||
| worker.tp_rank = tp_rank | ||
| worker.pcp_rank = pcp_rank | ||
| worker._prefill_tp_size = _prefill_tp_size | ||
| worker.local_remote_block_port_mapping = None | ||
|
|
||
| meta = types.SimpleNamespace() | ||
|
|
||
| meta.remote_pcp_size = remote_pcp_size | ||
| meta.remote_dcp_size = remote_dcp_size | ||
| meta.remote_port = remote_port | ||
| meta.remote_block_ids = remote_block_ids | ||
| meta.local_block_ids = local_block_ids | ||
|
|
||
| remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata( | ||
| '0', meta) | ||
|
|
||
| return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list | ||
|
|
||
| self.assertEqual( | ||
| get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 1, 8, 30000, [1], | ||
| [1]), | ||
| ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], | ||
| [30000]], [[], [], [], [], [], [], [], [1]], [[], [], [], [], [], | ||
| [], [], [1]])) | ||
|
|
||
| self.assertEqual( | ||
| get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 8, 30000, [1], | ||
| [1]), | ||
| ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], | ||
| [30008], [30009], [30010], [30011], [30012], [30013], [30014], | ||
| [30015], [30000] | ||
| ], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], | ||
| [1]], [[], [], [], [], [], [], [], [], [], [], [], [], [], | ||
| [], [], [1]])) | ||
|
|
||
| self.assertEqual( | ||
| get_kv_split_metadata(True, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1], | ||
| [1]), | ||
| ([[30001], [30008], [30009], [30000]], [[], [], [], [1] | ||
| ], [[], [], [], [1]])) | ||
|
|
||
| self.assertEqual( | ||
| get_kv_split_metadata(False, 1, 1, 8, 1, 0, 8, 2, 2, 30000, [1], | ||
| [1]), | ||
| ([[30001], [30008], [30009], [30000]], [[], [], [], [1] | ||
| ], [[], [], [], [1]])) | ||
|
|
||
| self.assertEqual( | ||
| get_kv_split_metadata(True, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1], | ||
| [1]), | ||
| ([[30009], [30001]], [[], [1]], [[], [1]])) | ||
|
|
||
| self.assertEqual( | ||
| get_kv_split_metadata(False, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1], | ||
| [1]), | ||
| ([[30009], [30001]], [[], [1]], [[], [1]])) | ||
|
|
||
| self.assertEqual( | ||
| get_kv_split_metadata(True, 1, 2, 8, 0, 0, 8, 2, 2, 30000, | ||
| [1, 2, 3], [1, 2, 3, 4, 5]), | ||
| ([[30008], [30000]], [[1, 2], [3, 4, 5]], [[1, 2], [1, 2, 3]])) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new test method has a critical bug and can be improved for better readability and test coverage.
-
Critical Bug: The test will fail with a
TypeErrorwhen testing certain paths becauseself.vllm_config.model_config.hf_config.num_key_value_headsis not initialized. This value is used in_get_remote_ranks_for_req, which is called by_get_kv_split_metadata. -
Missing Test Coverage: The test suite does not cover the case where
meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1. This is an important branch in the function under test. -
Readability and Maintainability: The test consists of multiple repetitive calls to
self.assertEqual, which makes it verbose and hard to maintain.
I suggest refactoring this test to be data-driven. This makes it more readable, easier to extend, and allows us to fix the bug and add the missing test case cleanly.
def test_get_kv_split_metadata(self):
def get_kv_split_metadata(use_mla, pcp_size, dcp_size, tp_size,
tp_rank, pcp_rank, _prefill_tp_size,
remote_pcp_size, remote_dcp_size,
remote_port, remote_block_ids,
local_block_ids):
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
# This is needed for _get_remote_ranks_for_req to work correctly.
self.vllm_config.model_config.hf_config.num_key_value_heads = 8
worker.use_mla = use_mla
worker.pcp_size = pcp_size
worker.dcp_size = dcp_size
worker.tp_size = tp_size
worker.tp_rank = tp_rank
worker.pcp_rank = pcp_rank
worker._prefill_tp_size = _prefill_tp_size
worker.local_remote_block_port_mapping = None
meta = types.SimpleNamespace()
meta.remote_pcp_size = remote_pcp_size
meta.remote_dcp_size = remote_dcp_size
meta.remote_port = remote_port
meta.remote_block_ids = remote_block_ids
meta.local_block_ids = local_block_ids
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata(
'0', meta)
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
test_cases = [
{
"name": "MLA, remote_dcp=8",
"params": dict(use_mla=True, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=1, remote_dcp_size=8, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
"expected": ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], [30000]], [[], [], [], [], [], [], [], [1]], [[], [], [], [], [], [], [], [1]])
},
{
"name": "no MLA, remote_pcp=2, remote_dcp=8",
"params": dict(use_mla=False, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=8, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
"expected": ([[30001], [30002], [30003], [30004], [30005], [30006], [30007], [30008], [30009], [30010], [30011], [30012], [30013], [30014], [30015], [30000]], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [1]], [[], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [1]])
},
{
"name": "MLA, remote_pcp=2, remote_dcp=2",
"params": dict(use_mla=True, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
"expected": ([[30001], [30008], [30009], [30000]], [[], [], [], [1]], [[], [], [], [1]])
},
{
"name": "no MLA, remote_pcp=2, remote_dcp=2",
"params": dict(use_mla=False, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
"expected": ([[30001], [30008], [30009], [30000]], [[], [], [], [1]], [[], [], [], [1]])
},
{
"name": "MLA, dcp=2, remote_pcp=2, remote_dcp=2",
"params": dict(use_mla=True, pcp_size=1, dcp_size=2, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
"expected": ([[30009], [30001]], [[], [1]], [[], [1]])
},
{
"name": "no MLA, dcp=2, remote_pcp=2, remote_dcp=2",
"params": dict(use_mla=False, pcp_size=1, dcp_size=2, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
"expected": ([[30009], [30001]], [[], [1]], [[], [1]])
},
{
"name": "MLA, dcp=2, remote_pcp=2, remote_dcp=2, multiple_blocks",
"params": dict(use_mla=True, pcp_size=1, dcp_size=2, tp_size=8, tp_rank=0, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=2, remote_dcp_size=2, remote_port=30000, remote_block_ids=[1, 2, 3], local_block_ids=[1, 2, 3, 4, 5]),
"expected": ([[30008], [30000]], [[1, 2], [3, 4, 5]], [[1, 2], [1, 2, 3]])
},
{
"name": "no parallel, simple case",
"params": dict(use_mla=True, pcp_size=1, dcp_size=1, tp_size=8, tp_rank=1, pcp_rank=0, _prefill_tp_size=8, remote_pcp_size=1, remote_dcp_size=1, remote_port=30000, remote_block_ids=[1], local_block_ids=[1]),
"expected": ([[30001]], [[1]], [[1]])
},
]
for case in test_cases:
with self.subTest(case["name"]):
result = get_kv_split_metadata(**case["params"])
self.assertEqual(result, case["expected"])|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
What this PR does / why we need it?
Add UT for mooncake
Does this PR introduce any user-facing change?
How was this patch tested?