@@ -422,69 +422,6 @@ def token_combine(self,
422422 return final_hidden_states
423423
424424
425- # mypy: disable-error-code="override"
426- class TokenDispatcherWithMoge (MoETokenDispatcher ):
427-
428- def __init__ (self , ** kwargs ):
429- super ().__init__ (** kwargs )
430- self .apply_router_weight_on_input = False
431- self .local_num_experts = self .num_experts // self .ep_size
432- self .local_num_group = self .top_k // self .ep_size
433- self .bsz = None
434-
435- def token_dispatch (self ,
436- hidden_states : torch .Tensor ,
437- topk_weights : torch .Tensor ,
438- topk_ids : torch .Tensor ,
439- expert_map : Optional [torch .Tensor ] = None ,
440- log2phy : Optional [torch .Tensor ] = None ,
441- global_redundant_expert_num : int = 0 ,
442- shared_experts : Optional [Any ] = None ,
443- quantized_x_for_share : Optional [Any ] = None ,
444- dynamic_scale_for_share : Optional [Any ] = None ,
445- mc2_mask : Optional [torch .Tensor ] = None ,
446- apply_router_weight_on_input : bool = False ,
447- with_quant : bool = False ,
448- dynamic_eplb : bool = False ,
449- pertoken_scale : Optional [torch .Tensor ] = None ):
450- self .bsz , _ = hidden_states .shape
451- flatten_topk_ids = topk_ids .view (- 1 )
452- self .sorted_topk_ids = torch .argsort (flatten_topk_ids .float ())
453- self .sorted_topk_ids = self .sorted_topk_ids .to (torch .int32 )
454- sorted_hidden_states = hidden_states .index_select (
455- 0 , self .sorted_topk_ids // self .local_num_group )
456-
457- experts_id = torch .arange (0 ,
458- self .local_num_experts ,
459- dtype = topk_ids .dtype ,
460- device = topk_ids .device )
461- num_tokens_per_expert = (
462- flatten_topk_ids .unsqueeze (- 1 ) == experts_id ).to (
463- torch .float32 ).sum (0 )
464- topk_scales = topk_weights .view (- 1 ).index_select (
465- 0 , self .sorted_topk_ids ).unsqueeze (- 1 )
466- group_list = num_tokens_per_expert .cumsum (dim = 0 ).to (torch .int64 )
467- group_list_type = 0
468- return {
469- "group_list_type" : group_list_type ,
470- "hidden_states" : sorted_hidden_states ,
471- "group_list" : group_list ,
472- "topk_scales" : topk_scales
473- }
474-
475- def token_combine (self ,
476- hidden_states : torch .Tensor ,
477- context_metadata : dict ,
478- bias : torch .Tensor = None ):
479- unsorted_topk_ids = torch .argsort (self .sorted_topk_ids .float ()).to (
480- torch .int32 )
481- unsorted_hidden_states = hidden_states .index_select (
482- 0 , unsorted_topk_ids )
483- final_hidden_states = unsorted_hidden_states .reshape (
484- self .bsz , self .top_k // self .ep_size , - 1 ).sum (1 )
485- return final_hidden_states
486-
487-
488425class TokenDispatcherWithAll2AllV (MoETokenDispatcher ):
489426 """
490427 The implementation of the AlltoAll-based token dispatcher, which handles token
0 commit comments