Skip to content

Commit 2da9794

Browse files
committed
Use default FakeTensorMode when calling module without input
Before, _ExportPassBase set self.tracer.fake_tensor_mode to a default value, but didn't use it when tracing. This caused operators that only have fake tensor implementations to crash. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I8d7ef0cc841b0e46cd04ea4ed941b761798a76d2
1 parent d968e47 commit 2da9794

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

backends/arm/test/ops/test_cond.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,6 @@ def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]):
237237
"case",
238238
test_cases,
239239
xfails={
240-
"zero_args_one_output": "Since the submodules have no input, the tracer fails finding a fake tensor mode,"
241-
" and traces the graph with real tensors, which tosa.RESCALE can't handle.",
242240
"one_arg_and_scalar_one_output": "Incorrect quantization on the scalar.",
243241
"nested_one_arg_one_output": "Node submodule_0 target submodule_0 references nonexistent attribute submodule_0",
244242
},

exir/pass_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -593,14 +594,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult:
593594
), "Multiple fake tensor mode detected."
594595
fake_tensor_mode = i.fake_mode
595596
if fake_tensor_mode is None:
596-
self.tracer.fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
597-
fake_tensor_mode = nullcontext() # type: ignore[assignment]
597+
fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True)
598598
dispatcher_mode = nullcontext() # type: ignore[assignment]
599599
else:
600600
fake_tensor_mode.allow_non_fake_inputs = True
601-
self.tracer.fake_tensor_mode = fake_tensor_mode
602601
dispatcher_mode = enable_python_dispatcher() # type: ignore[assignment]
603-
self.fake_tensor_mode = self.tracer.fake_tensor_mode
602+
self.tracer.fake_tensor_mode = fake_tensor_mode
603+
self.fake_tensor_mode = fake_tensor_mode
604604

605605
with fake_tensor_mode, dispatcher_mode: # type: ignore[assignment, union-attr]
606606
result = self.call_submodule(graph_module, tuple(inputs))

0 commit comments

Comments
 (0)