Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 28 additions & 35 deletions tests/ut/quantization/test_w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,14 +754,6 @@ def setUp(self):

self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
"""Mock custom routing"""
self.mock_custom_routing = MagicMock()
self.mock_custom_routing.return_value = (torch.ones(
self.num_tokens, self.top_k),
torch.zeros(
self.num_tokens,
self.top_k,
dtype=torch.int32))

self.mock_ctx = MagicMock()
self.mock_ctx.weight_prefetch_method = MagicMock()
Expand All @@ -771,7 +763,7 @@ def setUp(self):
self.addCleanup(patcher.stop)
patcher.start()

@patch('torch_npu.npu_moe_gating_top_k')
@patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_softmax_scoring(self, mock_topk):
"""Test softmax scoring function"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
Expand All @@ -798,14 +790,12 @@ def test_softmax_scoring(self, mock_topk):
def test_sigmoid_scoring(self):
"""Test sigmoid scoring function"""

weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid",
custom_routing_function=self.mock_custom_routing)
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")

self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
Expand All @@ -818,8 +808,7 @@ def test_invalid_scoring_func(self):
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid_func",
custom_routing_function=self.mock_custom_routing)
scoring_func="invalid_func")

@patch('torch.topk')
def test_grouped_topk(self, mock_topk):
Expand All @@ -829,15 +818,13 @@ def test_grouped_topk(self, mock_topk):
self.top_k,
dtype=torch.long))

weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2,
custom_routing_function=self.mock_custom_routing)
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)

mock_topk.assert_called()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
Expand All @@ -859,29 +846,35 @@ def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
renormalize=False,
topk_group=4,
num_expert_group=2,
e_score_correction_bias=e_score_correction_bias,
custom_routing_function=self.mock_custom_routing)
e_score_correction_bias=e_score_correction_bias)

mock_grouped_topk.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))

def test_custom_routing_function(self):
"""Test custom routing function"""
mock_custom_routing = MagicMock()
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.int32))

weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
custom_routing_function=self.mock_custom_routing)
custom_routing_function=mock_custom_routing)

self.mock_custom_routing.assert_called_once()
mock_custom_routing.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32)

@patch('torch_npu.npu_moe_gating_top_k')
@patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_renormalize(self, mock_topk):
"""Test renormalization"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
Expand All @@ -907,13 +900,13 @@ def test_renormalize(self, mock_topk):
sums = weights.sum(dim=-1)
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))

@patch('torch_npu.npu_moe_gating_top_k')
@patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_output_dtypes(self, mock_topk):
"""Test output dtypes"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.int32),
dtype=torch.long),
torch.arange(0,
self.num_tokens * self.top_k,
dtype=torch.int32).view(
Expand Down
1 change: 0 additions & 1 deletion vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def set_ascend_forward_context(
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)

# fused_moe_state is used in torchair, it will be deleted along with torchair
is_deepseek_v3_r1 = hasattr(
vllm_config.model_config.hf_config, 'n_routed_experts'
) and vllm_config.model_config.hf_config.n_routed_experts == 256
Expand Down
79 changes: 41 additions & 38 deletions vllm_ascend/ops/moe/experts_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch_npu
from vllm.forward_context import get_forward_context

from vllm_ascend.ascend_config import get_ascend_config


def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand Down Expand Up @@ -60,20 +62,21 @@ def select_experts(hidden_states: torch.Tensor,
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
hidden_states, "gate_up")
if custom_routing_function is None:
topk_weights, topk_ids = _select_experts_with_fusion_ops(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
global_num_experts=global_num_experts)
else:
topk_weights, topk_ids = _select_experts_with_fusion_ops(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
global_num_experts=global_num_experts)

if topk_weights is None:
topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
Expand Down Expand Up @@ -168,34 +171,34 @@ def _select_experts_with_fusion_ops(
e_score_correction_bias: Optional[torch.Tensor],
topk_group: Optional[int],
num_expert_group: Optional[int],
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
global_num_experts: int = -1):

if scoring_func == "softmax":
norm_type = 0
topk_group = 1
num_expert_group = 1
else:
norm_type = 1
if e_score_correction_bias is not None and \
e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(
router_logits.dtype)
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k,
bias=e_score_correction_bias,
k_group=topk_group,
group_count=num_expert_group,
group_select_mode=1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=norm_type, # 0: softmax; 1: sigmoid
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
if scoring_func == "softmax":
topk_weights, topk_ids = None, None
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
global_redundant_expert_num = get_ascend_config().init_redundancy_expert
is_deepseek_v3_r1 = global_num_experts - global_redundant_expert_num == 256
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The if statement on this line could potentially conflict with the preceding if is_deepseek_v3_r1: block. If conditions for both blocks are met (e.g., for a deepseek_v3_r1 model with scoring_func="softmax"), the results from the first block will be overwritten. This could lead to incorrect expert selection. Using elif would make these conditions mutually exclusive and prevent this potential bug.

Suggested change
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
elif not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":

topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
x=router_logits, finished=None, k=top_k)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)

return topk_weights, topk_ids
Expand Down
Loading