Skip to content

Commit 53da451

Browse files
Arm backend: Add support for quant. decomposition (#15993)
For mixed type models we need to be able switch between FP and INT, meaning quantize and dequantize at runtime. As there are no quantize or dequantize operators in TOSA, we need to decompose these operators to TOSA operators. This commit introduces a pass that decomposes q-dq nodes into more primitive nodes. This also affects Cortex-M backend as it uses the qdq-folding pass in arm/_passes. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 8af8252 commit 53da451

14 files changed

+304
-44
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
6363
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
6464
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
65+
from .decompose_quant_nodes import DecomposeQuantNodesPass # noqa
6566
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
6667
from .decompose_round_pass import DecomposeRoundPass # noqa
6768
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
DecomposeMaxPool2dPass,
6666
DecomposeMeanDimPass,
6767
DecomposeNotEqualPass,
68+
DecomposeQuantNodesPass,
6869
DecomposeRemainderPass,
6970
DecomposeRoundPass,
7071
DecomposeScaledDotProductAttentionPass,
@@ -187,7 +188,7 @@ def _tosa_pipeline(
187188
]
188189
)
189190

190-
# Fold Q/DQ nodes, insert INT8/INT32 rescales.
191+
# Fold Q/DQ nodes, insert INT8/INT32 rescales, decompose quantization nodes.
191192
self.add_passes(
192193
[
193194
FoldAndAnnotateQParamsPass(exported_program),
@@ -198,6 +199,7 @@ def _tosa_pipeline(
198199
DecomposeLinearPass(),
199200
InsertRescaleInt32Pass(),
200201
InsertControlFlowRescalesPass(),
202+
DecomposeQuantNodesPass(),
201203
]
202204
)
203205

backends/arm/_passes/convert_elu_params.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def call(self, graph_module: torch.fx.GraphModule):
3838
if not is_quantized:
3939
continue
4040
with graph.inserting_after(node):
41-
replace_node = create_node(graph, exir_ops.edge.aten.elu.default)
41+
replace_node = create_node(
42+
graph, exir_ops.edge.aten.elu.default, from_node=node
43+
)
4244
old_args = list(node.args)
4345

4446
alpha = old_args[1] if len(old_args) > 1 else 1.0

backends/arm/_passes/convert_minmax_pass.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
import torch
99
from executorch.backends.arm._passes.arm_pass import ArmPass
10-
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
10+
from executorch.backends.arm._passes.arm_pass_utils import (
11+
create_node,
12+
get_first_fake_tensor,
13+
)
1114
from executorch.backends.arm._passes.convert_squeezes_to_view import (
1215
ConvertSqueezesToViewPass,
1316
)
@@ -131,15 +134,21 @@ def call(self, graph_module: torch.fx.GraphModule):
131134

132135
for dim in dims:
133136
args = (input_node, dim, True)
134-
input_node = graph_module.graph.create_node(
135-
"call_function", op, args, node.kwargs
137+
input_node = create_node(
138+
graph=graph_module.graph,
139+
op_target=op,
140+
args=args,
141+
kwargs={},
142+
from_node=node,
136143
)
137144

138145
if not keepdims:
139-
input_node = graph_module.graph.create_node(
140-
"call_function",
141-
squeeze_op,
142-
(input_node, dims),
146+
input_node = create_node(
147+
graph=graph_module.graph,
148+
op_target=squeeze_op,
149+
args=(input_node, dims),
150+
kwargs={},
151+
from_node=node,
143152
)
144153

145154
replace_node.replace_all_uses_with(input_node)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast, Set, Type
7+
8+
import torch
9+
from executorch.backends.arm._passes.arm_pass import ArmPass
10+
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.backends.arm._passes.decompose_round_pass import DecomposeRoundPass
12+
from executorch.backends.arm.constants import DEQUANT_PER_TENSOR_OP, QUANT_PER_TENSOR_OP
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
17+
class DecomposeQuantNodesPass(ArmPass):
18+
"""Decomposes quantization nodes into more primitive operations by rewriting the graph
19+
using the two formulas:
20+
21+
quantized value = clamp(round(fp32_value / scale) + zero point, qmin, qmax)
22+
23+
fp32_value = (quantized value - zp) * scale
24+
25+
For quantization nodes, the pass replaces them with:
26+
27+
1. Multiplying the input by the inverse of the scale factor.
28+
2. Rounding the result.
29+
3. Adding the zero point.
30+
4. Clamping the result to [qmin, qmax].
31+
5. Casting to the target data type.
32+
33+
For dequantization nodes, the pass replaces them with:
34+
35+
1. Casting the input to int32.
36+
2. Subtracting the zero point.
37+
3. Casting to float32.
38+
4. Multiplying by the scale factor.
39+
40+
"""
41+
42+
_passes_required_after: Set[Type[ExportPass]] = {DecomposeRoundPass}
43+
44+
def call(self, graph_module: torch.fx.GraphModule):
45+
modified = False
46+
for node in list(graph_module.graph.nodes):
47+
if node.op != "call_function" or node.target not in (
48+
QUANT_PER_TENSOR_OP,
49+
DEQUANT_PER_TENSOR_OP,
50+
):
51+
continue
52+
if node.target == DEQUANT_PER_TENSOR_OP and all(
53+
user.target == QUANT_PER_TENSOR_OP for user in node.users
54+
):
55+
continue
56+
elif (
57+
node.target == QUANT_PER_TENSOR_OP
58+
and node.all_input_nodes[0].target == DEQUANT_PER_TENSOR_OP
59+
):
60+
continue
61+
modified = True
62+
args = node.args
63+
input_rank = args[0].meta["val"].ndim
64+
x, scale, zero_point, qmin, qmax, dtype = args
65+
# Instead of dividing by scale in quantization, we multiply by 1/scale
66+
# when quantizing.
67+
scale = cast(float, scale)
68+
scale = scale if node.target == DEQUANT_PER_TENSOR_OP else 1.0 / scale
69+
with graph_module.graph.inserting_before(node):
70+
scale_const = create_node(
71+
graph_module.graph,
72+
exir_ops.edge.aten.full.default,
73+
args=((1,) * input_rank, scale),
74+
kwargs={"dtype": torch.float32},
75+
)
76+
zp_const = create_node(
77+
graph_module.graph,
78+
exir_ops.edge.aten.full.default,
79+
args=((1,) * input_rank, zero_point),
80+
kwargs={
81+
"dtype": (
82+
torch.float32
83+
if node.target == QUANT_PER_TENSOR_OP
84+
else torch.int32
85+
)
86+
},
87+
)
88+
if node.target == QUANT_PER_TENSOR_OP:
89+
# TODO MLETORCH-1587: Decompose quantization nodes using more integer arithmetic
90+
scaled = create_node(
91+
graph_module.graph,
92+
exir_ops.edge.aten.mul.Tensor,
93+
args=(x, scale_const),
94+
from_node=node,
95+
)
96+
rounded = create_node(
97+
graph_module.graph,
98+
exir_ops.edge.aten.round.default,
99+
args=(scaled,),
100+
from_node=node,
101+
)
102+
shifted = create_node(
103+
graph_module.graph,
104+
exir_ops.edge.aten.add.Tensor,
105+
args=(rounded, zp_const),
106+
from_node=node,
107+
)
108+
clamped = create_node(
109+
graph_module.graph,
110+
exir_ops.edge.aten.clamp.default,
111+
args=(shifted, float(qmin), float(qmax)),
112+
from_node=node,
113+
)
114+
quantized = create_node(
115+
graph_module.graph,
116+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
117+
args=(clamped,),
118+
kwargs={"dtype": dtype},
119+
from_node=node,
120+
)
121+
output = quantized
122+
else:
123+
input_casted_to_zp_dtype = create_node(
124+
graph_module.graph,
125+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
126+
args=(x,),
127+
kwargs={"dtype": torch.int32},
128+
from_node=node,
129+
)
130+
shifted = create_node(
131+
graph_module.graph,
132+
exir_ops.edge.aten.sub.Tensor,
133+
args=(input_casted_to_zp_dtype, zp_const),
134+
from_node=node,
135+
)
136+
casted_to_float = create_node(
137+
graph_module.graph,
138+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
139+
args=(shifted,),
140+
kwargs={"dtype": torch.float32},
141+
from_node=node,
142+
)
143+
dequantized = create_node(
144+
graph_module.graph,
145+
exir_ops.edge.aten.mul.Tensor,
146+
args=(casted_to_float, scale_const),
147+
from_node=node,
148+
)
149+
output = dequantized
150+
node.replace_all_uses_with(output)
151+
graph_module.graph.erase_node(node)
152+
if modified:
153+
graph_module.graph.eliminate_dead_code()
154+
graph_module.recompile()
155+
graph_module = super().call(graph_module).graph_module
156+
return PassResult(graph_module, modified=modified)

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -16,10 +15,14 @@
1615
is_param_node,
1716
set_node_arg,
1817
)
18+
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
19+
ComputeConstantOpsAOTPass,
20+
)
1921
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
2022

2123
from executorch.backends.arm._passes.quant_args import QuantArgs
2224
from executorch.backends.arm._passes.remove_noop_pass import RemoveNoopPass
25+
from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo
2326
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
2427
from executorch.exir import ExportedProgram
2528

@@ -230,15 +233,37 @@ def _handle_control_flow_node(self, node: Node, graph_module: GraphModule):
230233
submodule.graph.erase_node(node_to_remove)
231234
return
232235

236+
@staticmethod
237+
def is_foldable(node: Node) -> bool:
238+
if node.op != "call_function":
239+
return False
240+
# Don't fold chains of quant-ops into each other.
241+
if node.target in (*Q_OPS, *DQ_OPS):
242+
return False
243+
244+
# Always fold q-dq into constant ops.
245+
if node.target in (
246+
exir_ops.edge.aten.full_like.default,
247+
*ComputeConstantOpsAOTPass.targeted_ops,
248+
):
249+
return True
250+
251+
# We should not fold q-dq nodes into non-quantized nodes.
252+
if not (
253+
ArmAnnotationInfo.CUSTOM_META_KEY in node.meta.get("custom", {})
254+
and ArmAnnotationInfo(
255+
node.meta["custom"][ArmAnnotationInfo.CUSTOM_META_KEY]
256+
).quantized
257+
):
258+
return False
259+
return True
260+
233261
def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901
234262

235263
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
236264
for n in graph_module.graph.nodes:
237265
n = cast(Node, n)
238-
if n.op != "call_function":
239-
continue
240-
# Don't fold chains of quant-ops into each other.
241-
if n.target in (*Q_OPS, *DQ_OPS):
266+
if not FoldAndAnnotateQParamsPass.is_foldable(n):
242267
continue
243268

244269
# Make sure we haven't already set qparams meta information on the node

backends/arm/_passes/insert_table_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
235235
for node in graph_module.graph.nodes:
236236
if node.op != "call_function" or node not in self.table_ops:
237237
continue
238-
input_qparams = node.meta["input_qparams"]
239-
output_qparams = node.meta["output_qparams"]
238+
input_qparams = node.meta.get("input_qparams", {})
239+
output_qparams = node.meta.get("output_qparams", {})
240240
if len(input_qparams) == 0 or len(output_qparams) == 0:
241241
# We only want to replace the node if it's quantized
242242
continue
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm._passes import DecomposeQuantNodesPass
10+
from executorch.backends.arm.test.common import parametrize
11+
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
12+
13+
14+
class Mul(torch.nn.Module):
15+
test_data = {
16+
"randn": (torch.randn(1, 3, 16, 16), torch.randn(1, 3, 16, 16)),
17+
"large_randn": (10e10 * torch.randn(1, 3, 16, 16), torch.randn(1, 3, 16, 16)),
18+
}
19+
20+
def forward(self, x, y):
21+
return x * y
22+
23+
24+
@parametrize("test_data", Mul.test_data)
25+
def test_decompose_quant_nodes_pass(test_data: Tuple[torch.Tensor]):
26+
module = Mul()
27+
q_dq_ops = {
28+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
29+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
30+
}
31+
# Verify that DecomposeQuantNodesPass removes quantize/dequantize nodes
32+
# and that the output is correct.
33+
pipeline = PassPipeline(
34+
module,
35+
test_data,
36+
quantize=True,
37+
pass_list=[
38+
DecomposeQuantNodesPass,
39+
],
40+
ops_before_pass=q_dq_ops,
41+
ops_not_after_pass=list(q_dq_ops.keys()),
42+
tosa_extensions=["FP"],
43+
)
44+
pipeline.run()

backends/arm/test/passes/test_fold_qdq_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class SimpleQuantizeModel(torch.nn.Module):
2020
}
2121

2222
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
23-
return x + torch.max((x + x), (y + y))
23+
return x + torch.maximum((x + x), (y + y))
2424

2525

2626
@common.parametrize("test_data", SimpleQuantizeModel.test_data)

0 commit comments

Comments
 (0)