Skip to content

Conversation

@majiayu000
Copy link
Contributor

Summary

EPLB requires the number of redundant experts to be chosen up front. This PR changes num_redundant_experts to default to None, which triggers automatic computation of the minimum valid value based on the model's configuration.

Changes

  • EPLBConfig.num_redundant_experts now defaults to None instead of 0
  • When EPLB is enabled and num_redundant_experts is None, the minimum valid value is computed as:
    min_redundant = max(0, ep_size - num_logical_experts)
    
    This ensures at least 1 local physical expert per EP rank.
  • Added validation that num_redundant_experts must be non-negative when explicitly set
  • Logs the computed value for debugging

Benefits

  • Reduces friction when enabling EPLB for the first time
  • Allows the same configuration to work across multiple EP sizes
  • Users can still override with an explicit value if needed

Test Plan

  • Added unit test test_eplb_num_redundant_experts_default
  • Tests cover: default value, explicit value, negative value validation, EPLB disabled scenarios

Fixes #30075

@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a helpful feature by automatically calculating the minimum valid num_redundant_experts for EPLB when it's not specified. The changes to EPLBConfig and ParallelConfig are well-implemented, and the logic for the automatic computation in ModelConfig is sound. However, I've identified a gap in test coverage for the new computation logic, which is critical to ensure its correctness and prevent future regressions. Please see my detailed comment.

Comment on lines 1055 to 1081
def test_eplb_num_redundant_experts_default():
"""Test that num_redundant_experts defaults to None and can be set."""
from vllm.config.parallel import EPLBConfig, ParallelConfig

# Test default is None
eplb_config = EPLBConfig()
assert eplb_config.num_redundant_experts is None

# Test explicit value
eplb_config_explicit = EPLBConfig(num_redundant_experts=4)
assert eplb_config_explicit.num_redundant_experts == 4

# Test validation for negative value
with pytest.raises(ValueError, match="non-negative"):
EPLBConfig(num_redundant_experts=-1)

# Test ParallelConfig validation - EPLB disabled with None is OK
parallel_config = ParallelConfig(
enable_eplb=False,
enable_expert_parallel=False,
)
# Should not raise - None is allowed when EPLB is disabled
assert parallel_config.eplb_config.num_redundant_experts is None

# Test ParallelConfig validation - EPLB disabled with non-zero value
with pytest.raises(ValueError, match="EPLB is not enabled"):
ParallelConfig(
enable_eplb=False,
enable_expert_parallel=False,
eplb_config=EPLBConfig(num_redundant_experts=4),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test function is a good start and covers the validation logic within EPLBConfig and ParallelConfig. However, it's missing tests for the core feature of this PR: the automatic computation of num_redundant_experts which happens in ModelConfig.verify_with_parallel_config.

To ensure the new logic is robust, please add test cases that cover the following scenarios:

  • When enable_eplb is True and num_redundant_experts is None, verify that it's correctly computed as max(0, ep_size - num_logical_experts). You'll need to mock ModelConfig.get_num_experts() and set tensor_parallel_size and data_parallel_size in ParallelConfig to test this.
  • When enable_eplb is False and num_redundant_experts is None, verify that it defaults to 0.

Adding these tests will help prevent regressions and ensure the correctness of this important new feature.

@mergify
Copy link

mergify bot commented Dec 13, 2025

Hi @majiayu000, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@majiayu000 majiayu000 force-pushed the fix/eplb-default-redundant-experts-30075 branch from efdf56b to 377b395 Compare December 13, 2025 15:22
@ApostaC
Copy link
Collaborator

ApostaC commented Dec 16, 2025

@WoosukKwon @youkaichao Can you help take a look at this PR? Thanks!

Copy link
Contributor

@ilmarkov ilmarkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for you work! Left some comments

)
# Minimum value ensures at least 1 local physical expert per rank:
# (num_logical_experts + num_redundant_experts) / ep_size >= 1
min_redundant = max(0, ep_size - num_logical_experts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to set the min_redundant to a value closest divisible by ep_size. So that we would automatically could support unusual ep_sizes? Something like (ep_size - num_logical_experts % ep_size) % ep_size

If None (default), the minimum valid value will be computed automatically
based on the number of logical experts and the expert parallel size."""

@model_validator(mode="after")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it has to be either in _validate_parallel_config or _validate_eplb_config not in separate method just for one config parameter.

"""

# Test the formula logic directly
def compute_min_redundant(num_logical_experts: int, ep_size: int) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we want to test the executed code in the main codebase not the replicated formula.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Updated per your feedback:

  • Formula now uses (ep_size - num_experts % ep_size) % ep_size for non-divisible cases
  • Moved validation to _validate_parallel_config
  • Tests now use parametrize without duplicating formula

@mergify
Copy link

mergify bot commented Dec 16, 2025

Hi @majiayu000, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

(7, 4, 1), # Non-divisible: (4 - 7%4) % 4 = (4-3)%4 = 1
(1, 4, 3), # Single expert: (4 - 1%4) % 4 = 3
])
def test_eplb_num_redundant_experts_auto_computation(num_experts, ep_size, expected):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this test you want to check parallel_config.eplb_config.num_redundant_experts is computed correctly at the initialization not the formula that you replicate here.

if parallel_config.enable_expert_parallel:
self._verify_with_expert_parallelism()

# Compute num_redundant_experts if not specified
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the updating config logic shouldn't be in the function named verify_. I would put it in post init of eplb_config.

Copy link
Contributor

@ilmarkov ilmarkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! Added short comments on the test and logic location. Also, please, fix the pre-commit.

@tlrmchlsmth I recall, you mentioned that such automatic choice of num_redundant_experts would be helpful

@majiayu000
Copy link
Contributor Author

Thanks! Updated per your feedback:

  • Formula now uses (ep_size - num_experts % ep_size) % ep_size for non-divisible cases
  • Moved validation to _validate_parallel_config
  • Tests now use parametrize without duplicating formula

majiayu000 and others added 3 commits December 18, 2025 23:10
When EPLB is enabled but num_redundant_experts is not specified,
automatically compute and use the minimum valid value based on:
- Number of logical experts in the model
- Expert parallel size (TP * DP)

The minimum valid value ensures at least 1 local physical expert per rank:
  min_redundant = max(0, ep_size - num_logical_experts)

This reduces friction when enabling EPLB for the first time and allows
the same configuration to work across multiple EP sizes.

Changes:
- EPLBConfig.num_redundant_experts now defaults to None instead of 0
- ModelConfig.verify_with_parallel_config() computes the minimum value
  when num_redundant_experts is None and EPLB is enabled
- Added validation that num_redundant_experts must be non-negative

Fixes vllm-project#30075

Signed-off-by: majiayu000 <[email protected]>
Signed-off-by: lif <[email protected]>
- Use formula (ep_size - num_experts % ep_size) % ep_size to support
  non-standard ep_size values where num_experts is not divisible by ep_size
- Move validation from EPLBConfig._validate_num_redundant_experts to
  ParallelConfig._validate_parallel_config (consolidate validation logic)
- Update tests to use @pytest.mark.parametrize and test actual formula
  instead of duplicating the computation logic

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Signed-off-by: lif <[email protected]>
Address reviewer feedback:
- Move num_redundant_experts computation logic from verify_with_parallel_config
  to a dedicated method compute_eplb_num_redundant_experts in ParallelConfig
- Call the computation method in VllmConfig.__post_init__ before verification
- Update tests to verify actual computed values instead of replicating formula
- Add tests for EPLB disabled case and explicit value preservation
- Fix nested if statement lint warning (SIM102)

Signed-off-by: lif <[email protected]>

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Signed-off-by: lif <[email protected]>
@majiayu000 majiayu000 force-pushed the fix/eplb-default-redundant-experts-30075 branch from be9971a to 0d58dcf Compare December 18, 2025 15:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Default eplb num_redundant_experts to the lowest valid value if unspecified

3 participants