diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 82d575f24690..0c11ba92ade1 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -123,6 +123,7 @@ class ParallelConfig: "pplx", "deepep_high_throughput", "deepep_low_latency", + "mori", "allgather_reducescatter", "flashinfer_all2allv", ] @@ -135,6 +136,7 @@ class ParallelConfig: - "pplx": Use pplx kernels - "deepep_high_throughput": Use deepep high-throughput kernels - "deepep_low_latency": Use deepep low-latency kernels + - "mori": Use mori kernels - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl""" num_redundant_experts: int | None = None """`num_redundant_experts` is deprecated and has been replaced with @@ -370,6 +372,7 @@ def use_sequence_parallel_moe(self) -> bool: "naive", "deepep_high_throughput", "deepep_low_latency", + "mori", ) and self.enable_expert_parallel and self.tensor_parallel_size > 1 diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index c40dde26b741..39d52a8e421e 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -10,7 +10,7 @@ from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils.flashinfer import has_flashinfer_all2all -from vllm.utils.import_utils import has_deep_ep, has_pplx +from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx from .base_device_communicator import All2AllManagerBase, Cache @@ -488,3 +488,80 @@ def cleanup(self): self.prepare_workspace_tensor = None self.mapping = None self.initialized = False + + +class MoriAll2AllManager(All2AllManagerBase): + def __init__(self, cpu_group): + assert has_mori(), ( + "MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md" + " to install MoRI kernels." + ) # noqa + import mori + + super().__init__(cpu_group) + self.handle_cache = Cache() + + torch._C._distributed_c10d._register_process_group("mori", cpu_group) + mori.shmem.shmem_torch_process_group_init("mori") + + def _make_all2all_kwargs( + self, + rank: int, + num_ep_ranks: int, + input_dtype: torch.dtype, + quant_dtype: torch.dtype, + token_hidden_size: int, + scale_dim: int, + scale_type_size: int, + max_num_tokens_per_dp_rank: int, + num_local_experts: int, + num_experts_per_token: int, + ): + import mori # type: ignore[import-not-found] + + if not self.internode: + # single node + kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode + warp_num_per_block = 16 + block_num = 80 + rdma_block_num = 0 + else: + # multi node + kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1 + warp_num_per_block = 16 + block_num = 32 + rdma_block_num = 16 + + return dict( + rank=rank, + world_size=num_ep_ranks, + data_type=quant_dtype, + hidden_dim=token_hidden_size, + scale_dim=scale_dim, + scale_type_size=scale_type_size, + max_token_type_size=input_dtype.itemsize, + max_num_inp_token_per_rank=max_num_tokens_per_dp_rank, + num_experts_per_rank=num_local_experts, + num_experts_per_token=num_experts_per_token, + warp_num_per_block=warp_num_per_block, + block_num=block_num, + kernel_type=kernel_type, + rdma_block_num=rdma_block_num, + ) + + def _make_handle(self, **kwargs): + import mori # type: ignore[import-not-found] + + mori_config = mori.ops.EpDispatchCombineConfig(**kwargs) + handle = mori.ops.EpDispatchCombineOp(mori_config) + return handle + + def get_handle(self, kwargs): + import mori # type: ignore[import-not-found] + + mori_kwargs = self._make_all2all_kwargs(**kwargs) + logger.debug("MoRI all2all args %s", mori_kwargs) + handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create( + mori_kwargs, self._make_handle + ) + return handle diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 2e878eef908a..eeac7a549f19 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -110,6 +110,10 @@ def __init__( from .all2all import DeepEPLLAll2AllManager self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) + elif self.all2all_backend == "mori": + from .all2all import MoriAll2AllManager + + self.all2all_manager = MoriAll2AllManager(self.cpu_group) elif self.all2all_backend == "flashinfer_all2allv": from .all2all import FlashInferAllToAllManager diff --git a/vllm/envs.py b/vllm/envs.py index 59a6bef58c9c..fd2337a422c0 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -111,7 +111,7 @@ VLLM_ROCM_USE_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False - VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True + VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -166,6 +166,7 @@ "pplx", "deepep_high_throughput", "deepep_low_latency", + "mori", "allgather_reducescatter", "flashinfer_all2allv", ] = "allgather_reducescatter" @@ -935,7 +936,7 @@ def get_vllm_port() -> int | None: # Whether to use aiter fusion shared experts ops. # By default is enabled. "VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower() + os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "False").lower() in ("true", "1") ), # use rocm skinny gemms @@ -1197,6 +1198,7 @@ def get_vllm_port() -> int | None: # - "pplx": use pplx kernels # - "deepep_high_throughput", use deepep high-throughput kernels # - "deepep_low_latency", use deepep low-latency kernels + # - "mori", use MoRI kernels # - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl "VLLM_ALL2ALL_BACKEND": env_with_choices( "VLLM_ALL2ALL_BACKEND", @@ -1206,6 +1208,7 @@ def get_vllm_port() -> int | None: "pplx", "deepep_high_throughput", "deepep_low_latency", + "mori", "allgather_reducescatter", "flashinfer_all2allv", ], diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index cb31045971bd..2424783a86b5 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -64,6 +64,9 @@ def get_config() -> dict[str, Any] | None: cutlass_moe_fp8, ) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts + from vllm.model_executor.layers.fused_moe.fused_aiter_moe import ( + AiterExperts, + ) from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts, ) @@ -93,6 +96,7 @@ def get_config() -> dict[str, Any] | None: "BatchedDeepGemmExperts", "TritonOrDeepGemmExperts", "BatchedTritonOrDeepGemmExperts", + "AiterExperts", ] else: # Some model classes directly use the custom ops. Add placeholders diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index cbc3caafcf2f..201e56e86196 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -691,6 +691,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" + @property + def use_mori_kernels(self): + return self.use_all2all_kernels and self.all2all_backend == "mori" + @staticmethod def flatten_tp_across_dp( tp_size: int, dp_size: int, dp_rank: int @@ -883,6 +887,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_mori_kernels(self): + return self.moe_parallel_config.use_mori_kernels + @property def use_flashinfer_cutlass_kernels(self): """ diff --git a/vllm/model_executor/layers/fused_moe/fused_aiter_moe.py b/vllm/model_executor/layers/fused_moe/fused_aiter_moe.py new file mode 100644 index 000000000000..236fc690ef9c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_aiter_moe.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, +) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) + + +class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__(self, quant_config: FusedMoEQuantConfig): + super().__init__(quant_config) + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return ( + mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard, + ) + + def supports_chunking(self) -> bool: + return True + + def supports_expert_map(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + workspace13 = (M, K) + workspace2 = (0,) + output = (M, K) + return (workspace13, workspace2, output) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ) -> None: + if expert_tokens_meta is not None: + num_local_tokens = expert_tokens_meta.expert_num_tokens + else: + num_local_tokens = None + + result = rocm_aiter_fused_experts( + hidden_states, + w1, + w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + quant_config=self.quant_config, + a1q_scale=a1q_scale, + num_local_tokens=num_local_tokens, + output_dtype=output.dtype, + ) + output.copy_(result) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1236116386c9..a09ceff8c5fe 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -56,7 +56,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -from vllm.utils.import_utils import has_deep_ep, has_pplx +from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import current_stream, direct_register_custom_op from vllm.v1.worker.ubatching import dbo_current_ubatch_id @@ -76,6 +76,8 @@ DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize, ) + if has_mori(): + from .mori_prepare_finalize import MoriPrepareAndFinalize else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = object # type: ignore @@ -233,6 +235,36 @@ def _maybe_make_prepare_finalize( use_fp8_dispatch=use_fp8_dispatch, ) + elif moe.use_mori_kernels: + assert quant_config is not None + # For PTPC (per token per channel) quant, the scale dim for each token is 1 + # For 1x128 quant, the scale dim for each token is hidden_dim // 128 + scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128 + all_to_all_args = dict( + rank=all2all_manager.rank, + num_ep_ranks=all2all_manager.world_size, + quant_dtype=quant_config.quant_dtype, + token_hidden_size=moe.hidden_dim, + scale_dim=scale_dim, + scale_type_size=torch.float32.itemsize, + max_num_tokens_per_dp_rank=moe.max_num_tokens, + input_dtype=moe.in_dtype, + num_local_experts=moe.num_experts // all2all_manager.world_size, + num_experts_per_token=moe.experts_per_token, + ) + handle = all2all_manager.get_handle(all_to_all_args) + + # Note: We may want to use FP8 dispatch just to reduce + # data movement. + use_fp8_dispatch = is_rocm_aiter_moe_enabled() + + prepare_finalize = MoriPrepareAndFinalize( + handle, + max_tokens_per_rank=moe.max_num_tokens, + num_dispatchers=all2all_manager.world_size, + use_fp8_dispatch=use_fp8_dispatch, + ) + return prepare_finalize def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: @@ -1407,6 +1439,11 @@ def __init__( is_act_and_mul=is_act_and_mul, is_lora_enabled=vllm_config.lora_config is not None, ) + if self.use_mori_kernels: + assert not is_rocm_aiter_fusion_shared_expert_enabled(), ( + "Mori does not support fusion shared expert now. " + "Turn it off by setting VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0" + ) self.moe_quant_config: FusedMoEQuantConfig | None = None self.quant_config = quant_config @@ -1538,6 +1575,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_mori_kernels(self): + return self.moe_parallel_config.use_mori_kernels + @property def use_flashinfer_cutlass_kernels(self): return ( @@ -1551,6 +1592,7 @@ def use_dp_chunking(self) -> bool: return ( self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels + or self.moe_parallel_config.use_mori_kernels or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) ) diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py new file mode 100644 index 000000000000..930e7ae3ff33 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import mori +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """ + Prepare/Finalize using MoRI kernels. + """ + + def __init__( + self, + mori_op: mori.ops.EpDispatchCombineOp, + max_tokens_per_rank: int, + num_dispatchers: int, + use_fp8_dispatch: bool = False, + ): + super().__init__() + self.mori_op = mori_op + self.num_dispatchers_ = num_dispatchers + self.max_tokens_per_rank = max_tokens_per_rank + self.use_fp8_dispatch = use_fp8_dispatch + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def output_is_reduced(self) -> bool: + return True + + def num_dispatchers(self): + return self.num_dispatchers_ + + def max_num_tokens_per_rank(self) -> int | None: + return self.max_tokens_per_rank + + def topk_indices_dtype(self) -> torch.dtype | None: + return torch.int32 + + def supports_async(self) -> bool: + return False + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + """ + Returns a tuple of: + - quantized + dispatched a. + - Optional quantized + dispatched a1_scales. + - Optional ExpertTokensMetadata containing gpu/cpu tensors + as big as the number of local experts with the information about the + number of tokens assigned to each local expert. + - Optional dispatched expert topk IDs + - Optional dispatched expert topk weight + """ + assert not apply_router_weight_on_input, ( + "mori does not support apply_router_weight_on_input=True now." + ) + scale = None + if self.use_fp8_dispatch: + from aiter import QuantType, get_hip_quant + + if quant_config.is_block_quantized: + quant_func = get_hip_quant(QuantType.per_1x128) + a1, scale = quant_func(a1, quant_dtype=current_platform.fp8_dtype()) + elif quant_config.is_per_act_token: + quant_func = get_hip_quant(QuantType.per_Token) + a1, scale = quant_func(a1, quant_dtype=current_platform.fp8_dtype()) + + ( + dispatch_a1, + dispatch_weights, + dispatch_scale, + dispatch_ids, + dispatch_recv_token_num, + ) = self.mori_op.dispatch(a1, topk_weights, scale, topk_ids) + + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=dispatch_recv_token_num, expert_num_tokens_cpu=None + ) + + return ( + dispatch_a1, + dispatch_scale, + expert_tokens_meta, + dispatch_ids, + dispatch_weights, + ) + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + num_token = output.shape[0] + result = self.mori_op.combine( + fused_expert_output, + None, + topk_ids, + )[0] + output.copy_(result[:num_token]) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index e18514ad43f6..f236e9cfe8dd 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -281,6 +281,8 @@ def rocm_aiter_fused_moe_impl( w2_scale: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, + num_local_tokens: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe @@ -302,6 +304,8 @@ def rocm_aiter_fused_moe_impl( w2_scale, a1_scale, a2_scale, + num_local_tokens=num_local_tokens, + dtype=output_dtype, ) @@ -319,6 +323,8 @@ def rocm_aiter_fused_moe_fake( w2_scale: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None, + num_local_tokens: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -435,6 +441,9 @@ def rocm_aiter_fused_experts( apply_router_weight_on_input: bool = False, expert_map: torch.Tensor | None = None, quant_config: FusedMoEQuantConfig | None = None, + a1q_scale: torch.Tensor | None = None, + num_local_tokens: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG @@ -463,6 +472,9 @@ def rocm_aiter_fused_experts( assert topk_weights.shape[-1] == 1, ( "Only support topk=1 when `apply_router_weight_on_input` is True" ) + assert num_local_tokens is None, ( + "AITER tkw1 kernel does not support `num_local_tokens`" + ) return torch.ops.vllm.rocm_aiter_asm_moe_tkw1( hidden_states, @@ -518,9 +530,11 @@ def rocm_aiter_fused_experts( activation_method=activation_method, w1_scale=quant_config.w1_scale, w2_scale=quant_config.w2_scale, - a1_scale=quant_config.a1_scale, + a1_scale=quant_config.a1_scale if a1q_scale is None else a1q_scale, a2_scale=quant_config.a2_scale, doweight_stage1=apply_router_weight_on_input, + num_local_tokens=num_local_tokens, + output_dtype=output_dtype, ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d95d49eddfe3..5e8706bc3a33 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -896,7 +896,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: - if self.use_marlin or self.rocm_aiter_moe_enabled: + if self.use_marlin: return None else: return super().maybe_make_prepare_finalize() @@ -950,6 +950,9 @@ def select_gemm_impl( return experts + # aiter path + from vllm.model_executor.layers.fused_moe import AiterExperts + # triton path from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 BatchedTritonOrDeepGemmExperts, @@ -958,7 +961,7 @@ def select_gemm_impl( TritonOrDeepGemmExperts, ) - assert not self.rocm_aiter_moe_enabled and not self.use_marlin + assert not self.use_marlin if ( prepare_finalize.activation_format @@ -973,6 +976,15 @@ def select_gemm_impl( num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, ) + elif self.rocm_aiter_moe_enabled: + logger.debug( + "AiterExperts(%s): per_act_token=%s", + self.__class__.__name__, + True, + ) + return AiterExperts( + quant_config=self.moe_quant_config, + ) else: logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) return TritonOrDeepGemmExperts(self.moe_quant_config, allow_deep_gemm=True) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ce40645782e5..a1b77e32300c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1096,8 +1096,7 @@ def process_weights_after_loading(self, layer: Module) -> None: def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None: if ( - self.rocm_aiter_moe_enabled - or self.use_marlin + self.use_marlin or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): return None @@ -1116,13 +1115,12 @@ def select_gemm_impl( layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: from vllm.model_executor.layers.fused_moe import ( + AiterExperts, BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts, ) - assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( - "Marlin and ROCm AITER are not supported with all2all yet." - ) + assert not self.use_marlin, "Marlin is not supported with all2all yet." assert self.moe_quant_config is not None @@ -1146,6 +1144,15 @@ def select_gemm_impl( quant_config=self.moe_quant_config, allow_deep_gemm=self.allow_deep_gemm, ) + elif self.rocm_aiter_moe_enabled: + logger.debug( + "AiterExperts(%s): per_act_token=%s", + self.__class__.__name__, + False, + ) + return AiterExperts( + quant_config=self.moe_quant_config, + ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: experts = select_cutlass_fp8_gemm_impl( self.moe, diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index f01d2c7a6a33..5fdcec6afc6a 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -409,3 +409,8 @@ def has_arctic_inference() -> bool: """Whether the optional `arctic_inference` package is available.""" return _has_module("arctic_inference") + + +def has_mori() -> bool: + """Whether the optional `mori` package is available.""" + return _has_module("mori")