Skip to content

Commit c5ae6bd

Browse files
committed
mori ep
Signed-off-by: Alex Sun <[email protected]>
1 parent 938a816 commit c5ae6bd

File tree

13 files changed

+387
-12
lines changed

13 files changed

+387
-12
lines changed

vllm/config/parallel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ class ParallelConfig:
123123
"pplx",
124124
"deepep_high_throughput",
125125
"deepep_low_latency",
126+
"mori",
126127
"allgather_reducescatter",
127128
"flashinfer_all2allv",
128129
]
@@ -135,6 +136,7 @@ class ParallelConfig:
135136
- "pplx": Use pplx kernels
136137
- "deepep_high_throughput": Use deepep high-throughput kernels
137138
- "deepep_low_latency": Use deepep low-latency kernels
139+
- "mori": Use mori kernels
138140
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
139141
num_redundant_experts: int | None = None
140142
"""`num_redundant_experts` is deprecated and has been replaced with
@@ -370,6 +372,7 @@ def use_sequence_parallel_moe(self) -> bool:
370372
"naive",
371373
"deepep_high_throughput",
372374
"deepep_low_latency",
375+
"mori",
373376
)
374377
and self.enable_expert_parallel
375378
and self.tensor_parallel_size > 1

vllm/distributed/device_communicators/all2all.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vllm.forward_context import get_forward_context
1111
from vllm.logger import init_logger
1212
from vllm.utils.flashinfer import has_flashinfer_all2all
13-
from vllm.utils.import_utils import has_deep_ep, has_pplx
13+
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
1414

1515
from .base_device_communicator import All2AllManagerBase, Cache
1616

@@ -488,3 +488,77 @@ def cleanup(self):
488488
self.prepare_workspace_tensor = None
489489
self.mapping = None
490490
self.initialized = False
491+
492+
493+
class MoriAll2AllManager(All2AllManagerBase):
494+
def __init__(self, cpu_group):
495+
assert has_mori(), (
496+
"MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md"
497+
" to install MoRI kernels."
498+
) # noqa
499+
import mori
500+
501+
super().__init__(cpu_group)
502+
self.handle_cache = Cache()
503+
504+
torch._C._distributed_c10d._register_process_group("mori", cpu_group)
505+
mori.shmem.shmem_torch_process_group_init("mori")
506+
507+
def _make_all2all_kwargs(
508+
self,
509+
rank: int,
510+
num_ep_ranks: int,
511+
input_dtype: torch.dtype,
512+
quant_dtype: torch.dtype,
513+
token_hidden_size: int,
514+
scale_dim: int,
515+
scale_type_size: int,
516+
max_num_tokens_per_dp_rank: int,
517+
num_local_experts: int,
518+
num_experts_per_token: int,
519+
):
520+
import mori # type: ignore[import-not-found]
521+
522+
if self.internode:
523+
# multi node
524+
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNode
525+
warp_num_per_block = 16
526+
block_num = 64
527+
else:
528+
# single node
529+
kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode
530+
warp_num_per_block = 16
531+
block_num = 64
532+
533+
return dict(
534+
rank=rank,
535+
world_size=num_ep_ranks,
536+
data_type=quant_dtype,
537+
hidden_dim=token_hidden_size,
538+
scale_dim=scale_dim,
539+
scale_type_size=scale_type_size,
540+
max_token_type_size=input_dtype.itemsize,
541+
max_num_inp_token_per_rank=max_num_tokens_per_dp_rank,
542+
num_experts_per_rank=num_local_experts,
543+
num_experts_per_token=num_experts_per_token,
544+
warp_num_per_block=warp_num_per_block,
545+
block_num=block_num,
546+
kernel_type=kernel_type,
547+
)
548+
549+
def _make_handle(self, **kwargs):
550+
import mori # type: ignore[import-not-found]
551+
552+
mori_config = mori.ops.EpDispatchCombineConfig(**kwargs)
553+
handle = mori.ops.EpDispatchCombineOp(mori_config)
554+
return handle
555+
556+
def get_handle(self, kwargs):
557+
import mori # type: ignore[import-not-found]
558+
559+
mori_kwargs = self._make_all2all_kwargs(**kwargs)
560+
logger.debug("MoRI all2all args %s", mori_kwargs)
561+
handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create(
562+
mori_kwargs, self._make_handle
563+
)
564+
return handle

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ def __init__(
110110
from .all2all import DeepEPLLAll2AllManager
111111

112112
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
113+
elif self.all2all_backend == "mori":
114+
from .all2all import MoriAll2AllManager
115+
116+
self.all2all_manager = MoriAll2AllManager(self.cpu_group)
113117
elif self.all2all_backend == "flashinfer_all2allv":
114118
from .all2all import FlashInferAllToAllManager
115119

vllm/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@
165165
"pplx",
166166
"deepep_high_throughput",
167167
"deepep_low_latency",
168+
"mori",
168169
"allgather_reducescatter",
169170
"flashinfer_all2allv",
170171
] = "allgather_reducescatter"
@@ -1187,6 +1188,7 @@ def get_vllm_port() -> int | None:
11871188
# - "pplx": use pplx kernels
11881189
# - "deepep_high_throughput", use deepep high-throughput kernels
11891190
# - "deepep_low_latency", use deepep low-latency kernels
1191+
# - "mori", use MoRI kernels
11901192
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
11911193
"VLLM_ALL2ALL_BACKEND": env_with_choices(
11921194
"VLLM_ALL2ALL_BACKEND",
@@ -1196,6 +1198,7 @@ def get_vllm_port() -> int | None:
11961198
"pplx",
11971199
"deepep_high_throughput",
11981200
"deepep_low_latency",
1201+
"mori",
11991202
"allgather_reducescatter",
12001203
"flashinfer_all2allv",
12011204
],

vllm/model_executor/layers/fused_moe/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def get_config() -> dict[str, Any] | None:
6464
cutlass_moe_fp8,
6565
)
6666
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
67+
from vllm.model_executor.layers.fused_moe.fused_aiter_moe import (
68+
AiterExperts,
69+
)
6770
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
6871
BatchedTritonExperts,
6972
)
@@ -93,6 +96,7 @@ def get_config() -> dict[str, Any] | None:
9396
"BatchedDeepGemmExperts",
9497
"TritonOrDeepGemmExperts",
9598
"BatchedTritonOrDeepGemmExperts",
99+
"AiterExperts",
96100
]
97101
else:
98102
# Some model classes directly use the custom ops. Add placeholders

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,10 @@ def use_deepep_ht_kernels(self):
683683
def use_deepep_ll_kernels(self):
684684
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
685685

686+
@property
687+
def use_mori_kernels(self):
688+
return self.use_all2all_kernels and self.all2all_backend == "mori"
689+
686690
@staticmethod
687691
def flatten_tp_across_dp(
688692
tp_size: int, dp_size: int, dp_rank: int
@@ -875,6 +879,10 @@ def use_deepep_ht_kernels(self):
875879
def use_deepep_ll_kernels(self):
876880
return self.moe_parallel_config.use_deepep_ll_kernels
877881

882+
@property
883+
def use_mori_kernels(self):
884+
return self.moe_parallel_config.use_mori_kernels
885+
878886
@property
879887
def use_flashinfer_cutlass_kernels(self):
880888
"""
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import torch
5+
6+
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
7+
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
8+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
9+
rocm_aiter_fused_experts,
10+
)
11+
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
12+
TopKWeightAndReduceNoOP,
13+
)
14+
15+
16+
class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
17+
def __init__(self, quant_config: FusedMoEQuantConfig):
18+
super().__init__(quant_config)
19+
20+
@property
21+
def activation_formats(
22+
self,
23+
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
24+
return (
25+
mk.FusedMoEActivationFormat.Standard,
26+
mk.FusedMoEActivationFormat.Standard,
27+
)
28+
29+
def supports_chunking(self) -> bool:
30+
return True
31+
32+
def supports_expert_map(self) -> bool:
33+
return True
34+
35+
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
36+
return TopKWeightAndReduceNoOP()
37+
38+
def workspace_shapes(
39+
self,
40+
M: int,
41+
N: int,
42+
K: int,
43+
topk: int,
44+
global_num_experts: int,
45+
local_num_experts: int,
46+
expert_tokens_meta: mk.ExpertTokensMetadata | None,
47+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
48+
workspace13 = (M, K)
49+
workspace2 = (0,)
50+
output = (M, K)
51+
return (workspace13, workspace2, output)
52+
53+
def apply(
54+
self,
55+
output: torch.Tensor,
56+
hidden_states: torch.Tensor,
57+
w1: torch.Tensor,
58+
w2: torch.Tensor,
59+
topk_weights: torch.Tensor,
60+
topk_ids: torch.Tensor,
61+
activation: str,
62+
global_num_experts: int,
63+
expert_map: torch.Tensor | None,
64+
a1q_scale: torch.Tensor | None,
65+
a2_scale: torch.Tensor | None,
66+
workspace13: torch.Tensor,
67+
workspace2: torch.Tensor,
68+
expert_tokens_meta: mk.ExpertTokensMetadata | None,
69+
apply_router_weight_on_input: bool,
70+
) -> None:
71+
if expert_tokens_meta is not None:
72+
num_local_tokens = expert_tokens_meta.expert_num_tokens
73+
else:
74+
num_local_tokens = None
75+
76+
result = rocm_aiter_fused_experts(
77+
hidden_states,
78+
w1,
79+
w2,
80+
topk_weights=topk_weights,
81+
topk_ids=topk_ids,
82+
activation=activation,
83+
apply_router_weight_on_input=apply_router_weight_on_input,
84+
expert_map=expert_map,
85+
quant_config=self.quant_config,
86+
a1q_scale=a1q_scale,
87+
num_local_tokens=num_local_tokens,
88+
output_dtype=output.dtype,
89+
)
90+
output.copy_(result)

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from vllm.platforms import current_platform
5757
from vllm.platforms.interface import CpuArchEnum
5858
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
59-
from vllm.utils.import_utils import has_deep_ep, has_pplx
59+
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
6060
from vllm.utils.math_utils import cdiv, round_up
6161
from vllm.utils.torch_utils import current_stream, direct_register_custom_op
6262
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
@@ -76,6 +76,8 @@
7676
DEEPEP_QUANT_BLOCK_SHAPE,
7777
DeepEPLLPrepareAndFinalize,
7878
)
79+
if has_mori():
80+
from .mori_prepare_finalize import MoriPrepareAndFinalize
7981
else:
8082
fused_experts = None # type: ignore
8183
FusedMoEPermuteExpertsUnpermute = object # type: ignore
@@ -233,6 +235,36 @@ def _maybe_make_prepare_finalize(
233235
use_fp8_dispatch=use_fp8_dispatch,
234236
)
235237

238+
elif moe.use_mori_kernels:
239+
assert quant_config is not None
240+
# For PTPC (per token per channel) quant, the scale dim for each token is 1
241+
# For 1x128 quant, the scale dim for each token is hidden_dim // 128
242+
scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128
243+
all_to_all_args = dict(
244+
rank=all2all_manager.rank,
245+
num_ep_ranks=all2all_manager.world_size,
246+
quant_dtype=quant_config.quant_dtype,
247+
token_hidden_size=moe.hidden_dim,
248+
scale_dim=scale_dim,
249+
scale_type_size=torch.float32.itemsize,
250+
max_num_tokens_per_dp_rank=moe.max_num_tokens,
251+
input_dtype=moe.in_dtype,
252+
num_local_experts=moe.num_experts // all2all_manager.world_size,
253+
num_experts_per_token=moe.experts_per_token,
254+
)
255+
handle = all2all_manager.get_handle(all_to_all_args)
256+
257+
# Note: We may want to use FP8 dispatch just to reduce
258+
# data movement.
259+
use_fp8_dispatch = is_rocm_aiter_moe_enabled()
260+
261+
prepare_finalize = MoriPrepareAndFinalize(
262+
handle,
263+
max_tokens_per_rank=moe.max_num_tokens,
264+
num_dispatchers=all2all_manager.world_size,
265+
use_fp8_dispatch=use_fp8_dispatch,
266+
)
267+
236268
return prepare_finalize
237269

238270
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
@@ -1551,6 +1583,7 @@ def use_dp_chunking(self) -> bool:
15511583
return (
15521584
self.moe_parallel_config.use_pplx_kernels
15531585
or self.moe_parallel_config.use_deepep_ll_kernels
1586+
or self.moe_parallel_config.use_mori_kernels
15541587
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
15551588
)
15561589

0 commit comments

Comments
 (0)