@@ -409,12 +409,14 @@ def step(self,
409409 self .expert_rearrangement_step = 0
410410 self .rearrange (model )
411411
412- def rearrange (self ,
413- model : MixtureOfExperts ,
414- is_profile : bool = False ,
415- execute_shuffle : bool = True ,
416- global_expert_load : Optional [torch .Tensor ] = None ,
417- rank_mapping : Optional [dict [int , int ]] = None ) -> None :
412+ def rearrange (
413+ self ,
414+ model : MixtureOfExperts ,
415+ is_profile : bool = False ,
416+ execute_shuffle : bool = True ,
417+ global_expert_load : Optional [torch .Tensor ] = None ,
418+ rank_mapping : Optional [dict [int ,
419+ int ]] = None ) -> Optional [torch .Tensor ]:
418420 """
419421 Rearrange the experts according to the current load.
420422 """
@@ -548,6 +550,7 @@ def rearrange(self,
548550 " (profile) " if is_profile else " " ,
549551 time_end - time_start ,
550552 )
553+ return None
551554
552555 @staticmethod
553556 def recv_state () -> tuple [torch .Tensor , torch .Tensor ]:
@@ -613,4 +616,4 @@ def _node_count_with_rank_mapping(
613616 if is_same_node and node_assignment [other_rank ] == 0 :
614617 node_assignment [other_rank ] = next_node_id
615618
616- return next_node_id
619+ return next_node_id
0 commit comments