Skip to content

Commit 4614144

Browse files
lucylqfacebook-github-bot
authored andcommitted
Save external constant tensors to custom filename (pytorch#15862)
Summary: Adding a Callable option in the config, so we can customize the nodes to save, and the filenames to save to. The pass is applied in _program.py, it's hard to do this on the eager model (like we do with delegates) as we have to run the pass after SpecPropPass. Reviewed By: ethansfng Differential Revision: D87280747
1 parent b4d72f1 commit 4614144

File tree

3 files changed

+49
-7
lines changed

3 files changed

+49
-7
lines changed

exir/capture/_config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-unsafe
88
from dataclasses import dataclass, field
9-
from typing import Dict, List, Optional, Union
9+
from typing import Callable, Dict, List, Optional, Union
1010

1111
import torch
1212

@@ -94,9 +94,14 @@ class ExecutorchBackendConfig:
9494
# Moreover, static views will be elided from the ExecuTorch graph
9595
remove_view_copy: bool = True
9696

97-
# If set to true, all constant tensors will be stored in a separate file,
98-
# external to the PTE file.
99-
external_constants: bool = False
97+
# Bool: if True, all constant tensors will be stored in a separate file. If False,
98+
# all constant tensors will be stored in the PTE file.
99+
# Callable: a function from torch.fx.Node to Optional[str]. This will be called for each
100+
# placeholder (constant tensor) node, and if it returns a string, that node will be
101+
# tagged with the string. If None, the constant tensor is stored in the PTE file.
102+
# Otherwise, it is stored in a file named by the string. E.g., a function
103+
# lambda x: "model_weights" will save all constants into a file "model_weights.ptd".
104+
external_constants: Union[bool, Callable[[torch.fx.Node], Optional[str]]] = False
100105

101106
# If set to true, all trainable weights will be stored in a separate file,
102107
# external to the PTE file.

exir/emit/test/test_emit.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,9 +1717,38 @@ def forward(self, x):
17171717
external_map = emitter_output.external_constant_map[
17181718
"_default_external_constant"
17191719
]
1720+
self.assertEqual(len(external_map), 2)
17201721
self.assertEqual(external_map["linear.weight"], 0)
17211722
self.assertEqual(external_map["linear.bias"], 1)
17221723

1724+
def test_constant_tagged_tensors_custom(self) -> None:
1725+
class LinearModule(torch.nn.Module):
1726+
def __init__(self):
1727+
super().__init__()
1728+
self.linear = torch.nn.Linear(5, 5)
1729+
1730+
def forward(self, x):
1731+
return self.linear(x)
1732+
1733+
model = to_edge(
1734+
export(LinearModule(), (torch.ones(5, 5),), strict=True)
1735+
).to_executorch(
1736+
config=ExecutorchBackendConfig(
1737+
external_constants=lambda x: (
1738+
"linear_weight" if "weight" in x.name else None
1739+
),
1740+
)
1741+
)
1742+
emitter_output = model._emitter_output
1743+
# constant_buffer contains placeholder and linear bias.
1744+
self.assertEqual(len(emitter_output.program.constant_buffer), 2)
1745+
# external constant buffer contains linear weight.
1746+
self.assertEqual(len(emitter_output.external_constant_buffer), 1)
1747+
# The lambda saves all constants to the key 'linear_weight'.
1748+
external_map = emitter_output.external_constant_map["linear_weight"]
1749+
self.assertEqual(len(external_map), 1)
1750+
self.assertEqual(external_map["linear.weight"], 0)
1751+
17231752
def test_constant_tagged_tensor_dedup(self) -> None:
17241753
class ConstantModule(nn.Module):
17251754
def __init__(self):

exir/program/_program.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,11 +1737,19 @@ def to_executorch( # noqa (FLAKE8) C901
17371737
# TODO(who?)
17381738
p.update_placeholder_tensor_specs(program, new_gm)
17391739

1740-
# Extract constants if the config says too.
1741-
if config.external_constants:
1740+
# Tag constant weights.
1741+
if (
1742+
isinstance(config.external_constants, bool)
1743+
and config.external_constants
1744+
):
17421745
new_gm_res = external_constants_pass(new_gm)
17431746
new_gm = new_gm_res.graph_module
1744-
elif config.external_mutable_weights:
1747+
elif callable(config.external_constants):
1748+
new_gm_res = external_constants_pass(new_gm, config.external_constants)
1749+
new_gm = new_gm_res.graph_module
1750+
1751+
# Tag mutable weights.
1752+
if config.external_mutable_weights:
17451753
new_gm_res = external_mutable_weights_pass(new_gm, program)
17461754
new_gm = new_gm_res.graph_module
17471755

0 commit comments

Comments
 (0)