@@ -754,6 +754,14 @@ def setUp(self):
754754
755755 self .hidden_states = torch .randn (self .num_tokens , self .hidden_size )
756756 self .router_logits = torch .randn (self .num_tokens , self .num_experts )
757+ """Mock custom routing"""
758+ self .mock_custom_routing = MagicMock ()
759+ self .mock_custom_routing .return_value = (torch .ones (
760+ self .num_tokens , self .top_k ),
761+ torch .zeros (
762+ self .num_tokens ,
763+ self .top_k ,
764+ dtype = torch .int32 ))
757765
758766 self .mock_ctx = MagicMock ()
759767 self .mock_ctx .weight_prefetch_method = MagicMock ()
@@ -763,7 +771,7 @@ def setUp(self):
763771 self .addCleanup (patcher .stop )
764772 patcher .start ()
765773
766- @patch ('torch_npu.npu_moe_gating_top_k_softmax ' )
774+ @patch ('torch_npu.npu_moe_gating_top_k ' )
767775 def test_softmax_scoring (self , mock_topk ):
768776 """Test softmax scoring function"""
769777 mock_topk .return_value = (torch .ones (self .num_tokens , self .top_k ),
@@ -790,12 +798,14 @@ def test_softmax_scoring(self, mock_topk):
790798 def test_sigmoid_scoring (self ):
791799 """Test sigmoid scoring function"""
792800
793- weights , ids = select_experts (hidden_states = self .hidden_states ,
794- router_logits = self .router_logits ,
795- top_k = self .top_k ,
796- use_grouped_topk = False ,
797- renormalize = False ,
798- scoring_func = "sigmoid" )
801+ weights , ids = select_experts (
802+ hidden_states = self .hidden_states ,
803+ router_logits = self .router_logits ,
804+ top_k = self .top_k ,
805+ use_grouped_topk = False ,
806+ renormalize = False ,
807+ scoring_func = "sigmoid" ,
808+ custom_routing_function = self .mock_custom_routing )
799809
800810 self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
801811 self .assertEqual (ids .shape , (self .num_tokens , self .top_k ))
@@ -808,7 +818,8 @@ def test_invalid_scoring_func(self):
808818 top_k = self .top_k ,
809819 use_grouped_topk = False ,
810820 renormalize = False ,
811- scoring_func = "invalid_func" )
821+ scoring_func = "invalid_func" ,
822+ custom_routing_function = self .mock_custom_routing )
812823
813824 @patch ('torch.topk' )
814825 def test_grouped_topk (self , mock_topk ):
@@ -818,13 +829,15 @@ def test_grouped_topk(self, mock_topk):
818829 self .top_k ,
819830 dtype = torch .long ))
820831
821- weights , ids = select_experts (hidden_states = self .hidden_states ,
822- router_logits = self .router_logits ,
823- top_k = self .top_k ,
824- use_grouped_topk = True ,
825- renormalize = False ,
826- topk_group = 4 ,
827- num_expert_group = 2 )
832+ weights , ids = select_experts (
833+ hidden_states = self .hidden_states ,
834+ router_logits = self .router_logits ,
835+ top_k = self .top_k ,
836+ use_grouped_topk = True ,
837+ renormalize = False ,
838+ topk_group = 4 ,
839+ num_expert_group = 2 ,
840+ custom_routing_function = self .mock_custom_routing )
828841
829842 mock_topk .assert_called ()
830843 self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
@@ -846,35 +859,29 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
846859 renormalize = False ,
847860 topk_group = 4 ,
848861 num_expert_group = 2 ,
849- e_score_correction_bias = e_score_correction_bias )
862+ e_score_correction_bias = e_score_correction_bias ,
863+ custom_routing_function = self .mock_custom_routing )
850864
851865 mock_grouped_topk .assert_called_once ()
852866 self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
853867 self .assertEqual (ids .shape , (self .num_tokens , self .top_k ))
854868
855869 def test_custom_routing_function (self ):
856870 """Test custom routing function"""
857- mock_custom_routing = MagicMock ()
858- mock_custom_routing .return_value = (torch .ones (self .num_tokens ,
859- self .top_k ),
860- torch .zeros (self .num_tokens ,
861- self .top_k ,
862- dtype = torch .int32 ))
863-
864871 weights , ids = select_experts (
865872 hidden_states = self .hidden_states ,
866873 router_logits = self .router_logits ,
867874 top_k = self .top_k ,
868875 use_grouped_topk = False ,
869876 renormalize = False ,
870- custom_routing_function = mock_custom_routing )
877+ custom_routing_function = self . mock_custom_routing )
871878
872- mock_custom_routing .assert_called_once ()
879+ self . mock_custom_routing .assert_called_once ()
873880 self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
874881 self .assertEqual (ids .shape , (self .num_tokens , self .top_k ))
875882 self .assertEqual (ids .dtype , torch .int32 )
876883
877- @patch ('torch_npu.npu_moe_gating_top_k_softmax ' )
884+ @patch ('torch_npu.npu_moe_gating_top_k ' )
878885 def test_renormalize (self , mock_topk ):
879886 """Test renormalization"""
880887 mock_topk .return_value = (torch .ones (self .num_tokens , self .top_k ),
@@ -900,13 +907,13 @@ def test_renormalize(self, mock_topk):
900907 sums = weights .sum (dim = - 1 )
901908 self .assertTrue (torch .allclose (sums , torch .ones_like (sums )))
902909
903- @patch ('torch_npu.npu_moe_gating_top_k_softmax ' )
910+ @patch ('torch_npu.npu_moe_gating_top_k ' )
904911 def test_output_dtypes (self , mock_topk ):
905912 """Test output dtypes"""
906913 mock_topk .return_value = (torch .ones (self .num_tokens , self .top_k ),
907914 torch .zeros (self .num_tokens ,
908915 self .top_k ,
909- dtype = torch .long ),
916+ dtype = torch .int32 ),
910917 torch .arange (0 ,
911918 self .num_tokens * self .top_k ,
912919 dtype = torch .int32 ).view (
0 commit comments