@@ -753,6 +753,14 @@ def setUp(self):
753753
754754 self .hidden_states = torch .randn (self .num_tokens , self .hidden_size )
755755 self .router_logits = torch .randn (self .num_tokens , self .num_experts )
756+ """Mock custom routing"""
757+ self .mock_custom_routing = MagicMock ()
758+ self .mock_custom_routing .return_value = (torch .ones (
759+ self .num_tokens , self .top_k ),
760+ torch .zeros (
761+ self .num_tokens ,
762+ self .top_k ,
763+ dtype = torch .int32 ))
756764
757765 self .mock_ctx = MagicMock ()
758766 self .mock_ctx .weight_prefetch_method = MagicMock ()
@@ -762,7 +770,7 @@ def setUp(self):
762770 self .addCleanup (patcher .stop )
763771 patcher .start ()
764772
765- @patch ('torch_npu.npu_moe_gating_top_k_softmax ' )
773+ @patch ('torch_npu.npu_moe_gating_top_k ' )
766774 def test_softmax_scoring (self , mock_topk ):
767775 """Test softmax scoring function"""
768776 mock_topk .return_value = (torch .ones (self .num_tokens , self .top_k ),
@@ -789,12 +797,14 @@ def test_softmax_scoring(self, mock_topk):
789797 def test_sigmoid_scoring (self ):
790798 """Test sigmoid scoring function"""
791799
792- weights , ids = select_experts (hidden_states = self .hidden_states ,
793- router_logits = self .router_logits ,
794- top_k = self .top_k ,
795- use_grouped_topk = False ,
796- renormalize = False ,
797- scoring_func = "sigmoid" )
800+ weights , ids = select_experts (
801+ hidden_states = self .hidden_states ,
802+ router_logits = self .router_logits ,
803+ top_k = self .top_k ,
804+ use_grouped_topk = False ,
805+ renormalize = False ,
806+ scoring_func = "sigmoid" ,
807+ custom_routing_function = self .mock_custom_routing )
798808
799809 self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
800810 self .assertEqual (ids .shape , (self .num_tokens , self .top_k ))
@@ -807,7 +817,8 @@ def test_invalid_scoring_func(self):
807817 top_k = self .top_k ,
808818 use_grouped_topk = False ,
809819 renormalize = False ,
810- scoring_func = "invalid_func" )
820+ scoring_func = "invalid_func" ,
821+ custom_routing_function = self .mock_custom_routing )
811822
812823 @patch ('torch.topk' )
813824 def test_grouped_topk (self , mock_topk ):
@@ -817,13 +828,15 @@ def test_grouped_topk(self, mock_topk):
817828 self .top_k ,
818829 dtype = torch .long ))
819830
820- weights , ids = select_experts (hidden_states = self .hidden_states ,
821- router_logits = self .router_logits ,
822- top_k = self .top_k ,
823- use_grouped_topk = True ,
824- renormalize = False ,
825- topk_group = 4 ,
826- num_expert_group = 2 )
831+ weights , ids = select_experts (
832+ hidden_states = self .hidden_states ,
833+ router_logits = self .router_logits ,
834+ top_k = self .top_k ,
835+ use_grouped_topk = True ,
836+ renormalize = False ,
837+ topk_group = 4 ,
838+ num_expert_group = 2 ,
839+ custom_routing_function = self .mock_custom_routing )
827840
828841 mock_topk .assert_called ()
829842 self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
@@ -845,35 +858,29 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
845858 renormalize = False ,
846859 topk_group = 4 ,
847860 num_expert_group = 2 ,
848- e_score_correction_bias = e_score_correction_bias )
861+ e_score_correction_bias = e_score_correction_bias ,
862+ custom_routing_function = self .mock_custom_routing )
849863
850864 mock_grouped_topk .assert_called_once ()
851865 self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
852866 self .assertEqual (ids .shape , (self .num_tokens , self .top_k ))
853867
854868 def test_custom_routing_function (self ):
855869 """Test custom routing function"""
856- mock_custom_routing = MagicMock ()
857- mock_custom_routing .return_value = (torch .ones (self .num_tokens ,
858- self .top_k ),
859- torch .zeros (self .num_tokens ,
860- self .top_k ,
861- dtype = torch .int32 ))
862-
863870 weights , ids = select_experts (
864871 hidden_states = self .hidden_states ,
865872 router_logits = self .router_logits ,
866873 top_k = self .top_k ,
867874 use_grouped_topk = False ,
868875 renormalize = False ,
869- custom_routing_function = mock_custom_routing )
876+ custom_routing_function = self . mock_custom_routing )
870877
871- mock_custom_routing .assert_called_once ()
878+ self . mock_custom_routing .assert_called_once ()
872879 self .assertEqual (weights .shape , (self .num_tokens , self .top_k ))
873880 self .assertEqual (ids .shape , (self .num_tokens , self .top_k ))
874881 self .assertEqual (ids .dtype , torch .int32 )
875882
876- @patch ('torch_npu.npu_moe_gating_top_k_softmax ' )
883+ @patch ('torch_npu.npu_moe_gating_top_k ' )
877884 def test_renormalize (self , mock_topk ):
878885 """Test renormalization"""
879886 mock_topk .return_value = (torch .ones (self .num_tokens , self .top_k ),
@@ -899,13 +906,13 @@ def test_renormalize(self, mock_topk):
899906 sums = weights .sum (dim = - 1 )
900907 self .assertTrue (torch .allclose (sums , torch .ones_like (sums )))
901908
902- @patch ('torch_npu.npu_moe_gating_top_k_softmax ' )
909+ @patch ('torch_npu.npu_moe_gating_top_k ' )
903910 def test_output_dtypes (self , mock_topk ):
904911 """Test output dtypes"""
905912 mock_topk .return_value = (torch .ones (self .num_tokens , self .top_k ),
906913 torch .zeros (self .num_tokens ,
907914 self .top_k ,
908- dtype = torch .long ),
915+ dtype = torch .int32 ),
909916 torch .arange (0 ,
910917 self .num_tokens * self .top_k ,
911918 dtype = torch .int32 ).view (
0 commit comments