diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py index c60015539e9..7e80b55c930 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py @@ -556,7 +556,13 @@ async def generate_stream(): instance_info.prefiller_idx, instance_info.prefiller_score) released_kv = True - chunk_str = chunk.decode("utf-8").strip() + try: + chunk_str = chunk.decode("utf-8").strip() + except UnicodeDecodeError: + logger.debug( + f"Skipping chunk: {chunk}") + yield chunk + continue if not chunk_str: continue if chunk_str.startswith("data: "): diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py index 880ed69e5f8..0694aced4af 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py @@ -539,7 +539,13 @@ async def generate_stream(): instance_info.prefiller_idx, instance_info.prefiller_score) released_kv = True - chunk_str = chunk.decode("utf-8").strip() + try: + chunk_str = chunk.decode("utf-8").strip() + except UnicodeDecodeError: + logger.debug( + f"Skipping chunk: {chunk}") + yield chunk + continue if not chunk_str: continue if chunk_str.startswith("data: "): diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index fae3ecb09aa..853baa9370b 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -28,9 +28,10 @@ import torch_npu from vllm.model_executor.layers.activation import SiluAndMul -from vllm_ascend.ops.moe.experts_selector import select_experts -from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp -from vllm_ascend.ops.moe.token_dispatcher import TokenDispatcherWithAllGather +from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp +from vllm_ascend.ops.fused_moe.token_dispatcher import \ + TokenDispatcherWithAllGather NUM_EXPERTS = [8, 64] EP_SIZE = [1] @@ -182,7 +183,7 @@ def test_token_dispatcher_with_all_gather_quant( ): context_mock = MagicMock() context_mock.fused_moe_state = 0 - with patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context", + with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context", return_value=context_mock): a = torch.randn((m, k), device=device, dtype=dtype) / 10 w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8) @@ -282,9 +283,9 @@ def test_select_experts( dtype=torch.int32) custom_routing_function.return_value = (mock_weights, mock_ids) - with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk" + with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk" ) as mock_native_grouped_topk, \ - patch('vllm_ascend.ops.moe.experts_selector.get_forward_context', + patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context', return_value=MagicMock(weight_prefetch_method=MagicMock())): mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( x) @@ -318,7 +319,7 @@ def test_select_experts( @pytest.mark.parametrize("device", DEVICE) def test_select_experts_invalid_scoring_func(device: str): - with patch('vllm_ascend.ops.moe.experts_selector.get_forward_context', + with patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context', return_value=MagicMock(weight_prefetch_method=MagicMock())), \ pytest.raises(ValueError, match="Unsupported scoring function: invalid"): diff --git a/tests/ut/models/conftest.py b/tests/ut/models/conftest.py index 88b8cfa0f2d..4f17e2df78b 100644 --- a/tests/ut/models/conftest.py +++ b/tests/ut/models/conftest.py @@ -90,9 +90,9 @@ def mock_distributed(): mock_vllm_config.scheduler_config = Mock(max_num_seqs=256) mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None) - with patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ - patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \ - patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \ + with patch("vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ + patch("vllm_ascend.ops.fused_moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \ + patch("vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version", return_value=None), \ patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group, _PP=pp_group), \ patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \ diff --git a/tests/ut/ops/test_comm_utils.py b/tests/ut/ops/test_comm_utils.py index 5b4071cea70..16d9eaef41c 100644 --- a/tests/ut/ops/test_comm_utils.py +++ b/tests/ut/ops/test_comm_utils.py @@ -20,7 +20,7 @@ from pytest_mock import MockerFixture from tests.ut.base import PytestBase -from vllm_ascend.ops.moe.comm_utils import ( +from vllm_ascend.ops.fused_moe.comm_utils import ( _gather_along_first_dim, async_all_to_all, gather_from_sequence_parallel_region) diff --git a/tests/ut/ops/test_common_fused_moe.py b/tests/ut/ops/test_common_fused_moe.py deleted file mode 100644 index 6153a4e93d2..00000000000 --- a/tests/ut/ops/test_common_fused_moe.py +++ /dev/null @@ -1,56 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# -from unittest.mock import patch - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.ops.common_fused_moe import AscendFusedMoE - - -class TestLoadWeight(TestBase): - - def test_load_w13_transpose(self): - with patch.object(AscendFusedMoE, "__init__", - lambda self, *args, **kwargs: None): - moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8) - - expert_data = torch.randn(128, 8) - loaded_weight = torch.randn(128, 4) - moe._load_w13(expert_data, 1, "w1", loaded_weight, 0) - - expert_data = torch.randn(8, 128) - loaded_weight = torch.randn(128, 4) - moe._load_w13(expert_data, 1, "w1", loaded_weight, 0) - - expert_data = torch.randn(128, 8) - loaded_weight = torch.randn(128, 4) - moe._load_w13(expert_data, 1, "w3", loaded_weight, 0) - - expert_data = torch.randn(8, 128) - loaded_weight = torch.randn(128, 4) - moe._load_w13(expert_data, 1, "w3", loaded_weight, 0) - - def test_load_w2_transpose(self): - with patch.object(AscendFusedMoE, "__init__", - lambda self, *args, **kwargs: None): - moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8) - expert_data = torch.randn(128, 4) - loaded_weight = torch.randn(128, 8) - moe._load_w2(expert_data, 1, loaded_weight, 0) - - expert_data = torch.randn(4, 128) - loaded_weight = torch.randn(128, 8) - moe._load_w2(expert_data, 1, loaded_weight, 0) diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_moe.py similarity index 86% rename from tests/ut/ops/test_fused_ops.py rename to tests/ut/ops/test_fused_moe.py index a6b0ae02419..8cd2961b703 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_moe.py @@ -24,9 +24,11 @@ from tests.ut.base import TestBase from vllm_ascend.ascend_forward_context import MoECommType -from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod -from vllm_ascend.ops.moe.experts_selector import select_experts -from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp +from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.fused_moe import ( + AscendFusedMoE, AscendUnquantizedFusedMoEMethod) +from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list, + unified_apply_mlp) from vllm_ascend.utils import AscendSocVersion, adapt_patch adapt_patch(True) @@ -69,10 +71,11 @@ def setup_vllm_config_mock(mocker: MockerFixture): mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4) mock_vllm_config.model_config.max_model_len = 2048 - mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config', - return_value=mock_vllm_config) - mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config', + mocker.patch('vllm_ascend.ops.fused_moe.fused_moe.get_current_vllm_config', return_value=mock_vllm_config) + mocker.patch( + 'vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config', + return_value=mock_vllm_config) @pytest.fixture @@ -105,37 +108,37 @@ def mock_finalize(hidden_states, **kwargs): with patch('torch.distributed.get_rank', return_value=0), \ patch('torch.distributed.get_world_size', return_value=4), \ - patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ - patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ - patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ - patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ + patch('vllm_ascend.ops.fused_moe.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ + patch('vllm_ascend.ops.fused_moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ + patch('vllm_ascend.ops.fused_moe.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ + patch('vllm_ascend.ops.fused_moe.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ + patch('vllm_ascend.ops.fused_moe.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('vllm_ascend.ops.common_fused_moe.get_ascend_config', + patch('vllm_ascend.ops.fused_moe.fused_moe.get_ascend_config', return_value=MagicMock( torchair_graph_config=MagicMock(enabled=False), enable_multistream_moe=False, expert_map_path=None )), \ - patch('vllm_ascend.ops.common_fused_moe.determine_expert_map', + patch('vllm_ascend.ops.fused_moe.fused_moe.determine_expert_map', return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ - patch('vllm_ascend.ops.common_fused_moe.get_forward_context', + patch('vllm_ascend.ops.fused_moe.fused_moe.get_forward_context', return_value=mock_forward_context_obj), \ - patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context', + patch('vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context', return_value=mock_forward_context_obj), \ patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \ - patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context', + patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context', return_value=mock_forward_context_obj), \ - patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher', + patch('vllm_ascend.ops.fused_moe.moe_comm_method.MC2CommImpl._get_token_dispatcher', return_value=None), \ - patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher', + patch('vllm_ascend.ops.fused_moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher', return_value=None), \ - patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher', + patch('vllm_ascend.ops.fused_moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher', return_value=None), \ - patch('vllm_ascend.ops.moe.experts_selector.get_forward_context', + patch('vllm_ascend.ops.fused_moe.experts_selector.get_forward_context', return_value=mock_forward_context_obj): yield { @@ -319,8 +322,8 @@ def test_cumsum_group_list_with_type_2(self): class TestUnifiedApplyMLP(TestBase): - @patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context') - @patch('vllm_ascend.ops.moe.moe_mlp.is_310p') + @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context') + @patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_dynamic_quant') @patch('torch_npu.npu_dequant_swiglu_quant') @@ -384,7 +387,7 @@ def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant, self.assertEqual(result.dtype, torch.bfloat16) - @patch('vllm_ascend.ops.moe.moe_mlp.is_310p') + @patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') @@ -426,7 +429,7 @@ def test_unified_apply_mlp_without_quantization(self, self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.float16) - @patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context') + @patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') @@ -486,7 +489,7 @@ def test_unified_apply_mlp_with_quantization_and_dynamic_scale( self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.bfloat16) - @patch('vllm_ascend.ops.moe.moe_mlp.is_310p') + @patch('vllm_ascend.ops.fused_moe.moe_mlp.is_310p') @patch('torch_npu.npu_grouped_matmul') @patch('torch_npu.npu_swiglu') @patch('torch_npu.npu_dynamic_quant') @@ -531,7 +534,7 @@ def test_unified_apply_mlp_without_quantization_310p( self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.float16) - @patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context") + @patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context") @patch("torch_npu.npu_grouped_matmul") @patch("torch_npu.npu_swiglu") @patch("torch_npu.npu_grouped_matmul_swiglu_quant") @@ -595,3 +598,39 @@ def test_unified_apply_mlp_with_quantization_and_fusion_mlp( self.assertTrue(mock_forward_context.with_quant) self.assertEqual(result.shape, hidden_states.shape) self.assertEqual(result.dtype, torch.bfloat16) + + +class TestLoadWeight(TestBase): + + def test_load_w13_transpose(self): + with patch.object(AscendFusedMoE, "__init__", + lambda self, *args, **kwargs: None): + moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8) + + expert_data = torch.randn(128, 8) + loaded_weight = torch.randn(128, 4) + moe._load_w13(expert_data, 1, "w1", loaded_weight, 0) + + expert_data = torch.randn(8, 128) + loaded_weight = torch.randn(128, 4) + moe._load_w13(expert_data, 1, "w1", loaded_weight, 0) + + expert_data = torch.randn(128, 8) + loaded_weight = torch.randn(128, 4) + moe._load_w13(expert_data, 1, "w3", loaded_weight, 0) + + expert_data = torch.randn(8, 128) + loaded_weight = torch.randn(128, 4) + moe._load_w13(expert_data, 1, "w3", loaded_weight, 0) + + def test_load_w2_transpose(self): + with patch.object(AscendFusedMoE, "__init__", + lambda self, *args, **kwargs: None): + moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8) + expert_data = torch.randn(128, 4) + loaded_weight = torch.randn(128, 8) + moe._load_w2(expert_data, 1, loaded_weight, 0) + + expert_data = torch.randn(4, 128) + loaded_weight = torch.randn(128, 8) + moe._load_w2(expert_data, 1, loaded_weight, 0) diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index a3ef4410499..643faa067ac 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -4,8 +4,9 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig from tests.ut.base import TestBase -from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl, - AlltoAllCommImpl, MC2CommImpl) +from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl, + AlltoAllCommImpl, + MC2CommImpl) class TestMoECommMethod(TestBase): @@ -24,12 +25,14 @@ def setUp(self): self.moe_config.dp_group = MagicMock() self.moe_config.num_global_redundant_experts = 0 - @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") - @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") + @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config") + @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch( - "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather" + "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather" + ) + @patch( + "vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather" ) - @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather") def test_all_gather_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, mock_get_forward_context, @@ -72,12 +75,11 @@ def test_all_gather_comm_impl(self, mock_token_dispatcher, context_metadata=context_metadata) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) - @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") - @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") + @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config") + @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch( - "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2" - ) - @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2") + "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithMC2") + @patch("vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithMC2") def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, mock_get_forward_context, mock_get_current_vllm_config): @@ -121,12 +123,14 @@ def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, context_metadata=context_metadata) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) - @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") - @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") + @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config") + @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") @patch( - "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All" + "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAll2All" + ) + @patch( + "vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAll2AllV" ) - @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV") def test_alltoall_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize, mock_get_forward_context, @@ -163,13 +167,15 @@ def test_alltoall_comm_impl(self, mock_token_dispatcher, mock_pf_instance.prepare.assert_called_once_with( hidden_states, router_logits, False, False, None) - @patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config") - @patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context") + @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config") + @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") + @patch( + "vllm_ascend.ops.fused_moe.moe_comm_method.PrepareAndFinalizeWithAllGather" + ) @patch( - "vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather" + "vllm_ascend.ops.fused_moe.moe_comm_method.TokenDispatcherWithAllGather" ) - @patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather") - @patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp") + @patch("vllm_ascend.ops.fused_moe.moe_comm_method.unified_apply_mlp") def test_fused_experts_method(self, mock_unified_apply_mlp, mock_token_dispatcher, mock_prepare_finalize, mock_get_forward_context, diff --git a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py b/tests/ut/ops/test_prepare_finalize.py similarity index 77% rename from tests/ut/ops/test_fused_moe_prepare_and_finalize.py rename to tests/ut/ops/test_prepare_finalize.py index 93b73ecfa2b..557a9f61e56 100644 --- a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py +++ b/tests/ut/ops/test_prepare_finalize.py @@ -4,13 +4,12 @@ import torch from vllm.model_executor.layers.fused_moe import FusedMoEConfig -from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( - FusedMoEPrepareAndFinalizeWithAll2All, - FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, - FusedMoEPrepareAndFinalizeWithNaiveMulticast) +from vllm_ascend.ops.fused_moe.prepare_finalize import ( + PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather, + PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast) -class TestFusedMoEPrepareAndFinalize(unittest.TestCase): +class TestPrepareAndFinalize(unittest.TestCase): def setUp(self): # Mock FusedMoEConfig @@ -24,14 +23,12 @@ def setUp(self): self.moe_config.original_num_experts = 8 @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", + "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size", return_value=1) @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", + "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank", return_value=0) - @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" - ) + @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context") def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank, mock_tp_size): mock_context = MagicMock() @@ -39,7 +36,7 @@ def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank, mock_context.padded_num_tokens = 4 mock_get_forward_context.return_value = mock_context - layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) + layer = PrepareAndFinalizeWithMC2(self.moe_config) hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) @@ -59,14 +56,12 @@ def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank, self.assertEqual(result.shape[0], 3) @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", + "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size", return_value=2) @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", + "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank", return_value=0) - @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" - ) + @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context") @patch("torch.distributed.all_gather") def test_mc2_tp_split_allgather(self, mock_all_gather, mock_get_forward_context, mock_tp_rank, @@ -76,7 +71,7 @@ def test_mc2_tp_split_allgather(self, mock_all_gather, mock_context.padded_num_tokens = 4 mock_get_forward_context.return_value = mock_context - layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) + layer = PrepareAndFinalizeWithMC2(self.moe_config) hidden_states = torch.randn(4, 8) router_logits = torch.randn(4, 2) @@ -108,13 +103,13 @@ def mock_all_gather_func(tensor_list, tensor, group=None): self.assertEqual(final_result.shape[0], 4) @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", + "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size", return_value=1) @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", + "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank", return_value=0) def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size): - layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) + layer = PrepareAndFinalizeWithAll2All(self.moe_config) hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) @@ -130,15 +125,15 @@ def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size): self.assertEqual(result.shape[0], 3) @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size", + "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_world_size", return_value=2) @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank", + "vllm_ascend.ops.fused_moe.prepare_finalize.get_tensor_model_parallel_rank", return_value=0) @patch("torch.distributed.all_gather") def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank, mock_tp_size): - layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) + layer = PrepareAndFinalizeWithAll2All(self.moe_config) hidden_states = torch.randn(2, 8) router_logits = torch.randn(2, 2) @@ -169,14 +164,15 @@ def mock_all_gather_func(tensor_list, tensor, group=None): # Should concat back self.assertEqual(final_result.shape[0], 2) - @patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group") + @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group") @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce" + "vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce" ) - @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" - ) - def test_allgather_prepare_finalize(self, mock_get_forward_context, + @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context") + @patch("vllm_ascend.ops.fused_moe.prepare_finalize.enable_sp", + return_value=False) + def test_allgather_prepare_finalize(self, mock_enable_sp, + mock_get_forward_context, mock_tp_all_reduce, mock_get_dp_group): # Mock forward context mock_context = MagicMock() @@ -198,7 +194,7 @@ def mock_all_gather_func(tensor, dim): self.moe_config.ep_size = 1 self.moe_config.dp_group = mock_dp_group - layer = FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config) + layer = PrepareAndFinalizeWithAllGather(self.moe_config) hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) @@ -232,13 +228,11 @@ def mock_reduce_scatter_func(tensor, dim): result_with_tp = layer.finalize(h_out, reduce_results=True) self.assertEqual(result_with_tp.shape[0], 3) - @patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group") - @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce" - ) + @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_dp_group") @patch( - "vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context" + "vllm_ascend.ops.fused_moe.prepare_finalize.tensor_model_parallel_all_reduce" ) + @patch("vllm_ascend.ops.fused_moe.prepare_finalize.get_forward_context") def test_naive_multicast_prepare_finalize(self, mock_get_forward_context, mock_tp_all_reduce, mock_get_dp_group): @@ -266,7 +260,7 @@ def mock_all_reduce(tensor): self.moe_config.tp_size = 1 self.moe_config.ep_size = 1 - layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config) + layer = PrepareAndFinalizeWithNaiveMulticast(self.moe_config) # Local inputs hidden_states = torch.randn(3, 8) diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 02045119072..4abe3d7f6b7 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -21,7 +21,7 @@ from tests.ut.base import TestBase -from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip +from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip AscendSocVersion, TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, TokenDispatcherWithMC2) @@ -34,7 +34,7 @@ def setUp(self): self.mc2_group.rank_in_group = 0 self.mc2_group.world_size = 8 self.mc2_group_patch = patch( - "vllm_ascend.ops.moe.token_dispatcher.get_mc2_group", + "vllm_ascend.ops.fused_moe.token_dispatcher.get_mc2_group", return_value=self.mc2_group) self.mc2_group_patch.start() @@ -52,7 +52,7 @@ def setUp(self): # Mock get_ascend_soc_version() self.ascend_soc_version_patch = patch( - "vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", + "vllm_ascend.ops.fused_moe.token_dispatcher.get_ascend_soc_version", return_value=AscendSocVersion.A3) self.ascend_soc_version_patch.start() @@ -369,7 +369,8 @@ def setUp(self): self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16) # Mock async_all_to_all - patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all') + patcher6 = patch( + 'vllm_ascend.ops.fused_moe.comm_utils.async_all_to_all') self.mock_async_all_to_all = patcher6.start() self.addCleanup(patcher6.stop) self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16), @@ -377,7 +378,7 @@ def setUp(self): # Mock gather_from_sequence_parallel_region patcher7 = patch( - 'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region' + 'vllm_ascend.ops.fused_moe.token_dispatcher.gather_from_sequence_parallel_region' ) self.mock_gather_from_sequence_parallel_region = patcher7.start() self.addCleanup(patcher7.stop) diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index b88e78f7da6..ba41e06a73f 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -5,8 +5,8 @@ from tests.ut.base import TestBase from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.ops.moe.experts_selector import (_native_grouped_topk, - select_experts) +from vllm_ascend.ops.fused_moe.experts_selector import (_native_grouped_topk, + select_experts) from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, AscendW8A8LinearMethod, @@ -758,7 +758,7 @@ def setUp(self): self.mock_ctx = MagicMock() self.mock_ctx.weight_prefetch_method = MagicMock() patcher = patch( - 'vllm_ascend.ops.moe.experts_selector.get_forward_context', + 'vllm_ascend.ops.fused_moe.experts_selector.get_forward_context', return_value=self.mock_ctx) self.addCleanup(patcher.stop) patcher.start() @@ -831,7 +831,7 @@ def test_grouped_topk(self, mock_topk): self.assertEqual(ids.shape, (self.num_tokens, self.top_k)) self.assertEqual(ids.dtype, torch.int32) - @patch('vllm_ascend.ops.moe.experts_selector._native_grouped_topk') + @patch('vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk') def test_grouped_topk_with_correction_bias(self, mock_grouped_topk): """Test grouped topk with expert score correction bias""" mock_grouped_topk.return_value = torch.ones(self.num_tokens, diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 85348dbf367..c6a0a7314b2 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -87,7 +87,8 @@ def set_ascend_forward_context( ): forward_context = get_forward_context() - from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method + from vllm_ascend.ops.fused_moe.moe_comm_method import \ + get_moe_comm_method forward_context.moe_comm_type = moe_comm_type forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type) diff --git a/vllm_ascend/models/deepseek_v3_2.py b/vllm_ascend/models/deepseek_v3_2.py index 700d94296c1..bf17c977829 100644 --- a/vllm_ascend/models/deepseek_v3_2.py +++ b/vllm_ascend/models/deepseek_v3_2.py @@ -66,7 +66,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.models.layers.sfa import AscendSFAModules, Indexer -from vllm_ascend.ops.common_fused_moe import AscendFusedMoE +from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE from vllm_ascend.ops.linear import AscendLinearBase from vllm_ascend.utils import vllm_version_is diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index 53f2d065905..e121f2a442c 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -17,7 +17,7 @@ import torch -import vllm_ascend.ops.common_fused_moe # noqa +import vllm_ascend.ops.fused_moe.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa import vllm_ascend.ops.register_custom_ops # noqa import vllm_ascend.ops.vocab_parallel_embedding # noqa diff --git a/vllm_ascend/ops/moe/__init__.py b/vllm_ascend/ops/fused_moe/__init__.py similarity index 100% rename from vllm_ascend/ops/moe/__init__.py rename to vllm_ascend/ops/fused_moe/__init__.py diff --git a/vllm_ascend/ops/moe/comm_utils.py b/vllm_ascend/ops/fused_moe/comm_utils.py similarity index 100% rename from vllm_ascend/ops/moe/comm_utils.py rename to vllm_ascend/ops/fused_moe/comm_utils.py diff --git a/vllm_ascend/ops/moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py similarity index 100% rename from vllm_ascend/ops/moe/experts_selector.py rename to vllm_ascend/ops/fused_moe/experts_selector.py diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py similarity index 99% rename from vllm_ascend/ops/common_fused_moe.py rename to vllm_ascend/ops/fused_moe/fused_moe.py index 23b3d5d9a5a..4df7948189b 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -35,8 +35,8 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, determine_default_log2phy_map) from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer -from vllm_ascend.ops.moe.experts_selector import select_experts -from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method +from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p, is_enable_nz, npu_stream_switch, shared_expert_dp_enabled, diff --git a/vllm_ascend/ops/moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py similarity index 85% rename from vllm_ascend/ops/moe/moe_comm_method.py rename to vllm_ascend/ops/fused_moe/moe_comm_method.py index a8364433871..8b2a879eac7 100644 --- a/vllm_ascend/ops/moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -24,15 +24,13 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm_ascend.ascend_forward_context import MoECommType -from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( - FusedMoEPrepareAndFinalizeWithAll2All, - FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, - FusedMoEPrepareAndFinalizeWithNaiveMulticast) -from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp -from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV, - TokenDispatcherWithAllGather, - TokenDispatcherWithMC2, - TokenDispatcherWithMoge) +from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp +from vllm_ascend.ops.fused_moe.prepare_finalize import ( + PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather, + PrepareAndFinalizeWithMC2, PrepareAndFinalizeWithNaiveMulticast) +from vllm_ascend.ops.fused_moe.token_dispatcher import ( + TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, + TokenDispatcherWithMC2, TokenDispatcherWithMoge) _MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} @@ -59,8 +57,7 @@ def __init__(self, moe_config: FusedMoEConfig): self.moe_config = moe_config self.token_dispatcher = self._get_token_dispatcher() - self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize( - ) + self.prepare_finalize = self._get_prepare_finalize() def prepare( self, @@ -71,7 +68,7 @@ def prepare( gate=None ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - hidden_states, router_logits, mc2_mask, context_metadata = self.fused_moe_prepare_finalize.prepare( + hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare( hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, gate) return hidden_states, router_logits, mc2_mask, context_metadata @@ -80,8 +77,9 @@ def finalize(self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: Optional[dict] = None) -> torch.Tensor: - hidden_states = self.fused_moe_prepare_finalize.finalize( - hidden_states, reduce_results, context_metadata) + hidden_states = self.prepare_finalize.finalize(hidden_states, + reduce_results, + context_metadata) return hidden_states def fused_experts( @@ -169,9 +167,9 @@ def _get_token_dispatcher(self): "_get_token_dispatcher function not implemented.") @abstractmethod - def _get_fused_moe_prepare_finalize(self): + def _get_prepare_finalize(self): raise NotImplementedError( - "_get_fused_moe_prepare_finalize function not implemented.") + "_get_prepare_finalize function not implemented.") class AllGatherCommImpl(MoECommMethod): @@ -205,8 +203,8 @@ def _get_token_dispatcher(self): num_experts=self.moe_config.num_experts, num_local_experts=self.moe_config.num_local_experts) - def _get_fused_moe_prepare_finalize(self): - return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config) + def _get_prepare_finalize(self): + return PrepareAndFinalizeWithAllGather(self.moe_config) class MC2CommImpl(MoECommMethod): @@ -222,8 +220,8 @@ class MC2CommImpl(MoECommMethod): def _get_token_dispatcher(self): return TokenDispatcherWithMC2() - def _get_fused_moe_prepare_finalize(self): - return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config) + def _get_prepare_finalize(self): + return PrepareAndFinalizeWithMC2(self.moe_config) class AlltoAllCommImpl(MoECommMethod): @@ -242,8 +240,8 @@ def _get_token_dispatcher(self): num_experts=self.moe_config.num_experts, num_local_experts=self.moe_config.num_local_experts) - def _get_fused_moe_prepare_finalize(self): - return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config) + def _get_prepare_finalize(self): + return PrepareAndFinalizeWithAll2All(self.moe_config) class NaiveMulticastCommImpl(MoECommMethod): @@ -271,5 +269,5 @@ def _get_token_dispatcher(self): num_experts=self.moe_config.num_experts, num_local_experts=self.moe_config.num_local_experts) - def _get_fused_moe_prepare_finalize(self): - return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config) + def _get_prepare_finalize(self): + return PrepareAndFinalizeWithNaiveMulticast(self.moe_config) diff --git a/vllm_ascend/ops/moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py similarity index 100% rename from vllm_ascend/ops/moe/moe_mlp.py rename to vllm_ascend/ops/fused_moe/moe_mlp.py diff --git a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py similarity index 96% rename from vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py rename to vllm_ascend/ops/fused_moe/prepare_finalize.py index 7533ccebe12..accda224b55 100644 --- a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -30,7 +30,7 @@ from vllm_ascend.utils import enable_sp, get_rm_router_logits_state -class FusedMoEPrepareAndFinalize(ABC): +class PrepareAndFinalize(ABC): """ Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization in distributed environments. Subclasses implement specific communication strategies @@ -103,7 +103,7 @@ def finalize(self, raise NotImplementedError("Finalize function not implemented.") -class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): +class PrepareAndFinalizeWithAll2All(PrepareAndFinalize): """ MoE communication strategy using All-to-All style slicing. Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing. @@ -195,7 +195,7 @@ def finalize(self, return hidden_states -class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All): +class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All): """ MoE communication strategy using MC2, which is based on All2All. Hence, it inherits All2All and share the same finalize method. @@ -275,7 +275,7 @@ def prepare( return hidden_states, router_logits, mc2_mask, context_metadata -class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): +class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): """ MoE communication strategy using All-Gather + Reduce-Scatter on EP group. There are two sets of prepare and finalize: @@ -429,7 +429,7 @@ def _finalize_with_dp_group(self, hidden_states: torch.Tensor, return hidden_states -class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize): +class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize): """ MoE communication strategy using Naive Multicast (point-to-point broadcast). Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others. diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py similarity index 99% rename from vllm_ascend/ops/moe/token_dispatcher.py rename to vllm_ascend/ops/fused_moe/token_dispatcher.py index 7f74c53fac8..222344e1ebe 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -28,7 +28,7 @@ from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.moe.comm_utils import ( +from vllm_ascend.ops.fused_moe.comm_utils import ( async_all_to_all, gather_from_sequence_parallel_region) from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version, is_hierarchical_communication_enabled) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 8ab85ea7c2a..36dbcb17563 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -37,7 +37,7 @@ from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) -from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod +from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, oproj_tp_enable) diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index ccacb69101d..77f0f4b23cb 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -26,7 +26,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 9df640c1893..07b7cac2557 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -24,7 +24,7 @@ from vllm.forward_context import get_forward_context from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 8fe6cbd8d90..1a9d0b5c970 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -25,7 +25,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group -from vllm_ascend.ops.moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_enable_nz, vllm_version_is) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 1f496d42eec..786557624b2 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -538,8 +538,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.models.layers.sfa import AscendSparseFlashAttention from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul - from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE, - AscendSharedFusedMoE) + from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE, + AscendSharedFusedMoE) from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm, AscendQuantRMSNorm, AscendRMSNorm) from vllm_ascend.ops.linear import (AscendColumnParallelLinear, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8fe6df2a534..1cc93533da3 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -586,7 +586,9 @@ def _init_mc2_tokens_capacity(self): if self.compilation_config.cudagraph_capture_sizes: max_num_tokens = self.compilation_config.cudagraph_capture_sizes[0] else: - max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len + # NOTE: To save memory, we cap the max number of tokens to 512. + max_num_tokens = min( + self.max_num_reqs * self.uniform_decode_query_len, 512) tp_size = self.parallel_config.tensor_parallel_size # Use integer arithmetic for ceiling division. num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size