Skip to content

Commit e093c82

Browse files
[EPLB][BugFix] generate experts map with redundant experts
Signed-off-by: shenchuxiaofugui <[email protected]>
1 parent 0d2b18d commit e093c82

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

vllm_ascend/eplb/core/eplb_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,29 @@
2525
import vllm_ascend.envs as envs_ascend
2626

2727

28+
def generate_experts_map(ep_rank, ep_size, n_expert, n_redundant):
29+
def split_and_insert(n, k, m):
30+
all_experts = torch.arange(n)
31+
groups = torch.array_split(all_experts, k)
32+
for i in range(m):
33+
j = i % k + 1
34+
if len(groups[-j]) == 0:
35+
groups[-j] = torch.append(groups[-j], j)
36+
else:
37+
groups[-j] = torch.append(groups[-j], (groups[-j][-1] + 1) % n_expert)
38+
return torch.concatenate(groups)
39+
40+
random_placement = split_and_insert(n_expert, ep_size, n_redundant)
41+
global_num_experts = random_placement.shape[0]
42+
local_num_experts = global_num_experts // ep_size
43+
44+
expert_map = torch.full((random_placement.shape[0]), -1, dtype=torch.int32)
45+
expert_map[ep_rank * local_num_experts: (ep_rank + 1) * local_num_experts] = \
46+
random_placement[ep_rank * local_num_experts: (ep_rank + 1) * local_num_experts]
47+
48+
return expert_map
49+
50+
2851
def generate_log2phy_map(expert_map):
2952
num_local_experts = expert_map.max() + 1
3053
log2phy_map = expert_map.clone()

vllm_ascend/ops/fused_moe/fused_moe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
from vllm_ascend.ascend_config import get_ascend_config
3636
from vllm_ascend.ascend_forward_context import MoECommType
3737
from vllm_ascend.distributed.parallel_state import get_mc2_group
38-
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
38+
from vllm_ascend.eplb.core.eplb_utils import (determine_default_log2phy_map,
39+
generate_experts_map)
3940
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4041
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
4142
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
@@ -182,8 +183,8 @@ def __init__(self, *args, **kwargs):
182183
dtype=vllm_config.model_config.dtype)
183184

184185
# init moe.
185-
self.local_num_experts, self.expert_map, _ = determine_expert_map(
186-
self.ep_size, self.ep_rank, self.global_num_experts)
186+
self.local_num_experts, self.expert_map = generate_experts_map(
187+
self.ep_size, self.ep_rank, num_experts, self.global_redundant_expert_num)
187188
# TODO: Temporary flag to indicate if static EPLB is enabled. This is a
188189
# workaround to bypass a quantization check that fails with float weights.
189190
init_eplb_enable = False

0 commit comments

Comments
 (0)