@@ -58,7 +58,6 @@ def setUp(self):
5858
5959 kwargs = {"with_quant" : False , "top_k" : 8 , "num_experts" : 128 }
6060 self .dispatcher = TokenDispatcherWithMC2 (** kwargs )
61- self .row_idx = torch .arange (10 , dtype = torch .int32 )
6261
6362 def tearDown (self ):
6463 self .mc2_group_patch .stop ()
@@ -96,7 +95,7 @@ def test_token_permutation_dispatch(self):
9695 (None , None )) as mock_dispatch :
9796 output = self .dispatcher .token_dispatch (hidden_states ,
9897 topk_weights , topk_ids ,
99- self . row_idx , expert_map )
98+ expert_map )
10099 mock_dispatch .assert_called_once ()
101100 self .assertEqual (output ["group_list_type" ],
102101 0 ) # group_list_type == 0
@@ -117,7 +116,6 @@ def test_token_dispatch_with_shared_experts_and_quant(self):
117116 self .dispatcher .token_dispatch (self .hidden_states ,
118117 self .topk_weights ,
119118 torch .randint (0 , 8 , (10 , 1 )),
120- self .row_idx ,
121119 torch .tensor (
122120 [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ]),
123121 shared_experts = self .shared_experts )
@@ -181,7 +179,6 @@ def setUp(self):
181179 torch .tensor ([0 , 1 , 2 , 3 , 4 , 5 ]), # expanded_row_idx
182180 torch .tensor ([0 , 1 , 0 , 1 , 0 , 1 ]), # expanded_expert_idx
183181 torch .tensor ([0 , 1 , 0 , 1 , 0 , 1 ]))
184- self .row_idx = torch .arange (10 , dtype = torch .int32 )
185182 self .patcher_npu_moe_token_unpermute = patch (
186183 'torch_npu.npu_moe_token_unpermute' )
187184 self .mock_npu_moe_token_unpermute = self .patcher_npu_moe_token_unpermute .start (
@@ -198,7 +195,7 @@ def test_token_dispatch_without_expert_map(self):
198195 topk_ids = torch .tensor ([[0 , 1 ], [1 , 2 ], [2 , 3 ]])
199196
200197 results = self .dispatcher .token_dispatch (hidden_states , topk_weights ,
201- topk_ids , self . row_idx , None )
198+ topk_ids , None )
202199
203200 # Verify npu_moe_init_routing is called
204201 self .mock_npu_moe_init_routing_v2 .assert_called_once ()
@@ -213,7 +210,7 @@ def test_token_dispatch_with_expert_map(self):
213210 topk_ids = torch .tensor ([[0 , 1 ], [1 , 2 ], [2 , 3 ]])
214211
215212 results = self .dispatcher .token_dispatch (hidden_states , topk_weights ,
216- topk_ids , self . row_idx , None )
213+ topk_ids , None )
217214
218215 # Verify npu_moe_init_routing is called
219216 self .mock_npu_moe_init_routing_v2 .assert_called_once ()
@@ -237,7 +234,7 @@ def test_token_dispatch_without_quant(self):
237234
238235 results = self .dispatcher_quant .token_dispatch (hidden_states ,
239236 topk_weights , topk_ids ,
240- self . row_idx , None )
237+ None )
241238
242239 self .assertEqual (results ["group_list_type" ], 1 )
243240
@@ -258,7 +255,6 @@ def test_token_dispatch_with_quant(self):
258255 results = self .dispatcher_quant .token_dispatch (hidden_states ,
259256 topk_weights ,
260257 topk_ids ,
261- self .row_idx ,
262258 None ,
263259 with_quant = True )
264260
@@ -401,7 +397,6 @@ def setUp(self):
401397 num_experts = 4 ,
402398 num_local_experts = 2 ,
403399 with_quant = False )
404- self .row_idx = torch .arange (10 , dtype = torch .int32 )
405400
406401 def test_token_dispatch (self ):
407402 hidden_states = torch .randn (8 , 16 )
@@ -416,7 +411,6 @@ def test_token_dispatch(self):
416411 result = self .dispatcher .token_dispatch (hidden_states = hidden_states ,
417412 topk_weights = topk_weights ,
418413 topk_ids = topk_ids ,
419- row_idx = self .row_idx ,
420414 expert_map = expert_map )
421415
422416 self .assertIsNotNone (result ["hidden_states" ])
@@ -463,7 +457,6 @@ def test_token_dispatch_with_quant(self):
463457 result = self .dispatcher .token_dispatch (hidden_states = hidden_states ,
464458 topk_weights = topk_weights ,
465459 topk_ids = topk_ids ,
466- row_idx = self .row_idx ,
467460 expert_map = expert_map ,
468461 with_quant = True )
469462
@@ -492,7 +485,6 @@ def test_token_dispatch_with_quant_no_active_tokens(self):
492485 result = self .dispatcher .token_dispatch (hidden_states = hidden_states ,
493486 topk_weights = topk_weights ,
494487 topk_ids = topk_ids ,
495- row_idx = self .row_idx ,
496488 expert_map = expert_map ,
497489 with_quant = True )
498490
@@ -515,7 +507,6 @@ def test_token_dispatch_with_log2phy(self):
515507 result = self .dispatcher .token_dispatch (hidden_states = hidden_states ,
516508 topk_weights = topk_weights ,
517509 topk_ids = topk_ids ,
518- row_idx = self .row_idx ,
519510 expert_map = expert_map ,
520511 log2phy = log2phy )
521512
0 commit comments