Skip to content

Commit 15b2e5c

Browse files
authored
Remove unused row_idx in token_dispatcher (#3442)
### What this PR does / why we need it? The `row_idx` parameter is no longer used since PR[#2689](#2689), so remove it across multiple files to remove unnecessary calculations and parameter passing. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? accuracy test passed for Qwen3 235B and DeepSeek V3 671B after this PR. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: CaranLic <[email protected]>
1 parent 3642b64 commit 15b2e5c

File tree

11 files changed

+37
-88
lines changed

11 files changed

+37
-88
lines changed

tests/e2e/singlecard/ops/test_fused_moe.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,6 @@ def test_token_dispatcher_with_all_gather(
118118
score = torch.softmax(score, dim=-1, dtype=dtype)
119119
topk_weights, topk_ids = torch.topk(score, topk)
120120
topk_ids = topk_ids.to(torch.int32)
121-
row_idx = (torch.arange(
122-
0,
123-
m * topk,
124-
device=device,
125-
dtype=torch.int32,
126-
).view(topk, -1).permute(1, 0).contiguous())
127121

128122
dispatcher_kwargs = {
129123
"num_experts": e,
@@ -137,7 +131,6 @@ def test_token_dispatcher_with_all_gather(
137131
hidden_states=a,
138132
topk_weights=topk_weights,
139133
topk_ids=topk_ids,
140-
row_idx=row_idx,
141134
expert_map=expert_map,
142135
apply_router_weight_on_input=apply_router_weight_on_input)
143136

@@ -201,12 +194,6 @@ def test_token_dispatcher_with_all_gather_quant(
201194
score = torch.softmax(score, dim=-1, dtype=dtype)
202195
topk_weights, topk_ids = torch.topk(score, topk)
203196
topk_ids = topk_ids.to(torch.int32)
204-
row_idx = (torch.arange(
205-
0,
206-
m * topk,
207-
device=device,
208-
dtype=torch.int32,
209-
).view(topk, -1).permute(1, 0).contiguous())
210197

211198
dispatcher_kwargs = {
212199
"num_experts": e,
@@ -220,7 +207,6 @@ def test_token_dispatcher_with_all_gather_quant(
220207
hidden_states=a,
221208
topk_weights=topk_weights,
222209
topk_ids=topk_ids,
223-
row_idx=row_idx,
224210
expert_map=expert_map,
225211
apply_router_weight_on_input=apply_router_weight_on_input,
226212
with_quant=True)
@@ -297,7 +283,7 @@ def test_select_experts(
297283
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
298284
x)
299285

300-
topk_weights, topk_ids, row_idx = select_experts(
286+
topk_weights, topk_ids = select_experts(
301287
hidden_states=hidden_states,
302288
router_logits=router_logits,
303289
top_k=topk,
@@ -318,7 +304,6 @@ def test_select_experts(
318304
assert topk_weights.shape == (m, topk)
319305
assert topk_ids.shape == (m, topk)
320306
assert topk_ids.dtype == torch.int32
321-
assert row_idx.shape == (m, topk)
322307

323308
gc.collect()
324309
torch.npu.empty_cache()

tests/ut/ops/test_fused_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def test_select_experts(self, mock_dist_env, mock_moe_env,
263263

264264
x = torch.randn(8, 2)
265265
router_logits = torch.randn(8, 2)
266-
topk_weights, topk_ids, _ = select_experts(
266+
topk_weights, topk_ids = select_experts(
267267
hidden_states=x,
268268
router_logits=router_logits,
269269
top_k=2,

tests/ut/ops/test_moe_comm_method.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def test_fused_experts_method(self, mock_unified_apply_mlp,
204204
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
205205
[0.6, 0.4]])
206206
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
207-
row_idx = torch.arange(4)
208207

209208
# Make sure tensors are contiguous and have correct strides
210209
hidden_states = hidden_states.contiguous()
@@ -216,7 +215,6 @@ def test_fused_experts_method(self, mock_unified_apply_mlp,
216215
w2=w2,
217216
topk_weights=topk_weights,
218217
topk_ids=topk_ids,
219-
row_idx=row_idx,
220218
activation="silu")
221219

222220
# Verify result shape

tests/ut/ops/test_token_dispatcher.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def setUp(self):
5858

5959
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
6060
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
61-
self.row_idx = torch.arange(10, dtype=torch.int32)
6261

6362
def tearDown(self):
6463
self.mc2_group_patch.stop()
@@ -96,7 +95,7 @@ def test_token_permutation_dispatch(self):
9695
(None, None)) as mock_dispatch:
9796
output = self.dispatcher.token_dispatch(hidden_states,
9897
topk_weights, topk_ids,
99-
self.row_idx, expert_map)
98+
expert_map)
10099
mock_dispatch.assert_called_once()
101100
self.assertEqual(output["group_list_type"],
102101
0) # group_list_type == 0
@@ -117,7 +116,6 @@ def test_token_dispatch_with_shared_experts_and_quant(self):
117116
self.dispatcher.token_dispatch(self.hidden_states,
118117
self.topk_weights,
119118
torch.randint(0, 8, (10, 1)),
120-
self.row_idx,
121119
torch.tensor(
122120
[0, 1, 2, 3, 4, 5, 6, 7]),
123121
shared_experts=self.shared_experts)
@@ -181,7 +179,6 @@ def setUp(self):
181179
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
182180
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
183181
torch.tensor([0, 1, 0, 1, 0, 1]))
184-
self.row_idx = torch.arange(10, dtype=torch.int32)
185182
self.patcher_npu_moe_token_unpermute = patch(
186183
'torch_npu.npu_moe_token_unpermute')
187184
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
@@ -198,7 +195,7 @@ def test_token_dispatch_without_expert_map(self):
198195
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
199196

200197
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
201-
topk_ids, self.row_idx, None)
198+
topk_ids, None)
202199

203200
# Verify npu_moe_init_routing is called
204201
self.mock_npu_moe_init_routing_v2.assert_called_once()
@@ -213,7 +210,7 @@ def test_token_dispatch_with_expert_map(self):
213210
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
214211

215212
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
216-
topk_ids, self.row_idx, None)
213+
topk_ids, None)
217214

218215
# Verify npu_moe_init_routing is called
219216
self.mock_npu_moe_init_routing_v2.assert_called_once()
@@ -237,7 +234,7 @@ def test_token_dispatch_without_quant(self):
237234

238235
results = self.dispatcher_quant.token_dispatch(hidden_states,
239236
topk_weights, topk_ids,
240-
self.row_idx, None)
237+
None)
241238

242239
self.assertEqual(results["group_list_type"], 1)
243240

@@ -258,7 +255,6 @@ def test_token_dispatch_with_quant(self):
258255
results = self.dispatcher_quant.token_dispatch(hidden_states,
259256
topk_weights,
260257
topk_ids,
261-
self.row_idx,
262258
None,
263259
with_quant=True)
264260

@@ -401,7 +397,6 @@ def setUp(self):
401397
num_experts=4,
402398
num_local_experts=2,
403399
with_quant=False)
404-
self.row_idx = torch.arange(10, dtype=torch.int32)
405400

406401
def test_token_dispatch(self):
407402
hidden_states = torch.randn(8, 16)
@@ -416,7 +411,6 @@ def test_token_dispatch(self):
416411
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
417412
topk_weights=topk_weights,
418413
topk_ids=topk_ids,
419-
row_idx=self.row_idx,
420414
expert_map=expert_map)
421415

422416
self.assertIsNotNone(result["hidden_states"])
@@ -463,7 +457,6 @@ def test_token_dispatch_with_quant(self):
463457
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
464458
topk_weights=topk_weights,
465459
topk_ids=topk_ids,
466-
row_idx=self.row_idx,
467460
expert_map=expert_map,
468461
with_quant=True)
469462

@@ -492,7 +485,6 @@ def test_token_dispatch_with_quant_no_active_tokens(self):
492485
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
493486
topk_weights=topk_weights,
494487
topk_ids=topk_ids,
495-
row_idx=self.row_idx,
496488
expert_map=expert_map,
497489
with_quant=True)
498490

@@ -515,7 +507,6 @@ def test_token_dispatch_with_log2phy(self):
515507
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
516508
topk_weights=topk_weights,
517509
topk_ids=topk_ids,
518-
row_idx=self.row_idx,
519510
expert_map=expert_map,
520511
log2phy=log2phy)
521512

tests/ut/quantization/test_w8a8.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -777,25 +777,25 @@ def test_softmax_scoring(self, mock_topk):
777777
-1).permute(1,
778778
0).contiguous())
779779

780-
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
781-
router_logits=self.router_logits,
782-
top_k=self.top_k,
783-
use_grouped_topk=False,
784-
renormalize=False,
785-
scoring_func="softmax")
780+
weights, ids = select_experts(hidden_states=self.hidden_states,
781+
router_logits=self.router_logits,
782+
top_k=self.top_k,
783+
use_grouped_topk=False,
784+
renormalize=False,
785+
scoring_func="softmax")
786786

787787
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
788788
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
789789

790790
def test_sigmoid_scoring(self):
791791
"""Test sigmoid scoring function"""
792792

793-
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
794-
router_logits=self.router_logits,
795-
top_k=self.top_k,
796-
use_grouped_topk=False,
797-
renormalize=False,
798-
scoring_func="sigmoid")
793+
weights, ids = select_experts(hidden_states=self.hidden_states,
794+
router_logits=self.router_logits,
795+
top_k=self.top_k,
796+
use_grouped_topk=False,
797+
renormalize=False,
798+
scoring_func="sigmoid")
799799

800800
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
801801
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@@ -818,13 +818,13 @@ def test_grouped_topk(self, mock_topk):
818818
self.top_k,
819819
dtype=torch.long))
820820

821-
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
822-
router_logits=self.router_logits,
823-
top_k=self.top_k,
824-
use_grouped_topk=True,
825-
renormalize=False,
826-
topk_group=4,
827-
num_expert_group=2)
821+
weights, ids = select_experts(hidden_states=self.hidden_states,
822+
router_logits=self.router_logits,
823+
top_k=self.top_k,
824+
use_grouped_topk=True,
825+
renormalize=False,
826+
topk_group=4,
827+
num_expert_group=2)
828828

829829
mock_topk.assert_called()
830830
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -838,7 +838,7 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
838838
self.num_experts)
839839

840840
e_score_correction_bias = torch.randn(self.num_experts)
841-
weights, ids, _ = select_experts(
841+
weights, ids = select_experts(
842842
hidden_states=self.hidden_states,
843843
router_logits=self.router_logits,
844844
top_k=self.top_k,
@@ -861,7 +861,7 @@ def test_custom_routing_function(self):
861861
self.top_k,
862862
dtype=torch.int32))
863863

864-
weights, ids, _ = select_experts(
864+
weights, ids = select_experts(
865865
hidden_states=self.hidden_states,
866866
router_logits=self.router_logits,
867867
top_k=self.top_k,
@@ -888,7 +888,7 @@ def test_renormalize(self, mock_topk):
888888
-1).permute(1,
889889
0).contiguous())
890890

891-
weights, ids, _ = select_experts(
891+
weights, ids = select_experts(
892892
hidden_states=self.hidden_states,
893893
router_logits=self.router_logits,
894894
top_k=self.top_k,
@@ -914,7 +914,7 @@ def test_output_dtypes(self, mock_topk):
914914
-1).permute(1,
915915
0).contiguous())
916916

917-
weights, ids, _ = select_experts(
917+
weights, ids = select_experts(
918918
hidden_states=self.hidden_states,
919919
router_logits=self.router_logits,
920920
top_k=self.top_k,

vllm_ascend/ops/common_fused_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def apply(self,
110110
shared_experts: Optional[Any] = None,
111111
**kwargs) -> torch.Tensor:
112112

113-
topk_weights, topk_ids, row_idx = select_experts(
113+
topk_weights, topk_ids = select_experts(
114114
hidden_states=x,
115115
router_logits=router_logits,
116116
top_k=top_k,
@@ -138,7 +138,6 @@ def apply(self,
138138
w2=layer.w2_weight,
139139
topk_weights=topk_weights,
140140
topk_ids=topk_ids,
141-
row_idx=row_idx,
142141
global_num_experts=global_num_experts,
143142
expert_map=expert_map,
144143
shared_experts=shared_experts,

vllm_ascend/ops/moe/experts_selector.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,6 @@
2121
from vllm.forward_context import get_forward_context
2222

2323

24-
def return_row_idx(hidden_states, top_k):
25-
num_tokens = hidden_states.shape[0]
26-
row_idx_len = num_tokens * top_k
27-
row_idx = (torch.arange(0,
28-
row_idx_len,
29-
dtype=torch.int32,
30-
device=hidden_states.device).view(
31-
top_k, -1).permute(1, 0).contiguous())
32-
return row_idx
33-
34-
3524
def select_experts(hidden_states: torch.Tensor,
3625
router_logits: torch.Tensor,
3726
top_k: int,
@@ -71,7 +60,7 @@ def select_experts(hidden_states: torch.Tensor,
7160
if weight_prefetch_method:
7261
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
7362
hidden_states, "gate_up")
74-
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
63+
topk_weights, topk_ids = _select_experts_with_fusion_ops(
7564
hidden_states=hidden_states,
7665
router_logits=router_logits,
7766
top_k=top_k,
@@ -99,9 +88,7 @@ def select_experts(hidden_states: torch.Tensor,
9988
e_score_correction_bias=e_score_correction_bias,
10089
global_num_experts=global_num_experts,
10190
)
102-
if row_idx is None:
103-
row_idx = return_row_idx(hidden_states, top_k)
104-
return topk_weights, topk_ids, row_idx
91+
return topk_weights, topk_ids
10592

10693

10794
def _native_grouped_topk(
@@ -187,7 +174,7 @@ def _select_experts_with_fusion_ops(
187174
routed_scaling_factor=1.0,
188175
global_num_experts: int = -1):
189176

190-
topk_weights, topk_ids, row_idx = None, None, None
177+
topk_weights, topk_ids = None, None
191178
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
192179
is_deepseek_v3_r1 = global_num_experts == 256
193180
if is_deepseek_v3_r1:
@@ -205,14 +192,13 @@ def _select_experts_with_fusion_ops(
205192
# y2_flag=False, # old api; should the third output be output
206193
routed_scaling_factor=1,
207194
eps=float(1e-20))
208-
row_idx = return_row_idx(hidden_states, top_k)
209195
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
210-
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
196+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
211197
x=router_logits, finished=None, k=top_k)
212198
topk_ids = topk_ids.to(torch.int32)
213199
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
214200

215-
return topk_weights, topk_ids, row_idx
201+
return topk_weights, topk_ids
216202

217203

218204
def _native_select_experts(

vllm_ascend/ops/moe/moe_comm_method.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def fused_experts(
8888
w2: torch.Tensor,
8989
topk_weights: torch.Tensor,
9090
topk_ids: torch.Tensor,
91-
row_idx: torch.Tensor,
9291
activation: str = "silu",
9392
apply_router_weight_on_input: bool = False,
9493
use_int8_w8a8: bool = False,
@@ -122,7 +121,6 @@ def fused_experts(
122121
hidden_states=hidden_states,
123122
topk_weights=topk_weights,
124123
topk_ids=topk_ids,
125-
row_idx=row_idx,
126124
expert_map=expert_map,
127125
log2phy=log2phy,
128126
global_redundant_expert_num=global_redundant_expert_num,

0 commit comments

Comments
 (0)