Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class ParallelConfig:
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"mori",
"allgather_reducescatter",
"flashinfer_all2allv",
]
Expand All @@ -135,6 +136,7 @@ class ParallelConfig:
- "pplx": Use pplx kernels
- "deepep_high_throughput": Use deepep high-throughput kernels
- "deepep_low_latency": Use deepep low-latency kernels
- "mori": Use mori kernels
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
num_redundant_experts: int | None = None
"""`num_redundant_experts` is deprecated and has been replaced with
Expand Down Expand Up @@ -370,6 +372,7 @@ def use_sequence_parallel_moe(self) -> bool:
"naive",
"deepep_high_throughput",
"deepep_low_latency",
"mori",
)
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
Expand Down
79 changes: 78 additions & 1 deletion vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.import_utils import has_deep_ep, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx

from .base_device_communicator import All2AllManagerBase, Cache

Expand Down Expand Up @@ -488,3 +488,80 @@ def cleanup(self):
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False


class MoriAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
assert has_mori(), (
"MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md"
" to install MoRI kernels."
) # noqa
import mori

super().__init__(cpu_group)
self.handle_cache = Cache()

torch._C._distributed_c10d._register_process_group("mori", cpu_group)
mori.shmem.shmem_torch_process_group_init("mori")

def _make_all2all_kwargs(
self,
rank: int,
num_ep_ranks: int,
input_dtype: torch.dtype,
quant_dtype: torch.dtype,
token_hidden_size: int,
scale_dim: int,
scale_type_size: int,
max_num_tokens_per_dp_rank: int,
num_local_experts: int,
num_experts_per_token: int,
):
import mori # type: ignore[import-not-found]

if not self.internode:
# single node
kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode
warp_num_per_block = 16
block_num = 80
rdma_block_num = 0
else:
# multi node
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1
warp_num_per_block = 16
block_num = 32
rdma_block_num = 16

return dict(
rank=rank,
world_size=num_ep_ranks,
data_type=quant_dtype,
hidden_dim=token_hidden_size,
scale_dim=scale_dim,
scale_type_size=scale_type_size,
max_token_type_size=input_dtype.itemsize,
max_num_inp_token_per_rank=max_num_tokens_per_dp_rank,
num_experts_per_rank=num_local_experts,
num_experts_per_token=num_experts_per_token,
warp_num_per_block=warp_num_per_block,
block_num=block_num,
kernel_type=kernel_type,
rdma_block_num=rdma_block_num,
)

def _make_handle(self, **kwargs):
import mori # type: ignore[import-not-found]

mori_config = mori.ops.EpDispatchCombineConfig(**kwargs)
handle = mori.ops.EpDispatchCombineOp(mori_config)
return handle

def get_handle(self, kwargs):
import mori # type: ignore[import-not-found]

mori_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("MoRI all2all args %s", mori_kwargs)
handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create(
mori_kwargs, self._make_handle
)
return handle
4 changes: 4 additions & 0 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def __init__(
from .all2all import DeepEPLLAll2AllManager

self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
elif self.all2all_backend == "mori":
from .all2all import MoriAll2AllManager

self.all2all_manager = MoriAll2AllManager(self.cpu_group)
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager

Expand Down
7 changes: 5 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
VLLM_ROCM_USE_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = True
VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
Expand Down Expand Up @@ -166,6 +166,7 @@
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"mori",
"allgather_reducescatter",
"flashinfer_all2allv",
] = "allgather_reducescatter"
Expand Down Expand Up @@ -935,7 +936,7 @@ def get_vllm_port() -> int | None:
# Whether to use aiter fusion shared experts ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "True").lower()
os.getenv("VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS", "False").lower()
in ("true", "1")
),
# use rocm skinny gemms
Expand Down Expand Up @@ -1197,6 +1198,7 @@ def get_vllm_port() -> int | None:
# - "pplx": use pplx kernels
# - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels
# - "mori", use MoRI kernels
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
"VLLM_ALL2ALL_BACKEND": env_with_choices(
"VLLM_ALL2ALL_BACKEND",
Expand All @@ -1206,6 +1208,7 @@ def get_vllm_port() -> int | None:
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"mori",
"allgather_reducescatter",
"flashinfer_all2allv",
],
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def get_config() -> dict[str, Any] | None:
cutlass_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_aiter_moe import (
AiterExperts,
)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts,
)
Expand Down Expand Up @@ -93,6 +96,7 @@ def get_config() -> dict[str, Any] | None:
"BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts",
"BatchedTritonOrDeepGemmExperts",
"AiterExperts",
]
else:
# Some model classes directly use the custom ops. Add placeholders
Expand Down
8 changes: 8 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,10 @@ def use_deepep_ht_kernels(self):
def use_deepep_ll_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"

@property
def use_mori_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "mori"

@staticmethod
def flatten_tp_across_dp(
tp_size: int, dp_size: int, dp_rank: int
Expand Down Expand Up @@ -883,6 +887,10 @@ def use_deepep_ht_kernels(self):
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels

@property
def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels

@property
def use_flashinfer_cutlass_kernels(self):
"""
Expand Down
90 changes: 90 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_aiter_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)


class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)

@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (
mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard,
)

def supports_chunking(self) -> bool:
return True

def supports_expert_map(self) -> bool:
return True

def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()

def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
workspace13 = (M, K)
workspace2 = (0,)
output = (M, K)
return (workspace13, workspace2, output)

def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
) -> None:
if expert_tokens_meta is not None:
num_local_tokens = expert_tokens_meta.expert_num_tokens
else:
num_local_tokens = None

result = rocm_aiter_fused_experts(
hidden_states,
w1,
w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.quant_config,
a1q_scale=a1q_scale,
num_local_tokens=num_local_tokens,
output_dtype=output.dtype,
)
output.copy_(result)
44 changes: 43 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_deep_ep, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import current_stream, direct_register_custom_op
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
Expand All @@ -76,6 +76,8 @@
DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize,
)
if has_mori():
from .mori_prepare_finalize import MoriPrepareAndFinalize
else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = object # type: ignore
Expand Down Expand Up @@ -233,6 +235,36 @@ def _maybe_make_prepare_finalize(
use_fp8_dispatch=use_fp8_dispatch,
)

elif moe.use_mori_kernels:
assert quant_config is not None
# For PTPC (per token per channel) quant, the scale dim for each token is 1
# For 1x128 quant, the scale dim for each token is hidden_dim // 128
scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128
all_to_all_args = dict(
rank=all2all_manager.rank,
num_ep_ranks=all2all_manager.world_size,
quant_dtype=quant_config.quant_dtype,
token_hidden_size=moe.hidden_dim,
scale_dim=scale_dim,
scale_type_size=torch.float32.itemsize,
max_num_tokens_per_dp_rank=moe.max_num_tokens,
input_dtype=moe.in_dtype,
num_local_experts=moe.num_experts // all2all_manager.world_size,
num_experts_per_token=moe.experts_per_token,
)
handle = all2all_manager.get_handle(all_to_all_args)

# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch = is_rocm_aiter_moe_enabled()

prepare_finalize = MoriPrepareAndFinalize(
handle,
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
)

return prepare_finalize

def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
Expand Down Expand Up @@ -1407,6 +1439,11 @@ def __init__(
is_act_and_mul=is_act_and_mul,
is_lora_enabled=vllm_config.lora_config is not None,
)
if self.use_mori_kernels:
assert not is_rocm_aiter_fusion_shared_expert_enabled(), (
"Mori does not support fusion shared expert now. "
"Turn it off by setting VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0"
)

self.moe_quant_config: FusedMoEQuantConfig | None = None
self.quant_config = quant_config
Expand Down Expand Up @@ -1538,6 +1575,10 @@ def use_deepep_ht_kernels(self):
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels

@property
def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels

@property
def use_flashinfer_cutlass_kernels(self):
return (
Expand All @@ -1551,6 +1592,7 @@ def use_dp_chunking(self) -> bool:
return (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
)

Expand Down
Loading