Skip to content

Commit 0d58dcf

Browse files
majiayu000claude
andcommitted
refactor: move EPLB num_redundant_experts computation to ParallelConfig
Address reviewer feedback: - Move num_redundant_experts computation logic from verify_with_parallel_config to a dedicated method compute_eplb_num_redundant_experts in ParallelConfig - Call the computation method in VllmConfig.__post_init__ before verification - Update tests to verify actual computed values instead of replicating formula - Add tests for EPLB disabled case and explicit value preservation - Fix nested if statement lint warning (SIM102) Signed-off-by: lif <[email protected]> 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]> Signed-off-by: lif <[email protected]>
1 parent 2f8f71a commit 0d58dcf

File tree

4 files changed

+133
-58
lines changed

4 files changed

+133
-58
lines changed

tests/test_config.py

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,28 +1089,91 @@ def test_eplb_num_redundant_experts_default():
10891089
)
10901090

10911091

1092-
@pytest.mark.parametrize("num_experts,ep_size,expected", [
1093-
(8, 8, 0), # Divisible: 8 % 8 = 0
1094-
(8, 16, 8), # ep_size > experts: (16 - 8%16) % 16 = 8
1095-
(8, 6, 4), # Non-divisible: (6 - 8%6) % 6 = (6-2)%6 = 4
1096-
(16, 8, 0), # Divisible: 16 % 8 = 0
1097-
(10, 8, 6), # Non-divisible: (8 - 10%8) % 8 = (8-2)%8 = 6
1098-
(7, 4, 1), # Non-divisible: (4 - 7%4) % 4 = (4-3)%4 = 1
1099-
(1, 4, 3), # Single expert: (4 - 1%4) % 4 = 3
1100-
])
1101-
def test_eplb_num_redundant_experts_auto_computation(num_experts, ep_size, expected):
1102-
"""Test the formula: (ep_size - num_experts % ep_size) % ep_size.
1103-
1104-
This ensures (num_logical_experts + num_redundant_experts) is divisible
1105-
by ep_size, supporting non-standard ep_size values.
1092+
@pytest.mark.parametrize(
1093+
"num_experts,tp_size,dp_size,expected",
1094+
[
1095+
(8, 8, 1, 0), # ep_size=8, divisible: 8 % 8 = 0
1096+
(8, 8, 2, 8), # ep_size=16, ep_size > experts: need 8 redundant
1097+
(8, 2, 3, 4), # ep_size=6, non-divisible: need 4 redundant
1098+
(16, 4, 2, 0), # ep_size=8, divisible: 16 % 8 = 0
1099+
(10, 4, 2, 6), # ep_size=8, non-divisible: need 6 redundant
1100+
(7, 2, 2, 1), # ep_size=4, non-divisible: need 1 redundant
1101+
(1, 2, 2, 3), # ep_size=4, single expert: need 3 redundant
1102+
],
1103+
)
1104+
def test_eplb_num_redundant_experts_auto_computation(
1105+
num_experts, tp_size, dp_size, expected
1106+
):
1107+
"""Test that num_redundant_experts is correctly computed by ParallelConfig.
1108+
1109+
The computation ensures (num_logical_experts + num_redundant_experts)
1110+
is divisible by ep_size (= tp_size * dp_size).
11061111
"""
1107-
# Compute using the actual formula from model.py
1108-
result = (ep_size - num_experts % ep_size) % ep_size
1109-
assert result == expected, (
1110-
f"Formula failed for experts={num_experts}, ep_size={ep_size}: "
1111-
f"got {result}, expected {expected}"
1112+
from vllm.config.parallel import ParallelConfig
1113+
1114+
parallel_config = ParallelConfig(
1115+
tensor_parallel_size=tp_size,
1116+
data_parallel_size=dp_size,
1117+
enable_expert_parallel=True,
1118+
enable_eplb=True,
1119+
)
1120+
# num_redundant_experts should be None before computation
1121+
assert parallel_config.eplb_config.num_redundant_experts is None
1122+
1123+
# Call the computation method
1124+
parallel_config.compute_eplb_num_redundant_experts(num_experts)
1125+
1126+
# Verify the computed value matches expected
1127+
assert parallel_config.eplb_config.num_redundant_experts == expected, (
1128+
f"Expected num_redundant_experts={expected} for "
1129+
f"num_experts={num_experts}, ep_size={tp_size * dp_size}, "
1130+
f"got {parallel_config.eplb_config.num_redundant_experts}"
11121131
)
11131132
# Verify divisibility constraint
1114-
assert (num_experts + result) % ep_size == 0, (
1115-
f"Divisibility check failed: ({num_experts} + {result}) % {ep_size} != 0"
1133+
ep_size = tp_size * dp_size
1134+
total = num_experts + parallel_config.eplb_config.num_redundant_experts
1135+
assert total % ep_size == 0, (
1136+
f"Divisibility check failed: ({num_experts} + "
1137+
f"{parallel_config.eplb_config.num_redundant_experts}) % {ep_size} != 0"
1138+
)
1139+
1140+
1141+
def test_eplb_num_redundant_experts_disabled():
1142+
"""Test that num_redundant_experts defaults to 0 when EPLB is disabled."""
1143+
from vllm.config.parallel import ParallelConfig
1144+
1145+
parallel_config = ParallelConfig(
1146+
tensor_parallel_size=2,
1147+
data_parallel_size=1,
1148+
enable_expert_parallel=False,
1149+
enable_eplb=False,
11161150
)
1151+
# num_redundant_experts should be None before computation
1152+
assert parallel_config.eplb_config.num_redundant_experts is None
1153+
1154+
# Call the computation method
1155+
parallel_config.compute_eplb_num_redundant_experts(num_logical_experts=8)
1156+
1157+
# When EPLB is disabled, should default to 0
1158+
assert parallel_config.eplb_config.num_redundant_experts == 0
1159+
1160+
1161+
def test_eplb_num_redundant_experts_explicit_value_preserved():
1162+
"""Test that explicitly set num_redundant_experts is not overwritten."""
1163+
from vllm.config.parallel import EPLBConfig, ParallelConfig
1164+
1165+
parallel_config = ParallelConfig(
1166+
tensor_parallel_size=4,
1167+
data_parallel_size=2,
1168+
enable_expert_parallel=True,
1169+
enable_eplb=True,
1170+
eplb_config=EPLBConfig(num_redundant_experts=10),
1171+
)
1172+
# num_redundant_experts is explicitly set
1173+
assert parallel_config.eplb_config.num_redundant_experts == 10
1174+
1175+
# Call the computation method - should not override
1176+
parallel_config.compute_eplb_num_redundant_experts(num_logical_experts=8)
1177+
1178+
# Should still be the explicit value
1179+
assert parallel_config.eplb_config.num_redundant_experts == 10

vllm/config/model.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,31 +1040,6 @@ def verify_with_parallel_config(
10401040
if parallel_config.enable_expert_parallel:
10411041
self._verify_with_expert_parallelism()
10421042

1043-
# Compute num_redundant_experts if not specified
1044-
if parallel_config.eplb_config.num_redundant_experts is None:
1045-
if parallel_config.enable_eplb:
1046-
num_logical_experts = self.get_num_experts()
1047-
# EP size is TP * DP for EPLB
1048-
ep_size = (
1049-
parallel_config.tensor_parallel_size
1050-
* parallel_config.data_parallel_size
1051-
)
1052-
# Ensure (num_logical_experts + num_redundant_experts) is
1053-
# divisible by ep_size, supporting non-standard ep_size values
1054-
min_redundant = (ep_size - num_logical_experts % ep_size) % ep_size
1055-
parallel_config.eplb_config.num_redundant_experts = min_redundant
1056-
logger.info(
1057-
"EPLB num_redundant_experts not specified, "
1058-
"defaulting to minimum valid value: %d "
1059-
"(num_logical_experts=%d, ep_size=%d)",
1060-
min_redundant,
1061-
num_logical_experts,
1062-
ep_size,
1063-
)
1064-
else:
1065-
# EPLB disabled, default to 0
1066-
parallel_config.eplb_config.num_redundant_experts = 0
1067-
10681043
pipeline_parallel_size = parallel_config.pipeline_parallel_size
10691044
if pipeline_parallel_size > 1 and not self.registry.is_pp_supported_model(
10701045
self.architectures, self

vllm/config/parallel.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -322,21 +322,53 @@ def _validate_parallel_config(self) -> Self:
322322
f"got {self.eplb_config.num_redundant_experts}"
323323
)
324324

325-
if not self.enable_eplb:
326-
# When EPLB is disabled, num_redundant_experts must be None or 0
327-
if (
328-
self.eplb_config.num_redundant_experts is not None
329-
and self.eplb_config.num_redundant_experts != 0
330-
):
331-
raise ValueError(
332-
"num_redundant_experts is set to "
333-
f"{self.eplb_config.num_redundant_experts} but EPLB is not "
334-
"enabled. Either enable EPLB or unset "
335-
"num_redundant_experts."
336-
)
325+
# When EPLB is disabled, num_redundant_experts must be None or 0
326+
if (
327+
not self.enable_eplb
328+
and self.eplb_config.num_redundant_experts is not None
329+
and self.eplb_config.num_redundant_experts != 0
330+
):
331+
raise ValueError(
332+
"num_redundant_experts is set to "
333+
f"{self.eplb_config.num_redundant_experts} but EPLB is not "
334+
"enabled. Either enable EPLB or unset "
335+
"num_redundant_experts."
336+
)
337337

338338
return self
339339

340+
def compute_eplb_num_redundant_experts(self, num_logical_experts: int) -> None:
341+
"""Compute and set num_redundant_experts if not explicitly specified.
342+
343+
This method should be called after ParallelConfig is initialized and
344+
when the number of logical experts is known (from ModelConfig).
345+
346+
Args:
347+
num_logical_experts: The number of logical experts from the model.
348+
"""
349+
if self.eplb_config.num_redundant_experts is not None:
350+
# Already explicitly set, don't override
351+
return
352+
353+
if self.enable_eplb:
354+
# EP size is TP * DP for EPLB
355+
ep_size = self.tensor_parallel_size * self.data_parallel_size
356+
# Ensure (num_logical_experts + num_redundant_experts) is
357+
# divisible by ep_size, supporting non-standard ep_size values
358+
min_redundant = (ep_size - num_logical_experts % ep_size) % ep_size
359+
self.eplb_config.num_redundant_experts = min_redundant
360+
logger.info(
361+
"EPLB num_redundant_experts not specified, "
362+
"defaulting to minimum valid value: %d "
363+
"(num_logical_experts=%d, ep_size=%d)",
364+
min_redundant,
365+
num_logical_experts,
366+
ep_size,
367+
)
368+
else:
369+
# EPLB disabled, default to 0
370+
self.eplb_config.num_redundant_experts = 0
371+
340372
@property
341373
def world_size_across_dp(self) -> int:
342374
"""world_size_across_dp is TPxPPxDP, it is the size of the world

vllm/config/vllm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,11 @@ def __post_init__(self):
519519
self.try_verify_and_update_config()
520520

521521
if self.model_config is not None:
522+
# Compute EPLB num_redundant_experts before verification
523+
num_experts = self.model_config.get_num_experts()
524+
if num_experts is not None:
525+
self.parallel_config.compute_eplb_num_redundant_experts(num_experts)
526+
522527
self.model_config.verify_with_parallel_config(self.parallel_config)
523528
self.model_config.verify_dual_chunk_attention_config(self.load_config)
524529

0 commit comments

Comments
 (0)