Skip to content

Commit 9328f37

Browse files
authored
[refactor]support gatingtopk operator generalization (#2958)
### What this PR does / why we need it? Past: npu_moe_gating_top_k can only support 'group_count=256' pattern Now: 1、npu_moe_gating_top_k support all size of group_count 2、the functionality of `torch_npu.npu_moe_gating_top_k_softmax` are included in `torch_npu.npu_moe_gating_top_k` CANN: depends on 8.3.RC1 Performance: 1. GLM4.5-w8a8, TPS improve 6% 2. Qwen3, the same as before - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@2918c1b Signed-off-by: 1092626063 <[email protected]>
1 parent 63561d6 commit 9328f37

File tree

3 files changed

+74
-69
lines changed

3 files changed

+74
-69
lines changed

tests/ut/quantization/test_w8a8.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,14 @@ def setUp(self):
754754

755755
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
756756
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
757+
"""Mock custom routing"""
758+
self.mock_custom_routing = MagicMock()
759+
self.mock_custom_routing.return_value = (torch.ones(
760+
self.num_tokens, self.top_k),
761+
torch.zeros(
762+
self.num_tokens,
763+
self.top_k,
764+
dtype=torch.int32))
757765

758766
self.mock_ctx = MagicMock()
759767
self.mock_ctx.weight_prefetch_method = MagicMock()
@@ -763,7 +771,7 @@ def setUp(self):
763771
self.addCleanup(patcher.stop)
764772
patcher.start()
765773

766-
@patch('torch_npu.npu_moe_gating_top_k_softmax')
774+
@patch('torch_npu.npu_moe_gating_top_k')
767775
def test_softmax_scoring(self, mock_topk):
768776
"""Test softmax scoring function"""
769777
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
@@ -790,12 +798,14 @@ def test_softmax_scoring(self, mock_topk):
790798
def test_sigmoid_scoring(self):
791799
"""Test sigmoid scoring function"""
792800

793-
weights, ids = select_experts(hidden_states=self.hidden_states,
794-
router_logits=self.router_logits,
795-
top_k=self.top_k,
796-
use_grouped_topk=False,
797-
renormalize=False,
798-
scoring_func="sigmoid")
801+
weights, ids = select_experts(
802+
hidden_states=self.hidden_states,
803+
router_logits=self.router_logits,
804+
top_k=self.top_k,
805+
use_grouped_topk=False,
806+
renormalize=False,
807+
scoring_func="sigmoid",
808+
custom_routing_function=self.mock_custom_routing)
799809

800810
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
801811
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
@@ -808,7 +818,8 @@ def test_invalid_scoring_func(self):
808818
top_k=self.top_k,
809819
use_grouped_topk=False,
810820
renormalize=False,
811-
scoring_func="invalid_func")
821+
scoring_func="invalid_func",
822+
custom_routing_function=self.mock_custom_routing)
812823

813824
@patch('torch.topk')
814825
def test_grouped_topk(self, mock_topk):
@@ -818,13 +829,15 @@ def test_grouped_topk(self, mock_topk):
818829
self.top_k,
819830
dtype=torch.long))
820831

821-
weights, ids = select_experts(hidden_states=self.hidden_states,
822-
router_logits=self.router_logits,
823-
top_k=self.top_k,
824-
use_grouped_topk=True,
825-
renormalize=False,
826-
topk_group=4,
827-
num_expert_group=2)
832+
weights, ids = select_experts(
833+
hidden_states=self.hidden_states,
834+
router_logits=self.router_logits,
835+
top_k=self.top_k,
836+
use_grouped_topk=True,
837+
renormalize=False,
838+
topk_group=4,
839+
num_expert_group=2,
840+
custom_routing_function=self.mock_custom_routing)
828841

829842
mock_topk.assert_called()
830843
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
@@ -846,35 +859,29 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
846859
renormalize=False,
847860
topk_group=4,
848861
num_expert_group=2,
849-
e_score_correction_bias=e_score_correction_bias)
862+
e_score_correction_bias=e_score_correction_bias,
863+
custom_routing_function=self.mock_custom_routing)
850864

851865
mock_grouped_topk.assert_called_once()
852866
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
853867
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
854868

855869
def test_custom_routing_function(self):
856870
"""Test custom routing function"""
857-
mock_custom_routing = MagicMock()
858-
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
859-
self.top_k),
860-
torch.zeros(self.num_tokens,
861-
self.top_k,
862-
dtype=torch.int32))
863-
864871
weights, ids = select_experts(
865872
hidden_states=self.hidden_states,
866873
router_logits=self.router_logits,
867874
top_k=self.top_k,
868875
use_grouped_topk=False,
869876
renormalize=False,
870-
custom_routing_function=mock_custom_routing)
877+
custom_routing_function=self.mock_custom_routing)
871878

872-
mock_custom_routing.assert_called_once()
879+
self.mock_custom_routing.assert_called_once()
873880
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
874881
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
875882
self.assertEqual(ids.dtype, torch.int32)
876883

877-
@patch('torch_npu.npu_moe_gating_top_k_softmax')
884+
@patch('torch_npu.npu_moe_gating_top_k')
878885
def test_renormalize(self, mock_topk):
879886
"""Test renormalization"""
880887
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
@@ -900,13 +907,13 @@ def test_renormalize(self, mock_topk):
900907
sums = weights.sum(dim=-1)
901908
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
902909

903-
@patch('torch_npu.npu_moe_gating_top_k_softmax')
910+
@patch('torch_npu.npu_moe_gating_top_k')
904911
def test_output_dtypes(self, mock_topk):
905912
"""Test output dtypes"""
906913
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
907914
torch.zeros(self.num_tokens,
908915
self.top_k,
909-
dtype=torch.long),
916+
dtype=torch.int32),
910917
torch.arange(0,
911918
self.num_tokens * self.top_k,
912919
dtype=torch.int32).view(

vllm_ascend/ascend_forward_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def set_ascend_forward_context(
9898
ep_size = (get_ep_group().world_size if
9999
vllm_config.parallel_config.enable_expert_parallel else 1)
100100

101+
# fused_moe_state is used in torchair, it will be deleted along with torchair
101102
is_deepseek_v3_r1 = hasattr(
102103
vllm_config.model_config.hf_config, 'n_routed_experts'
103104
) and vllm_config.model_config.hf_config.n_routed_experts == 256

vllm_ascend/ops/fused_moe/experts_selector.py

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import torch_npu
2121
from vllm.forward_context import get_forward_context
2222

23-
from vllm_ascend.ascend_config import get_ascend_config
24-
2523

2624
def select_experts(hidden_states: torch.Tensor,
2725
router_logits: torch.Tensor,
@@ -62,21 +60,20 @@ def select_experts(hidden_states: torch.Tensor,
6260
if weight_prefetch_method:
6361
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
6462
hidden_states, "gate_up")
65-
topk_weights, topk_ids = _select_experts_with_fusion_ops(
66-
hidden_states=hidden_states,
67-
router_logits=router_logits,
68-
top_k=top_k,
69-
use_grouped_topk=use_grouped_topk,
70-
topk_group=topk_group,
71-
renormalize=renormalize,
72-
e_score_correction_bias=e_score_correction_bias,
73-
num_expert_group=num_expert_group,
74-
custom_routing_function=custom_routing_function,
75-
scoring_func=scoring_func,
76-
routed_scaling_factor=routed_scaling_factor,
77-
global_num_experts=global_num_experts)
78-
79-
if topk_weights is None:
63+
if custom_routing_function is None:
64+
topk_weights, topk_ids = _select_experts_with_fusion_ops(
65+
hidden_states=hidden_states,
66+
router_logits=router_logits,
67+
top_k=top_k,
68+
use_grouped_topk=use_grouped_topk,
69+
topk_group=topk_group,
70+
renormalize=renormalize,
71+
e_score_correction_bias=e_score_correction_bias,
72+
num_expert_group=num_expert_group,
73+
scoring_func=scoring_func,
74+
routed_scaling_factor=routed_scaling_factor,
75+
global_num_experts=global_num_experts)
76+
else:
8077
topk_weights, topk_ids = _native_select_experts(
8178
hidden_states=hidden_states,
8279
router_logits=router_logits,
@@ -171,34 +168,34 @@ def _select_experts_with_fusion_ops(
171168
e_score_correction_bias: Optional[torch.Tensor],
172169
topk_group: Optional[int],
173170
num_expert_group: Optional[int],
174-
custom_routing_function: Optional[Callable] = None,
175171
scoring_func: str = "softmax",
176172
routed_scaling_factor=1.0,
177173
global_num_experts: int = -1):
178174

179-
topk_weights, topk_ids = None, None
180-
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
181-
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
182-
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
183-
if is_deepseek_v3_r1:
184-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
185-
router_logits,
186-
k=top_k, # topk currently 8
187-
bias=e_score_correction_bias,
188-
k_group=topk_group, # fix: 4
189-
group_count=num_expert_group, # fix 8
190-
group_select_mode=
191-
1, # 0: the maximum in the group; 1: topk2.sum(fix)
192-
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
193-
norm_type=1, # 0: softmax; 1: sigmoid(fix)
194-
# out_flag=False, # todo new api; should the third output be output
195-
# y2_flag=False, # old api; should the third output be output
196-
routed_scaling_factor=1,
197-
eps=float(1e-20))
198-
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
199-
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
200-
x=router_logits, finished=None, k=top_k)
201-
topk_ids = topk_ids.to(torch.int32)
175+
if scoring_func == "softmax":
176+
norm_type = 0
177+
topk_group = 1
178+
num_expert_group = 1
179+
else:
180+
norm_type = 1
181+
if e_score_correction_bias is not None and \
182+
e_score_correction_bias.dtype != router_logits.dtype:
183+
e_score_correction_bias = e_score_correction_bias.to(
184+
router_logits.dtype)
185+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
186+
router_logits,
187+
k=top_k,
188+
bias=e_score_correction_bias,
189+
k_group=topk_group,
190+
group_count=num_expert_group,
191+
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
192+
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
193+
norm_type=norm_type, # 0: softmax; 1: sigmoid
194+
# out_flag=False, # todo new api; should the third output be output
195+
# y2_flag=False, # old api; should the third output be output
196+
routed_scaling_factor=1,
197+
eps=float(1e-20))
198+
if scoring_func == "softmax":
202199
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
203200

204201
return topk_weights, topk_ids

0 commit comments

Comments
 (0)