Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,3 +1050,130 @@ def test_scheduler_config_init():
with pytest.raises(AttributeError):
# InitVar does not become an attribute
print(SchedulerConfig.default_factory().max_model_len)


def test_eplb_num_redundant_experts_default():
"""Test that num_redundant_experts defaults to None and can be set."""
from vllm.config.parallel import EPLBConfig, ParallelConfig

# Test default is None
eplb_config = EPLBConfig()
assert eplb_config.num_redundant_experts is None

# Test explicit value
eplb_config_explicit = EPLBConfig(num_redundant_experts=4)
assert eplb_config_explicit.num_redundant_experts == 4

# Test ParallelConfig validation - EPLB disabled with None is OK
parallel_config = ParallelConfig(
enable_eplb=False,
enable_expert_parallel=False,
)
# Should not raise - None is allowed when EPLB is disabled
assert parallel_config.eplb_config.num_redundant_experts is None

# Test ParallelConfig validation - EPLB disabled with non-zero value
with pytest.raises(ValidationError, match="EPLB is not enabled"):
ParallelConfig(
enable_eplb=False,
enable_expert_parallel=False,
eplb_config=EPLBConfig(num_redundant_experts=4),
)

# Test validation for negative value (validated in ParallelConfig)
with pytest.raises(ValidationError, match="non-negative"):
ParallelConfig(
enable_eplb=False,
enable_expert_parallel=False,
eplb_config=EPLBConfig(num_redundant_experts=-1),
)


@pytest.mark.parametrize(
"num_experts,tp_size,dp_size,expected",
[
(8, 8, 1, 0), # ep_size=8, divisible: 8 % 8 = 0
(8, 8, 2, 8), # ep_size=16, ep_size > experts: need 8 redundant
(8, 2, 3, 4), # ep_size=6, non-divisible: need 4 redundant
(16, 4, 2, 0), # ep_size=8, divisible: 16 % 8 = 0
(10, 4, 2, 6), # ep_size=8, non-divisible: need 6 redundant
(7, 2, 2, 1), # ep_size=4, non-divisible: need 1 redundant
(1, 2, 2, 3), # ep_size=4, single expert: need 3 redundant
],
)
def test_eplb_num_redundant_experts_auto_computation(
num_experts, tp_size, dp_size, expected
):
"""Test that num_redundant_experts is correctly computed by ParallelConfig.

The computation ensures (num_logical_experts + num_redundant_experts)
is divisible by ep_size (= tp_size * dp_size).
"""
from vllm.config.parallel import ParallelConfig

parallel_config = ParallelConfig(
tensor_parallel_size=tp_size,
data_parallel_size=dp_size,
enable_expert_parallel=True,
enable_eplb=True,
)
# num_redundant_experts should be None before computation
assert parallel_config.eplb_config.num_redundant_experts is None

# Call the computation method
parallel_config.compute_eplb_num_redundant_experts(num_experts)

# Verify the computed value matches expected
assert parallel_config.eplb_config.num_redundant_experts == expected, (
f"Expected num_redundant_experts={expected} for "
f"num_experts={num_experts}, ep_size={tp_size * dp_size}, "
f"got {parallel_config.eplb_config.num_redundant_experts}"
)
# Verify divisibility constraint
ep_size = tp_size * dp_size
total = num_experts + parallel_config.eplb_config.num_redundant_experts
assert total % ep_size == 0, (
f"Divisibility check failed: ({num_experts} + "
f"{parallel_config.eplb_config.num_redundant_experts}) % {ep_size} != 0"
)


def test_eplb_num_redundant_experts_disabled():
"""Test that num_redundant_experts defaults to 0 when EPLB is disabled."""
from vllm.config.parallel import ParallelConfig

parallel_config = ParallelConfig(
tensor_parallel_size=2,
data_parallel_size=1,
enable_expert_parallel=False,
enable_eplb=False,
)
# num_redundant_experts should be None before computation
assert parallel_config.eplb_config.num_redundant_experts is None

# Call the computation method
parallel_config.compute_eplb_num_redundant_experts(num_logical_experts=8)

# When EPLB is disabled, should default to 0
assert parallel_config.eplb_config.num_redundant_experts == 0


def test_eplb_num_redundant_experts_explicit_value_preserved():
"""Test that explicitly set num_redundant_experts is not overwritten."""
from vllm.config.parallel import EPLBConfig, ParallelConfig

parallel_config = ParallelConfig(
tensor_parallel_size=4,
data_parallel_size=2,
enable_expert_parallel=True,
enable_eplb=True,
eplb_config=EPLBConfig(num_redundant_experts=10),
)
# num_redundant_experts is explicitly set
assert parallel_config.eplb_config.num_redundant_experts == 10

# Call the computation method - should not override
parallel_config.compute_eplb_num_redundant_experts(num_logical_experts=8)

# Should still be the explicit value
assert parallel_config.eplb_config.num_redundant_experts == 10
68 changes: 58 additions & 10 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ class EPLBConfig:
of the last `lb_window_size` steps will be used for rearranging experts.
"""

num_redundant_experts: int = Field(default=0, ge=0)
"""Number of redundant experts to use for expert parallelism."""
num_redundant_experts: int | None = None
"""Number of redundant experts to use for expert parallelism.
If None (default), the minimum valid value will be computed automatically
based on the number of logical experts and the expert parallel size."""

log_balancedness: bool = False
"""
Expand Down Expand Up @@ -310,17 +312,63 @@ def _validate_parallel_config(self) -> Self:
f"to be greater than 1, but got "
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
)
else:
if self.eplb_config.num_redundant_experts != 0:
raise ValueError(
"num_redundant_experts is set to "
f"{self.eplb_config.num_redundant_experts} but EPLB is not "
"enabled. Either enable EPLB or unset "
"num_redundant_experts."
)
# Validate num_redundant_experts is non-negative if explicitly set
if (
self.eplb_config.num_redundant_experts is not None
and self.eplb_config.num_redundant_experts < 0
):
raise ValueError(
f"num_redundant_experts must be non-negative, "
f"got {self.eplb_config.num_redundant_experts}"
)

# When EPLB is disabled, num_redundant_experts must be None or 0
if (
not self.enable_eplb
and self.eplb_config.num_redundant_experts is not None
and self.eplb_config.num_redundant_experts != 0
):
raise ValueError(
"num_redundant_experts is set to "
f"{self.eplb_config.num_redundant_experts} but EPLB is not "
"enabled. Either enable EPLB or unset "
"num_redundant_experts."
)

return self

def compute_eplb_num_redundant_experts(self, num_logical_experts: int) -> None:
"""Compute and set num_redundant_experts if not explicitly specified.

This method should be called after ParallelConfig is initialized and
when the number of logical experts is known (from ModelConfig).

Args:
num_logical_experts: The number of logical experts from the model.
"""
if self.eplb_config.num_redundant_experts is not None:
# Already explicitly set, don't override
return

if self.enable_eplb:
# EP size is TP * DP for EPLB
ep_size = self.tensor_parallel_size * self.data_parallel_size
# Ensure (num_logical_experts + num_redundant_experts) is
# divisible by ep_size, supporting non-standard ep_size values
min_redundant = (ep_size - num_logical_experts % ep_size) % ep_size
self.eplb_config.num_redundant_experts = min_redundant
logger.info(
"EPLB num_redundant_experts not specified, "
"defaulting to minimum valid value: %d "
"(num_logical_experts=%d, ep_size=%d)",
min_redundant,
num_logical_experts,
ep_size,
)
else:
# EPLB disabled, default to 0
self.eplb_config.num_redundant_experts = 0

@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
Expand Down
5 changes: 5 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,11 @@ def __post_init__(self):
self.try_verify_and_update_config()

if self.model_config is not None:
# Compute EPLB num_redundant_experts before verification
num_experts = self.model_config.get_num_experts()
if num_experts is not None:
self.parallel_config.compute_eplb_num_redundant_experts(num_experts)

self.model_config.verify_with_parallel_config(self.parallel_config)
self.model_config.verify_dual_chunk_attention_config(self.load_config)

Expand Down