Skip to content

Commit d628411

Browse files
bugfix
Signed-off-by: shenchuxiaofugui <[email protected]>
1 parent e093c82 commit d628411

File tree

2 files changed

+8
-14
lines changed

2 files changed

+8
-14
lines changed

vllm_ascend/ops/expert_load_balancer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ def __init__(self, expert_map_path, num_experts):
1414
self.tensor_data = []
1515
self.expert_map_tensor, self.layers_num, self.ranks_num = (
1616
self._expert_file_to_tensor())
17-
self.global_expert_num = num_experts + self.get_global_redundant_expert_num(
18-
)
1917
self.expert_placement_map = self.generate_expert_placement_map()
2018

2119
def _expert_file_to_tensor(self):
@@ -47,7 +45,7 @@ def generate_index_dicts(self, tensor_2d):
4745

4846
def generate_expert_placement_map(self):
4947
expert_placement_map = torch.full(
50-
(self.layers_num, self.ranks_num, self.global_expert_num),
48+
(self.layers_num, self.ranks_num, self.num_experts),
5149
-1,
5250
dtype=torch.int32,
5351
)
@@ -70,7 +68,7 @@ def generate_log2phy_expert_map(self, layer_id):
7068
result_dict[key] = []
7169
result_dict[key].append(idx)
7270

73-
log2phy_map = torch.full((self.ranks_num, self.global_expert_num),
71+
log2phy_map = torch.full((self.ranks_num, self.num_experts),
7472
-1,
7573
dtype=torch.int32)
7674
for rank in range(self.ranks_num):

vllm_ascend/ops/fused_moe/token_dispatcher.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,18 +118,15 @@ def get_dispatch_mc2_kwargs(
118118
mc2_mask: torch.Tensor,
119119
global_redundant_expert_num: int = 0,
120120
):
121-
if self.with_quant:
122-
quant_mode = 2
123-
moe_expert_num = len(expert_map)
124-
else:
125-
quant_mode = 0
126-
moe_expert_num = len(expert_map)
121+
quant_mode = 2 if self.with_quant else 0
122+
self.physics_num_experts = len(expert_map) + global_redundant_expert_num
123+
127124
kwargs_mc2 = {
128125
"x": hidden_states,
129126
"expert_ids": topk_ids,
130127
"expert_shard_type": 0,
131128
"shared_expert_rank_num": 0,
132-
"moe_expert_num": moe_expert_num,
129+
"moe_expert_num": self.physics_num_experts,
133130
"global_bs": 0,
134131
"expert_token_nums_type": 0,
135132
}
@@ -247,15 +244,14 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
247244
expand_scales = context_metadata["expand_scales"]
248245

249246
assert expert_map is not None
250-
moe_expert_num = len(expert_map)
251247

252248
kwargs_mc2 = {
253249
"expand_x": hidden_states,
254250
"expert_ids": topk_ids,
255251
"expert_scales": topk_weights.to(torch.float32),
256252
"expert_shard_type": 0,
257253
"shared_expert_rank_num": 0,
258-
"moe_expert_num": moe_expert_num,
254+
"moe_expert_num": self.physics_num_experts,
259255
"global_bs": 0,
260256
}
261257

@@ -360,7 +356,7 @@ def token_dispatch(self,
360356
hidden_states = hidden_states * \
361357
topk_weights.to(hidden_states.dtype)
362358
if expert_map is not None:
363-
global_num_experts = len(expert_map)
359+
global_num_experts = len(expert_map) + global_redundant_expert_num
364360
mask = (expert_map[topk_ids] != -1)
365361
topk_weights = topk_weights * mask
366362
first_expert_idx = get_ep_group(

0 commit comments

Comments
 (0)