Skip to content

Commit 262d263

Browse files
authored
[Bugfix] Eliminate tuple inputs to submodules in graph partitioning (vllm-project#28533)
Signed-off-by: Yanan Cao <[email protected]>
1 parent 968060c commit 262d263

File tree

3 files changed

+140
-2
lines changed

3 files changed

+140
-2
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@ steps:
445445
- vllm/
446446
- tests/compile
447447
commands:
448+
- pytest -v -s compile/test_graph_partition.py
448449
- pytest -v -s compile/test_config.py
449450
- pytest -v -s compile/test_pass_manager.py
450451
- pytest -v -s compile/test_fusion.py
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import operator
5+
6+
import pytest
7+
import torch
8+
from torch.fx.experimental.proxy_tensor import make_fx
9+
10+
from vllm.compilation.backends import split_graph
11+
12+
13+
def test_getitem_moved_to_producer_subgraph():
14+
"""
15+
Test that getitem operations are moved to the same subgraph as their input,
16+
preventing tuple inputs to submodules.
17+
"""
18+
19+
def model_fn(x: torch.Tensor) -> torch.Tensor:
20+
# torch.split returns a tuple, creating real getitem operations
21+
# Should become first submodule that produces tuple
22+
chunks = torch.split(x, x.shape[0] // 2, dim=0)
23+
24+
# Following ops should become second submodule that consumes tuple
25+
result_0 = torch.relu(chunks[0])
26+
result_1 = torch.relu(chunks[1])
27+
return torch.cat([result_0, result_1], dim=0)
28+
29+
x = torch.randn(4, 3)
30+
gm = make_fx(model_fn)(x)
31+
32+
has_getitem = any(
33+
node.op == "call_function" and node.target == operator.getitem
34+
for node in gm.graph.nodes
35+
)
36+
assert has_getitem, "Test setup failed: graph should contain getitem operations"
37+
38+
# Split on tuple producer aten::split
39+
split_ops = ["aten::split.Tensor"]
40+
split_gm, split_items = split_graph(gm, split_ops)
41+
assert len(split_items) == 2, "Graph should be split into 2 submodules"
42+
43+
for split_item in split_items:
44+
submodule = split_item.graph
45+
46+
getitem_on_placeholder = []
47+
for node in submodule.graph.nodes:
48+
if (
49+
node.op == "call_function"
50+
and node.target == operator.getitem
51+
and node.args[0].op == "placeholder"
52+
):
53+
getitem_on_placeholder.append(node)
54+
55+
assert len(getitem_on_placeholder) == 0, (
56+
f"Submodule {split_item.submod_name} has getitem operations on "
57+
f"placeholder nodes: {[n.name for n in getitem_on_placeholder]}. "
58+
"This means tuple inputs were not properly eliminated."
59+
)
60+
61+
new_x = torch.randn(4, 3)
62+
output_original = gm(new_x)
63+
output_split = split_gm(new_x)
64+
65+
assert torch.allclose(output_original, output_split), "Output mismatch"
66+
67+
68+
def test_no_tuple_inputs_with_multiple_consumers():
69+
"""
70+
Test that when a tuple is consumed by multiple split operations,
71+
getitem operations are properly moved to avoid tuple inputs.
72+
"""
73+
74+
def model_fn(x: torch.Tensor) -> torch.Tensor:
75+
# torch.split returns a tuple, creating real getitem operations
76+
# Should become first submodule that produces tuple
77+
chunks = torch.split(x, x.shape[0] // 2, dim=0)
78+
79+
# These should become second submodule consuming tuple
80+
result_1 = torch.relu(chunks[0])
81+
result_2 = torch.relu(chunks[1])
82+
83+
# Artificial graph splitting point to create another
84+
# independent submodule that consumes tuple later
85+
# This would become the third submodule
86+
result_1 = torch.sigmoid(result_1)
87+
88+
# Fourth submodule that consumes tuple
89+
result = torch.cat([chunks[0], chunks[1], result_1, result_2])
90+
return result
91+
92+
x = torch.randn(4, 3)
93+
gm = make_fx(model_fn)(x)
94+
95+
has_getitem = any(
96+
node.op == "call_function" and node.target == operator.getitem
97+
for node in gm.graph.nodes
98+
)
99+
assert has_getitem, "Test setup failed: graph should contain getitem operations"
100+
101+
split_ops = ["aten::split.Tensor", "aten::sigmoid"]
102+
split_gm, split_items = split_graph(gm, split_ops)
103+
assert len(split_items) == 4, "Graph should be split into 4 submodules"
104+
105+
for split_item in split_items:
106+
submodule = split_item.graph
107+
108+
for node in submodule.graph.nodes:
109+
if (
110+
node.op == "call_function"
111+
and node.target == operator.getitem
112+
and node.args[0].op == "placeholder"
113+
):
114+
pytest.fail(
115+
f"Submodule {split_item.submod_name} has getitem on "
116+
f"placeholder {node.args[0].name}, indicating it receives "
117+
"a tuple input"
118+
)
119+
120+
new_x = torch.randn(4, 3)
121+
output_original = gm(new_x)
122+
output_split = split_gm(new_x)
123+
124+
assert torch.allclose(output_original, output_split), "Output mismatch after split"

vllm/compilation/backends.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import ast
55
import dataclasses
66
import hashlib
7+
import operator
78
import os
89
import pprint
910
import time
@@ -307,12 +308,24 @@ def split_graph(
307308
) -> tuple[fx.GraphModule, list[SplitItem]]:
308309
# split graph by ops
309310
subgraph_id = 0
310-
node_to_subgraph_id = {}
311-
split_op_graphs = []
311+
node_to_subgraph_id: dict[fx.Node, int] = {}
312+
split_op_graphs: list[int] = []
312313
for node in graph.graph.nodes:
313314
if node.op in ("output", "placeholder"):
314315
continue
315316

317+
# Check if this is a getitem operation on a node from an earlier subgraph.
318+
# If so, assign it to the same subgraph as its input to avoid passing entire
319+
# tuple as input to submodules, which is against standalone_compile and
320+
# AoTAutograd input requirement.
321+
if node.op == "call_function" and node.target == operator.getitem:
322+
# Assign this getitem to the same subgraph as its input
323+
input_node = node.args[0]
324+
if input_node.op != "placeholder":
325+
assert input_node in node_to_subgraph_id
326+
node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
327+
continue
328+
316329
if should_split(node, splitting_ops):
317330
subgraph_id += 1
318331
node_to_subgraph_id[node] = subgraph_id

0 commit comments

Comments
 (0)