Skip to content

Commit 8083ea7

Browse files
angelayimeta-codesync[bot]
authored andcommitted
Fix fake mode detection
Summary: it's not necessarily true that the 0th placeholder's meta["val"] contains a tensor. replaced this with _detect_fake_mode_from_gm which is safer Differential Revision: D87577206
1 parent b4d72f1 commit 8083ea7

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

exir/program/_program.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
get_aten_verifier,
8181
)
8282
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
83+
from torch._export.utils import _detect_fake_mode_from_gm
8384
from torch._export.verifier import Verifier
8485
from torch.export import ExportedProgram
8586
from torch.export._remove_auto_functionalized_pass import (
@@ -333,7 +334,8 @@ def lift_constant_tensor_pass(ep):
333334
graph_signature = ep.graph_signature
334335
buffers = list(graph_signature.buffers)
335336

336-
fake_mode = list(ep.graph.nodes)[0].meta["val"].fake_mode
337+
fake_mode = _detect_fake_mode_from_gm(ep.graph_module)
338+
337339
first_user_input = None
338340
lifted_constants = []
339341
for node in ep.graph.nodes:

0 commit comments

Comments
 (0)