Skip to content

Commit 900d90b

Browse files
[TORCH] Modified fx_importer to support hop_while_loop (#4338)
This PR adds support for emitting graphs for Pytorch HOPs, beginning with `torch._higher_order_ops.while_loop`. The proposed change is to modify the `import_program` to call function `_import_all_child_modules`, which recursively imports the stateless graph for all the children modules. Since HOP operator graphs are stateless graphs with no mutation, it is correct to import them as stateless graphs, although the method `import_stateless_graph` is marked as "deprecated". --------- Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent 288cd5e commit 900d90b

File tree

6 files changed

+351
-14
lines changed

6 files changed

+351
-14
lines changed

lib/Conversion/TorchToSCF/TorchToSCF.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern<PrimLoopOp> {
150150
targetType = Torch::IntType::get(op->getContext());
151151
torchArg = typeConverter->materializeSourceConversion(
152152
rewriter, scfWhileOp.getLoc(), targetType, {to});
153+
} else if (auto tty = dyn_cast<RankedTensorType>(targetType)) {
154+
targetType = op.getIterArgsInit()[barg.index()].getType();
155+
torchArg = typeConverter->materializeSourceConversion(
156+
rewriter, scfWhileOp.getLoc(), targetType, {to});
153157
}
154158
if (!torchArg)
155159
return rewriter.notifyMatchFailure(op,
@@ -173,14 +177,6 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern<PrimLoopOp> {
173177
"unsupported type of the operand");
174178
loopConditionIterArgs.push_back(shouldContinue);
175179
for (auto torchArg : primLoopConditionOp.getIterArgs()) {
176-
Type torchType = torchArg.getType();
177-
178-
// If the argument is a torch tensor, directly add it in the list of
179-
// iter args.
180-
if (isa<Torch::BaseTensorType>(torchType)) {
181-
loopConditionIterArgs.push_back(torchArg);
182-
continue;
183-
}
184180
Value arg = typeConverter->materializeTargetConversion(
185181
rewriter, scfWhileOp->getLoc(),
186182
typeConverter->convertType(torchArg.getType()), {torchArg});

lib/Dialect/Torch/Transforms/Passes.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline(
7070

7171
void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline(
7272
OpPassManager &pm, const TorchLoweringPipelineOptions &options) {
73+
// Inline func.call operations created by higher-order ops like while_loop
74+
// to conform to the linalg-on-tensors backend contract.
75+
pm.addPass(createInlinerPass());
7376
pm.addNestedPass<func::FuncOp>(
7477
createReduceOpVariantsPass(options.extraLibrary));
7578
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@
246246
"IsFloatingPointInt_False",
247247
"TorchPrimLoopForLikeModule_basic",
248248
"TorchPrimLoopWhileLikeModule_basic",
249+
# torch._dynamo.exc.BackendCompilerFailed: Unsupported op: get_attr
250+
"TorchPrimLoopWhileLikeHOPModule_basic",
249251
"ScalarConstantTupleModule_basic",
250252
# END tests failing due to: empty graph in dynamo
251253
# ERROR due to: backend never runs because of empty frame
@@ -481,6 +483,7 @@
481483
"TensorToBoolZeroRank_basic",
482484
"TensorToBool_basic",
483485
"ThresholdBackward2dMixedModule_basic",
486+
"TorchPrimLoopWhileLikeHOPModule_basic", # Compilation error: failed to legalize operation 'func.call'
484487
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
485488
"UpSampleNearest2dDynamicFactor_basic",
486489
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
@@ -982,6 +985,8 @@
982985
"ElementwiseClampMinModule_bfloat16",
983986
"ElementwiseClampModule_bfloat16",
984987
"ElementwiseReluModule_bfloat16",
988+
# Runtime error: failed to legalize operation 'torch.constant.int'
989+
"TorchPrimLoopWhileLikeHOPModule_basic",
985990
}
986991

987992
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
@@ -2564,6 +2569,7 @@
25642569

25652570
LTC_XFAIL_SET = {
25662571
"TorchPrimLoopForLikeTensorArgModule_basic" "CollapseAllDimensionsModule_basic",
2572+
"TorchPrimLoopWhileLikeHOPModule_basic",
25672573
"CollapseRank1DynamicModule_basic",
25682574
"CollapseStaticModule_basic",
25692575
"CollapsePartialDynamicModule_basic",
@@ -3253,6 +3259,8 @@
32533259
"ToCopyWithDTypeModule_basic",
32543260
"TorchPrimLoopForLikeModule_basic",
32553261
"TorchPrimLoopWhileLikeModule_basic",
3262+
# RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function
3263+
"TorchPrimLoopWhileLikeHOPModule_basic",
32563264
"TraceModule_basic",
32573265
"TraceModule_empty",
32583266
"TraceModule_nonsquare",
@@ -3952,6 +3960,8 @@
39523960
"ThresholdBackward2dMixedModule_basic",
39533961
"TorchPrimLoopForLikeModule_basic",
39543962
"TorchPrimLoopWhileLikeModule_basic",
3963+
# Runtime error: failed to legalize operation 'torch.aten.Bool.Tensor'
3964+
"TorchPrimLoopWhileLikeHOPModule_basic",
39553965
"TraceModule_empty",
39563966
"TraceUnsignedIntModule_empty",
39573967
"TransposedConv1dNegativePadding_basic",
@@ -5024,6 +5034,8 @@
50245034
"ToDtypeFloatFromIntModule_basic",
50255035
"TorchPrimLoopForLikeModule_basic",
50265036
"TorchPrimLoopWhileLikeModule_basic",
5037+
# RuntimeError: Detected that you are using FX to torch.jit.trace a dynamo-optimized function
5038+
"TorchPrimLoopWhileLikeHOPModule_basic",
50275039
"TraceModule_basic",
50285040
"TraceModule_empty",
50295041
"TraceModule_nonsquare",

projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch_mlir_e2e_test.framework import TestUtils
1111
from torch_mlir_e2e_test.registry import register_test_case
1212
from torch_mlir_e2e_test.annotations import annotate_args, export
13+
from torch._higher_order_ops.while_loop import while_loop
1314

1415
# ==============================================================================
1516

@@ -78,3 +79,36 @@ def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils):
7879
x_test = torch.zeros([7, 9]).float()
7980

8081
module.forward(x_test)
82+
83+
84+
# ==============================================================================
85+
86+
87+
class TorchPrimLoopWhileLikeHOPModule(torch.nn.Module):
88+
def __init__(self):
89+
super().__init__()
90+
91+
def body_fn(self, i, x):
92+
return i + 1, x + 1
93+
94+
def cond_fn(self, i, x):
95+
return i < 3
96+
97+
@export
98+
@annotate_args(
99+
[
100+
None,
101+
([7, 9], torch.float32, True),
102+
]
103+
)
104+
def forward(self, x: torch.Tensor) -> torch.Tensor:
105+
i0 = torch.tensor(0)
106+
out_i, out_x = while_loop(self.cond_fn, self.body_fn, (i0, x))
107+
return out_i, out_x
108+
109+
110+
@register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeHOPModule())
111+
def TorchPrimLoopWhileLikeHOPModule_basic(module, tu: TestUtils):
112+
x_test = torch.zeros([7, 9]).float()
113+
114+
module.forward(x_test)

0 commit comments

Comments
 (0)