Skip to content

Commit e38ef2c

Browse files
momo609wangxiaoxin-sherie
andauthored
support FULL graph mode for GQA (#3970)
### What this PR does / why we need it? The current library only supports the FullDecodeOnly graph mode, which enables full graph execution during the decode. This PR extends support to allow full graph execution in both the prefill and decode, referred to as FULL graph mode. - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@2918c1b Signed-off-by: wangxiaoxin-sherie <[email protected]> Co-authored-by: wangxiaoxin-sherie <[email protected]>
1 parent c334114 commit e38ef2c

File tree

11 files changed

+328
-296
lines changed

11 files changed

+328
-296
lines changed

.github/workflows/_e2e_test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ jobs:
180180
if: ${{ inputs.type == 'full' }}
181181
run: |
182182
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py
183+
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
183184
pytest -sv tests/e2e/multicard/test_data_parallel.py
184185
pytest -sv tests/e2e/multicard/test_expert_parallel.py
185186
pytest -sv tests/e2e/multicard/test_external_launcher.py

tests/e2e/multicard/test_full_graph_mode.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from tests.e2e.model_utils import check_outputs_equal
3030

3131

32-
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
32+
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL_DECODE_ONLY():
3333
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
3434
del os.environ['HCCL_OP_EXPANSION_MODE']
3535
prompts = [
@@ -42,15 +42,64 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
4242
max_model_len=1024,
4343
tensor_parallel_size=2,
4444
enforce_eager=False,
45-
compilation_config={"cudagraph_mode":
46-
"FULL_DECODE_ONLY"}) as runner:
45+
compilation_config={
46+
"cudagraph_mode": "FULL_DECODE_ONLY",
47+
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
48+
}) as runner:
4749
vllm_fullgraph_outputs = runner.model.generate(prompts,
4850
sampling_params)
4951

5052
with VllmRunner(
5153
model,
5254
max_model_len=1024,
53-
enforce_eager=True,
55+
tensor_parallel_size=2,
56+
enforce_eager=False,
57+
) as runner:
58+
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
59+
60+
vllm_fullgraph_outputs_list = []
61+
for output in vllm_fullgraph_outputs:
62+
vllm_fullgraph_outputs_list.append(
63+
(output.outputs[0].index, output.outputs[0].text))
64+
65+
vllm_eager_outputs_list = []
66+
for output in vllm_eager_outputs:
67+
vllm_eager_outputs_list.append(
68+
(output.outputs[0].index, output.outputs[0].text))
69+
70+
check_outputs_equal(
71+
outputs_0_lst=vllm_eager_outputs_list,
72+
outputs_1_lst=vllm_fullgraph_outputs_list,
73+
name_0="vllm_eager_outputs",
74+
name_1="vllm_fullgraph_outputs",
75+
)
76+
77+
78+
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULL():
79+
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
80+
del os.environ['HCCL_OP_EXPANSION_MODE']
81+
prompts = [
82+
"Hello, my name is", "The president of the United States is",
83+
"The capital of France is", "The future of AI is"
84+
]
85+
model = "Qwen/Qwen3-30B-A3B"
86+
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
87+
with VllmRunner(model,
88+
max_model_len=1024,
89+
tensor_parallel_size=2,
90+
enforce_eager=False,
91+
compilation_config={
92+
"cudagraph_mode": "FULL",
93+
"cudagraph_capture_sizes": [4, 8, 24, 48, 60]
94+
}) as runner:
95+
vllm_fullgraph_outputs = runner.model.generate(prompts,
96+
sampling_params)
97+
98+
with VllmRunner(
99+
model,
100+
max_model_len=1024,
101+
tensor_parallel_size=2,
102+
enforce_eager=False,
54103
) as runner:
55104
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
56105

tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def mtp_correctness(sampling_config: SamplingParams,
4646

4747
graph_mode_str = "PIECEWISE"
4848
if graph_mode == CUDAGraphMode.FULL:
49-
graph_mode_str = "FULL"
49+
graph_mode_str = "FULL_DECODE_ONLY"
5050

5151
with VllmRunner(
5252
model_name,
@@ -63,7 +63,9 @@ def mtp_correctness(sampling_config: SamplingParams,
6363
enforce_eager=enforce_eager,
6464
max_model_len=2000,
6565
compilation_config=CompilationConfig(
66-
cudagraph_mode=graph_mode_str),
66+
cudagraph_mode=graph_mode_str,
67+
cudagraph_capture_sizes=[12],
68+
),
6769
additional_config={"ascend_scheduler_config": {
6870
"enabled": False
6971
}}) as spec_llm:

0 commit comments

Comments
 (0)