Skip to content

Conversation

@alexsun07
Copy link

@alexsun07 alexsun07 commented Oct 21, 2025

Purpose

This PR is to integrate MoRI-EP, a high performance all2all comm kernel, with vLLM as an all2all backend. See MoRI project here. And MoRI supports cuda graph.

This PR follows the design of vLLM's Fused MoE Modular Kernel. The Fused MoE Modular Kernel is composed of following components:
[Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine]

For MoRI+AITER path, which is the high performance practice from AMD, it would be:
[Router] → [Quantize-Dispatch] → [Experts] → [Combine]

Two new classes are introduced:

  • MoriPrepareAndFinalize: do the [Quantize-Dispatch] and [Combine]
  • AiterExperts: do the [Experts] and don't do permute or unpermute

Summary of performance comparison between MoRI-EP and naive backend (bs=128 per DP rank):

all2all EP size Mean TPOT Output tps per node perf
naive 8 128.42 7119.64 1.00x
mori 8 94.14 9439.57 1.33x
naive (eager) 16 305.36 2740.34 1.00x
mori 16 110.87 7343.28 2.68x

How to install MoRI

See https://github.com/ROCm/mori

Test Plan

Test platform: MI300X

Accuracy

Serve on DeepSeek-V3/R1 (Block scale quant)

VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MOE=1 \
VLLM_USE_V1=1 \
vllm serve deepseek-ai/DeepSeek-V3 \
    -tp 1 \
    -dp 8 \
    --port 30000 \
    --all2all-backend mori \
    --enable-expert-parallel

Serve on DeepSeek-R1-PTPC (per token per channel quant)
see here for more info about PTPC.

VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MOE=1 \
VLLM_USE_V1=1 \
vllm serve EmbeddedLLM/deepseek-r1-FP8-Dynamic \
    -tp 1 \
    -dp 8 \
    --port 30000 \
    --all2all-backend mori \
    --enable-expert-parallel

Evaluate by gsm8k

lm_eval --model local-completions \
    --tasks gsm8k \
    --model_args model=<model_path>,base_url=http://localhost:30000/v1/completions,num_concurrent=256,max_retries=3,tokenized_requests=False 

Performance

Test EP8 and EP16 performance, compare with naive all2all backend

EP8 with mori backend

VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MOE=1 \
VLLM_USE_V1=1 \
VLLM_MOE_DP_CHUNK_SIZE=512 \
vllm serve EmbeddedLLM/deepseek-r1-FP8-Dynamic \
    -tp 1 \
    -dp 8 \
    --port 30000 \
    --all2all-backend mori \
    --max-num-seqs 128 \
    --enable-expert-parallel \
    --cudagraph-capture-sizes 1 2 4 8 16 32 64 128

EP8 with naive backend:
replace --all2all-backend mori with --all2all-backend naive.

EP16 with mori backend

# node0
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MOE=1 \
VLLM_USE_V1=1 \
VLLM_MOE_DP_CHUNK_SIZE=512 \
vllm serve /nfs/DeepSeek-R1-PTPC \
    -dp 16 \
    --data-parallel-size-local 8 \
    --data-parallel-address <node-0-ip> --data-parallel-rpc-port <node-0-port> \
    --enable-expert-parallel \
    --all2all-backend mori \
    --port 30000 \
    --max-num-seqs 128 \
    --cuda-graph-sizes 1 2 4 8 16 32 64 128 \
    --trust-remote-code 

# node1
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_MOE=1 \
VLLM_USE_V1=1 \
VLLM_MOE_DP_CHUNK_SIZE=512 \
vllm serve /nfs/DeepSeek-R1-PTPC \
    -dp 16 \
    --headless \
    --data-parallel-size-local 8 \
    --data-parallel-start-rank 8 \
    --data-parallel-address <node-0-ip> --data-parallel-rpc-port <node-0-port> \
    --enable-expert-parallel \
    --all2all-backend mori \
    --port 30000 \
    --max-num-seqs 128 \
    --cuda-graph-sizes 1 2 4 8 16 32 64 128 \
    --trust-remote-code

EP16 with naive backend:
replace --all2all-backend mori with --all2all-backend naive, and use --enforce-eager.

Benchmark:
use --random-input-len 1 --random-prefix-len 1023 because we want to simulate the PD disagg and test decode performance without prefill.

vllm bench serve \
    --max-concurrency <1024 * node_num> \
    --num-prompts <4096 * node_num> \
    --model <model_path>
    --port 30000 \
    --ignore-eos \
    --trust-remote-code \
    --dataset-name random \
    --seed 2025 \
    --random-input-len 1 \
    --random-prefix-len 1023 \
    --random-output-len 500

Test Result

Accuracy

MoRI-EP with DeepSeek-R1-PTPC

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9538|±  |0.0058|
|     |       |strict-match    |     5|exact_match|↑  |0.9530|±  |0.0058|

Decode Performance

Summary

all2all EP size Mean TPOT Output tps per node perf
naive 8 128.42 7119.64 1.00x
mori 8 94.14 9439.57 1.33x
naive (eager) 16 305.36 2740.34 1.00x
mori 16 110.87 7343.28 2.68x

EP8 mori all2all backend

============ Serving Benchmark Result ============
Successful requests:                     4096      
Failed requests:                         0         
Maximum request concurrency:             1024      
Benchmark duration (s):                  216.96    
Total input tokens:                      4190208   
Total generated tokens:                  2048000   
Request throughput (req/s):              18.88     
Output token throughput (tok/s):         9439.57   
Peak output token throughput (tok/s):    13171.00  
Peak concurrent requests:                1152.00   
Total Token throughput (tok/s):          28752.92  
---------------Time to First Token----------------
Mean TTFT (ms):                          3079.99   
Median TTFT (ms):                        1172.27   
P99 TTFT (ms):                           14658.47  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          94.14     
Median TPOT (ms):                        95.69     
P99 TPOT (ms):                           98.65     
---------------Inter-token Latency----------------
Mean ITL (ms):                           105.46    
Median ITL (ms):                         84.14     
P99 ITL (ms):                            503.41    
==================================================

EP8 naive all2all backend

============ Serving Benchmark Result ============
Successful requests:                     4096      
Failed requests:                         0         
Maximum request concurrency:             1024      
Benchmark duration (s):                  287.65    
Total input tokens:                      4190208   
Total generated tokens:                  2048000   
Request throughput (req/s):              14.24     
Output token throughput (tok/s):         7119.64   
Peak output token throughput (tok/s):    10230.00  
Peak concurrent requests:                1152.00   
Total Token throughput (tok/s):          21686.42  
---------------Time to First Token----------------
Mean TTFT (ms):                          3118.80   
Median TTFT (ms):                        1093.97   
P99 TTFT (ms):                           15430.33  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          128.42    
Median TPOT (ms):                        129.82    
P99 TPOT (ms):                           137.51    
---------------Inter-token Latency----------------
Mean ITL (ms):                           133.46    
Median ITL (ms):                         112.55    
P99 ITL (ms):                            513.15    
==================================================

EP16 mori all2all backend

============ Serving Benchmark Result ============
Successful requests:                     8192
Failed requests:                         0
Maximum request concurrency:             2048
Benchmark duration (s):                  278.89
Total input tokens:                      8380416
Total generated tokens:                  4096000
Request throughput (req/s):              29.37
Output token throughput (tok/s):         14686.55
Peak output token throughput (tok/s):    20942.00
Peak concurrent requests:                2271.00
Total Token throughput (tok/s):          44735.22
---------------Time to First Token----------------
Mean TTFT (ms):                          10838.91
Median TTFT (ms):                        7431.13
P99 TTFT (ms):                           34603.02
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          110.87
Median TPOT (ms):                        111.76
P99 TPOT (ms):                           127.69
---------------Inter-token Latency----------------
Mean ITL (ms):                           209.21
Median ITL (ms):                         94.86
P99 ITL (ms):                            864.02
==================================================

EP16 naive all2all backend

============ Serving Benchmark Result ============
Successful requests:                     8192
Failed requests:                         0
Maximum request concurrency:             2048
Benchmark duration (s):                  747.35
Total input tokens:                      8380416
Total generated tokens:                  4096000
Request throughput (req/s):              10.96
Output token throughput (tok/s):         5480.68
Peak output token throughput (tok/s):    9665.00
Peak concurrent requests:                2187.00
Total Token throughput (tok/s):          16694.17
---------------Time to First Token----------------
Mean TTFT (ms):                          10112.99
Median TTFT (ms):                        7514.72
P99 TTFT (ms):                           35132.93
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          305.36
Median TPOT (ms):                        305.49
P99 TPOT (ms):                           317.03
---------------Inter-token Latency----------------
Mean ITL (ms):                           328.70
Median ITL (ms):                         297.74
P99 ITL (ms):                            857.16
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@alexsun07 alexsun07 marked this pull request as draft October 21, 2025 17:03
@mergify mergify bot added the rocm Related to AMD ROCm label Oct 21, 2025
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates MoRI, a high-performance all-to-all communication kernel, as a new backend for vLLM, primarily targeting AMD GPUs. The changes span across several files to add the necessary configurations, manager class, and logic to use this new backend. While the integration is mostly well-structured, I've identified a couple of areas for improvement related to code duplication and consistency, which I've detailed in the comments.

@HAIAI HAIAI self-requested a review October 21, 2025 18:03
@mergify
Copy link

mergify bot commented Oct 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @alexsun07.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Alex Sun <[email protected]>
@alexsun07 alexsun07 changed the title [WIP][AMD] MoRI EP integration [AMD][ROCm] MoRI EP: a high-performance all2all backend Nov 4, 2025
@alexsun07 alexsun07 marked this pull request as ready for review November 5, 2025 02:45
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant