diff --git a/examples/offline_data_parallel.py b/examples/offline_data_parallel.py index b16d50ffd1d..b7193f583f2 100644 --- a/examples/offline_data_parallel.py +++ b/examples/offline_data_parallel.py @@ -111,6 +111,10 @@ def parse_args(): parser.add_argument("--enable-expert-parallel", action="store_true", help="Enable expert parallel, used in MOE models.") + parser.add_argument("--quantization", + type=str, + default="", + help="Use quantization models") return parser.parse_args() @@ -134,6 +138,7 @@ def main( enable_expert_parallel, enforce_eager, trust_remote_code, + quantization, ): # DP only support on V1 engine os.environ["VLLM_DP_RANK"] = str(global_dp_rank) @@ -185,6 +190,7 @@ def start(rank): enforce_eager=enforce_eager, enable_expert_parallel=enable_expert_parallel, trust_remote_code=trust_remote_code, + quantization=quantization, ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. @@ -220,6 +226,8 @@ def start(rank): assert dp_size % node_size == 0, "dp_size should be divisible by node_size" dp_per_node = dp_size // node_size + quantization = args.quantization if args.quantization else None + from multiprocessing import Process procs = [] @@ -238,6 +246,7 @@ def start(rank): args.enable_expert_parallel, args.enforce_eager, args.trust_remote_code, + quantization, ), ) proc.start() diff --git a/tests/e2e/multicard/test_data_parallel.py b/tests/e2e/multicard/test_data_parallel.py index 3839eb8edf2..94c95887149 100644 --- a/tests/e2e/multicard/test_data_parallel.py +++ b/tests/e2e/multicard/test_data_parallel.py @@ -27,13 +27,17 @@ import pytest -MODELS = ["Qwen/Qwen3-0.6B", "Qwen/Qwen3-30B-A3B"] +MODELS = [ + "Qwen/Qwen3-0.6B", "Qwen/Qwen3-30B-A3B", "vllm-ascend/Qwen3-30B-A3B-W8A8" +] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [32]) @patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"}) def test_data_parallel_inference(model, max_tokens): + moe_models = ["Qwen/Qwen3-30B-A3B", "vllm-ascend/Qwen3-30B-A3B-W8A8"] + quantization_models = ["vllm-ascend/Qwen3-30B-A3B-W8A8"] script = "examples/offline_data_parallel.py" env = os.environ.copy() @@ -54,8 +58,11 @@ def test_data_parallel_inference(model, max_tokens): "--trust-remote-code", ] - if model == "Qwen/Qwen3-30B-A3B": + if model in moe_models: cmd.append("--enable-expert-parallel") + if model in quantization_models: + cmd.append("--quantization") + cmd.append("ascend") print(f"Running subprocess: {' '.join(cmd)}") proc = subprocess.run(cmd, diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 72c04e50b70..e9d0c97f942 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -408,11 +408,10 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): quant_config: The Ascend quantization config. """ - def __init__(self, - quant_config: AscendQuantConfig, - prefix: str, - packed_modules_mapping: Dict[str, Any], - layer: torch.nn.Module = None): + def __init__(self, quant_config: AscendQuantConfig, prefix: str, + packed_modules_mapping: Dict[str, + Any], layer: torch.nn.Module): + super().__init__(layer.moe_config) self.quant_method = get_quant_method(quant_config.quant_description, prefix, "moe", diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index a3a39176127..87f23b9b3bf 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -16,7 +16,7 @@ # Adapted from vllm/tests/kernels/test_moe.py import os -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch import torch.distributed as dist @@ -45,7 +45,9 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer -from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod +from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod, + AscendQuantConfig) +from vllm_ascend.quantization.utils import get_quant_method from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding from vllm_ascend.torchair.utils import (get_all_reduce_merge_state, get_rm_router_logits_state, @@ -936,6 +938,15 @@ def apply( ep_group=get_ep_group()) +class TorchairAscendFusedMoEMethod(AscendFusedMoEMethod): + + def __init__(self, quant_config: AscendQuantConfig, prefix: str, + packed_modules_mapping: Dict[str, Any]): + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "moe", + packed_modules_mapping) + + class TorchairAscendFusedMoE(FusedMoE): # The moe_counter parameter is required during the initialization of EPLB @@ -1115,7 +1126,7 @@ def __init__( self.quant_method = TorchairAscendUnquantizedFusedMoEMethod( self.moe) else: - self.quant_method = AscendFusedMoEMethod( + self.quant_method = TorchairAscendFusedMoEMethod( quant_config, prefix, quant_config.packed_modules_mapping) assert self.quant_method is not None