Skip to content

Commit 1fc5313

Browse files
committed
[cherry-pick][refactor]support gatingtopk operator generalization (#4050)
### What this PR does / why we need it? pick from : #2958 Past: npu_moe_gating_top_k can only support 'group_count=256' pattern Now: 1、npu_moe_gating_top_k support all size of group_count 2、the functionality of `torch_npu.npu_moe_gating_top_k_softmax` are included in `torch_npu.npu_moe_gating_top_k` CANN: depends on 8.3.RC1 Performance: 1. GLM4.5-w8a8, TPS improve 6% 2. Qwen3, the same as before Signed-off-by: 1092626063 <[email protected]>
1 parent b6d63bb commit 1fc5313

File tree

3 files changed

+74
-69
lines changed

3 files changed

+74
-69
lines changed

tests/ut/quantization/test_w8a8.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,14 @@ def setUp(self):
753753

754754
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
755755
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
756+
"""Mock custom routing"""
757+
self.mock_custom_routing = MagicMock()
758+
self.mock_custom_routing.return_value = (torch.ones(
759+
self.num_tokens, self.top_k),
760+
torch.zeros(
761+
self.num_tokens,
762+
self.top_k,
763+
dtype=torch.int32))
756764

757765
self.mock_ctx = MagicMock()
758766
self.mock_ctx.weight_prefetch_method = MagicMock()
@@ -762,7 +770,7 @@ def setUp(self):
762770
self.addCleanup(patcher.stop)
763771
patcher.start()
764772

765-
@patch('torch_npu.npu_moe_gating_top_k_softmax')
773+
@patch('torch_npu.npu_moe_gating_top_k')
766774
def test_softmax_scoring(self, mock_topk):
767775
"""Test softmax scoring function"""
768776
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
@@ -789,12 +797,14 @@ def test_softmax_scoring(self, mock_topk):
789797
def test_sigmoid_scoring(self):
790798
"""Test sigmoid scoring function"""
791799

792-
weights, ids = select_experts(hidden_states=self.hidden_states,
793-
router_logits=self.router_logits,
794-
top_k=self.top_k,
795-
use_grouped_topk=False,
796-
renormalize=False,
797-
scoring_func="sigmoid")
800+
weights, ids = select_experts(
801+
hidden_states=self.hidden_states,
802+
router_logits=self.router_logits,
803+
top_k=self.top_k,
804+
use_grouped_topk=False,
805+
renormalize=False,
806+
scoring_func="sigmoid",
807+
custom_routing_function=self.mock_custom_routing)
798808

799809
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
800810
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@@ -807,7 +817,8 @@ def test_invalid_scoring_func(self):
807817
top_k=self.top_k,
808818
use_grouped_topk=False,
809819
renormalize=False,
810-
scoring_func="invalid_func")
820+
scoring_func="invalid_func",
821+
custom_routing_function=self.mock_custom_routing)
811822

812823
@patch('torch.topk')
813824
def test_grouped_topk(self, mock_topk):
@@ -817,13 +828,15 @@ def test_grouped_topk(self, mock_topk):
817828
self.top_k,
818829
dtype=torch.long))
819830

820-
weights, ids = select_experts(hidden_states=self.hidden_states,
821-
router_logits=self.router_logits,
822-
top_k=self.top_k,
823-
use_grouped_topk=True,
824-
renormalize=False,
825-
topk_group=4,
826-
num_expert_group=2)
831+
weights, ids = select_experts(
832+
hidden_states=self.hidden_states,
833+
router_logits=self.router_logits,
834+
top_k=self.top_k,
835+
use_grouped_topk=True,
836+
renormalize=False,
837+
topk_group=4,
838+
num_expert_group=2,
839+
custom_routing_function=self.mock_custom_routing)
827840

828841
mock_topk.assert_called()
829842
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -845,35 +858,29 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
845858
renormalize=False,
846859
topk_group=4,
847860
num_expert_group=2,
848-
e_score_correction_bias=e_score_correction_bias)
861+
e_score_correction_bias=e_score_correction_bias,
862+
custom_routing_function=self.mock_custom_routing)
849863

850864
mock_grouped_topk.assert_called_once()
851865
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
852866
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
853867

854868
def test_custom_routing_function(self):
855869
"""Test custom routing function"""
856-
mock_custom_routing = MagicMock()
857-
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
858-
self.top_k),
859-
torch.zeros(self.num_tokens,
860-
self.top_k,
861-
dtype=torch.int32))
862-
863870
weights, ids = select_experts(
864871
hidden_states=self.hidden_states,
865872
router_logits=self.router_logits,
866873
top_k=self.top_k,
867874
use_grouped_topk=False,
868875
renormalize=False,
869-
custom_routing_function=mock_custom_routing)
876+
custom_routing_function=self.mock_custom_routing)
870877

871-
mock_custom_routing.assert_called_once()
878+
self.mock_custom_routing.assert_called_once()
872879
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
873880
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
874881
self.assertEqual(ids.dtype, torch.int32)
875882

876-
@patch('torch_npu.npu_moe_gating_top_k_softmax')
883+
@patch('torch_npu.npu_moe_gating_top_k')
877884
def test_renormalize(self, mock_topk):
878885
"""Test renormalization"""
879886
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
@@ -899,13 +906,13 @@ def test_renormalize(self, mock_topk):
899906
sums = weights.sum(dim=-1)
900907
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
901908

902-
@patch('torch_npu.npu_moe_gating_top_k_softmax')
909+
@patch('torch_npu.npu_moe_gating_top_k')
903910
def test_output_dtypes(self, mock_topk):
904911
"""Test output dtypes"""
905912
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
906913
torch.zeros(self.num_tokens,
907914
self.top_k,
908-
dtype=torch.long),
915+
dtype=torch.int32),
909916
torch.arange(0,
910917
self.num_tokens * self.top_k,
911918
dtype=torch.int32).view(

vllm_ascend/ascend_forward_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def set_ascend_forward_context(
9696
ep_size = (get_ep_group().world_size if
9797
vllm_config.parallel_config.enable_expert_parallel else 1)
9898

99+
# fused_moe_state is used in torchair, it will be deleted along with torchair
99100
is_deepseek_v3_r1 = hasattr(
100101
vllm_config.model_config.hf_config, 'n_routed_experts'
101102
) and vllm_config.model_config.hf_config.n_routed_experts == 256

vllm_ascend/ops/moe/experts_selector.py

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import torch_npu
2121
from vllm.forward_context import get_forward_context
2222

23-
from vllm_ascend.ascend_config import get_ascend_config
24-
2523

2624
def select_experts(hidden_states: torch.Tensor,
2725
router_logits: torch.Tensor,
@@ -62,21 +60,20 @@ def select_experts(hidden_states: torch.Tensor,
6260
if weight_prefetch_method:
6361
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
6462
hidden_states, "gate_up")
65-
topk_weights, topk_ids = _select_experts_with_fusion_ops(
66-
hidden_states=hidden_states,
67-
router_logits=router_logits,
68-
top_k=top_k,
69-
use_grouped_topk=use_grouped_topk,
70-
topk_group=topk_group,
71-
renormalize=renormalize,
72-
e_score_correction_bias=e_score_correction_bias,
73-
num_expert_group=num_expert_group,
74-
custom_routing_function=custom_routing_function,
75-
scoring_func=scoring_func,
76-
routed_scaling_factor=routed_scaling_factor,
77-
global_num_experts=global_num_experts)
78-
79-
if topk_weights is None:
63+
if custom_routing_function is None:
64+
topk_weights, topk_ids = _select_experts_with_fusion_ops(
65+
hidden_states=hidden_states,
66+
router_logits=router_logits,
67+
top_k=top_k,
68+
use_grouped_topk=use_grouped_topk,
69+
topk_group=topk_group,
70+
renormalize=renormalize,
71+
e_score_correction_bias=e_score_correction_bias,
72+
num_expert_group=num_expert_group,
73+
scoring_func=scoring_func,
74+
routed_scaling_factor=routed_scaling_factor,
75+
global_num_experts=global_num_experts)
76+
else:
8077
topk_weights, topk_ids = _native_select_experts(
8178
hidden_states=hidden_states,
8279
router_logits=router_logits,
@@ -171,34 +168,34 @@ def _select_experts_with_fusion_ops(
171168
e_score_correction_bias: Optional[torch.Tensor],
172169
topk_group: Optional[int],
173170
num_expert_group: Optional[int],
174-
custom_routing_function: Optional[Callable] = None,
175171
scoring_func: str = "softmax",
176172
routed_scaling_factor=1.0,
177173
global_num_experts: int = -1):
178174

179-
topk_weights, topk_ids = None, None
180-
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
181-
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
182-
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
183-
if is_deepseek_v3_r1:
184-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
185-
router_logits,
186-
k=top_k, # topk currently 8
187-
bias=e_score_correction_bias,
188-
k_group=topk_group, # fix: 4
189-
group_count=num_expert_group, # fix 8
190-
group_select_mode=
191-
1, # 0: the maximum in the group; 1: topk2.sum(fix)
192-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
193-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
194-
# out_flag=False, # todo new api; should the third output be output
195-
# y2_flag=False, # old api; should the third output be output
196-
routed_scaling_factor=1,
197-
eps=float(1e-20))
198-
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
199-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
200-
x=router_logits, finished=None, k=top_k)
201-
topk_ids = topk_ids.to(torch.int32)
175+
if scoring_func == "softmax":
176+
norm_type = 0
177+
topk_group = 1
178+
num_expert_group = 1
179+
else:
180+
norm_type = 1
181+
if e_score_correction_bias is not None and \
182+
e_score_correction_bias.dtype != router_logits.dtype:
183+
e_score_correction_bias = e_score_correction_bias.to(
184+
router_logits.dtype)
185+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
186+
router_logits,
187+
k=top_k,
188+
bias=e_score_correction_bias,
189+
k_group=topk_group,
190+
group_count=num_expert_group,
191+
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
192+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
193+
norm_type=norm_type, # 0: softmax; 1: sigmoid
194+
# out_flag=False, # todo new api; should the third output be output
195+
# y2_flag=False, # old api; should the third output be output
196+
routed_scaling_factor=1,
197+
eps=float(1e-20))
198+
if scoring_func == "softmax":
202199
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
203200

204201
return topk_weights, topk_ids

0 commit comments

Comments
 (0)