|
10 | 10 | from vllm.forward_context import get_forward_context |
11 | 11 | from vllm.logger import init_logger |
12 | 12 | from vllm.utils.flashinfer import has_flashinfer_all2all |
13 | | -from vllm.utils.import_utils import has_deep_ep, has_pplx |
| 13 | +from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx |
14 | 14 |
|
15 | 15 | from .base_device_communicator import All2AllManagerBase, Cache |
16 | 16 |
|
@@ -488,3 +488,77 @@ def cleanup(self): |
488 | 488 | self.prepare_workspace_tensor = None |
489 | 489 | self.mapping = None |
490 | 490 | self.initialized = False |
| 491 | + |
| 492 | + |
| 493 | +class MoriAll2AllManager(All2AllManagerBase): |
| 494 | + def __init__(self, cpu_group): |
| 495 | + assert has_mori(), ( |
| 496 | + "MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md" |
| 497 | + " to install MoRI kernels." |
| 498 | + ) # noqa |
| 499 | + import mori |
| 500 | + |
| 501 | + super().__init__(cpu_group) |
| 502 | + self.handle_cache = Cache() |
| 503 | + |
| 504 | + torch._C._distributed_c10d._register_process_group("mori", cpu_group) |
| 505 | + mori.shmem.shmem_torch_process_group_init("mori") |
| 506 | + |
| 507 | + def _make_all2all_kwargs( |
| 508 | + self, |
| 509 | + rank: int, |
| 510 | + num_ep_ranks: int, |
| 511 | + input_dtype: torch.dtype, |
| 512 | + quant_dtype: torch.dtype, |
| 513 | + token_hidden_size: int, |
| 514 | + scale_dim: int, |
| 515 | + scale_type_size: int, |
| 516 | + max_num_tokens_per_dp_rank: int, |
| 517 | + num_local_experts: int, |
| 518 | + num_experts_per_token: int, |
| 519 | + ): |
| 520 | + import mori # type: ignore[import-not-found] |
| 521 | + |
| 522 | + if self.internode: |
| 523 | + # multi node |
| 524 | + kernel_type = mori.ops.EpDispatchCombineKernelType.InterNode |
| 525 | + warp_num_per_block = 16 |
| 526 | + block_num = 64 |
| 527 | + else: |
| 528 | + # single node |
| 529 | + kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode |
| 530 | + warp_num_per_block = 16 |
| 531 | + block_num = 64 |
| 532 | + |
| 533 | + return dict( |
| 534 | + rank=rank, |
| 535 | + world_size=num_ep_ranks, |
| 536 | + data_type=quant_dtype, |
| 537 | + hidden_dim=token_hidden_size, |
| 538 | + scale_dim=scale_dim, |
| 539 | + scale_type_size=scale_type_size, |
| 540 | + max_token_type_size=input_dtype.itemsize, |
| 541 | + max_num_inp_token_per_rank=max_num_tokens_per_dp_rank, |
| 542 | + num_experts_per_rank=num_local_experts, |
| 543 | + num_experts_per_token=num_experts_per_token, |
| 544 | + warp_num_per_block=warp_num_per_block, |
| 545 | + block_num=block_num, |
| 546 | + kernel_type=kernel_type, |
| 547 | + ) |
| 548 | + |
| 549 | + def _make_handle(self, **kwargs): |
| 550 | + import mori # type: ignore[import-not-found] |
| 551 | + |
| 552 | + mori_config = mori.ops.EpDispatchCombineConfig(**kwargs) |
| 553 | + handle = mori.ops.EpDispatchCombineOp(mori_config) |
| 554 | + return handle |
| 555 | + |
| 556 | + def get_handle(self, kwargs): |
| 557 | + import mori # type: ignore[import-not-found] |
| 558 | + |
| 559 | + mori_kwargs = self._make_all2all_kwargs(**kwargs) |
| 560 | + logger.debug("MoRI all2all args %s", mori_kwargs) |
| 561 | + handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create( |
| 562 | + mori_kwargs, self._make_handle |
| 563 | + ) |
| 564 | + return handle |
0 commit comments