Skip to content

Commit 32916d3

Browse files
Arm backend: Add Quantization test pipeline (pytorch#16151)
Add new test pipeline to test that nodes are annotated in an expected way. The pipeline does currently not test the numerical accuracy. Signed-off-by: Oscar Andersson <[email protected]>
1 parent c398ff4 commit 32916d3

File tree

4 files changed

+460
-3
lines changed

4 files changed

+460
-3
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,6 @@ def set_module_name(
403403
404404
"""
405405
# Validate that quantization_config is provided
406-
if quantization_config is None:
407-
raise ValueError("quantization_config == None is not supported yet")
408406
self.module_name_config[module_name] = quantization_config
409407
return self
410408

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Dict
8+
9+
import torch
10+
from executorch.backends.arm.quantizer import (
11+
get_symmetric_a16w8_quantization_config,
12+
get_symmetric_quantization_config,
13+
TOSAQuantizer,
14+
)
15+
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
16+
from executorch.backends.arm.test import common
17+
from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline
18+
from executorch.backends.arm.tosa import TosaSpecification
19+
from torchvision import models, transforms # type: ignore[import-untyped]
20+
from torchvision.ops.misc import Conv2dNormActivation # type: ignore[import-untyped]
21+
22+
23+
def get_quantizer():
24+
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
25+
quantizer = TOSAQuantizer(tosa_spec)
26+
quantizer.set_global(get_symmetric_quantization_config())
27+
return quantizer
28+
29+
30+
def get_selective_quantizer_by_module(
31+
module_types: Dict[torch.nn.Module, QuantizationConfig]
32+
):
33+
quantizer = get_quantizer()
34+
quantizer.set_global(get_symmetric_quantization_config())
35+
for module_type, config in module_types.items():
36+
quantizer.set_module_type(module_type, config)
37+
38+
return quantizer
39+
40+
41+
def get_selective_quantizer_by_module_name(module_names: Dict[str, QuantizationConfig]):
42+
quantizer = get_quantizer()
43+
quantizer.set_global(get_symmetric_quantization_config())
44+
for module_name, config in module_names.items():
45+
quantizer.set_module_name(module_name, config)
46+
47+
return quantizer
48+
49+
50+
class Add(torch.nn.Module):
51+
52+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
53+
return x + y
54+
55+
56+
class AddSoftmaxAdd(torch.nn.Module):
57+
module_names = {"add_0": None, "add_1": None}
58+
module_types = {
59+
Add: None,
60+
}
61+
quantized_aten_targets = {"aten.relu.default": 1}
62+
non_quantized_aten_targets = {"aten.add.Tensor": 2}
63+
64+
def __init__(self, *args, **kwargs):
65+
super().__init__(*args, **kwargs)
66+
self.softmax = torch.nn.Softmax(dim=-1)
67+
self.relu = torch.nn.ReLU()
68+
self.add_0 = Add()
69+
self.add_1 = Add()
70+
71+
def get_inputs(self):
72+
return (torch.randn(1, 10), torch.randn(1, 10))
73+
74+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
75+
z = self.add_0(x, y)
76+
z = self.relu(z)
77+
z = self.softmax(z)
78+
return self.add_1(z, y)
79+
80+
81+
test_models = {
82+
"add_softmax_add": AddSoftmaxAdd,
83+
}
84+
85+
86+
@common.parametrize("model", test_models)
87+
def test_selective_quant_module_name_tosa_INT(model):
88+
model = model()
89+
inputs = model.get_inputs()
90+
quantzed_aten_targets = model.quantized_aten_targets
91+
non_quantized_aten_targets = model.non_quantized_aten_targets
92+
quantization_annotations = {}
93+
for target, count in quantzed_aten_targets.items():
94+
quantization_annotations[target] = {
95+
get_symmetric_quantization_config().output_activation: count
96+
}
97+
for target, count in non_quantized_aten_targets.items():
98+
quantization_annotations[target] = {None: count}
99+
100+
pipeline = QuantizationPipeline[tuple[torch.Tensor, torch.Tensor]](
101+
model,
102+
inputs,
103+
quantizer=get_selective_quantizer_by_module_name(model.module_names),
104+
qspecs=quantization_annotations,
105+
)
106+
107+
pipeline.run()
108+
109+
110+
@common.parametrize("model", test_models)
111+
def test_selective_quant_module_type_tosa_INT(model):
112+
model = model()
113+
inputs = model.get_inputs()
114+
quantzed_aten_targets = model.quantized_aten_targets
115+
non_quantized_aten_targets = model.non_quantized_aten_targets
116+
quantization_annotations = {}
117+
for target, count in quantzed_aten_targets.items():
118+
quantization_annotations[target] = {
119+
get_symmetric_quantization_config().output_activation: count
120+
}
121+
for target, count in non_quantized_aten_targets.items():
122+
quantization_annotations[target] = {None: count}
123+
124+
pipeline = QuantizationPipeline[tuple[torch.Tensor, torch.Tensor]](
125+
model,
126+
inputs,
127+
quantizer=get_selective_quantizer_by_module(model.module_types),
128+
qspecs=quantization_annotations,
129+
)
130+
131+
pipeline.run()
132+
133+
134+
mv3 = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights)
135+
mv3.eval()
136+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
137+
138+
139+
def test_mv3_selective_quant_int16():
140+
model = mv3
141+
inputs = (normalize(torch.randn(1, 3, 224, 224)),)
142+
143+
a16w8_config = get_symmetric_a16w8_quantization_config()
144+
quantization_annotations = {
145+
"aten.conv2d.default": {
146+
a16w8_config.output_activation: 29,
147+
},
148+
"aten.hardswish_.default": {
149+
a16w8_config.output_activation: 18,
150+
},
151+
"aten.relu_.default": {
152+
a16w8_config.output_activation: 5,
153+
},
154+
}
155+
156+
pipeline = QuantizationPipeline[tuple[torch.Tensor]](
157+
model,
158+
inputs,
159+
quantizer=get_selective_quantizer_by_module(
160+
{
161+
Conv2dNormActivation: a16w8_config,
162+
}
163+
),
164+
qspecs=quantization_annotations,
165+
)
166+
167+
pipeline.run()
168+
169+
170+
def test_mv3_selective_quant_float32():
171+
model = mv3
172+
inputs = (normalize(torch.randn(1, 3, 224, 224)),)
173+
174+
quantization_annotations = {
175+
"aten.adaptive_avg_pool2d.default": {
176+
None: 1,
177+
},
178+
}
179+
180+
pipeline = QuantizationPipeline[tuple[torch.Tensor]](
181+
model,
182+
inputs,
183+
quantizer=get_selective_quantizer_by_module_name(
184+
{
185+
"features.11.block.2.avgpool": None,
186+
}
187+
),
188+
qspecs=quantization_annotations,
189+
)
190+
191+
pipeline.run()
192+
193+
194+
def test_mv3_io_quant():
195+
model = mv3
196+
inputs = (normalize(torch.randn(1, 3, 224, 224)),)
197+
198+
quantizer = get_quantizer()
199+
# Workaround to disable quantization for all modules
200+
quantizer.set_module_type(torch.nn.Module, None)
201+
# Only quantize IO
202+
quantizer.set_io(get_symmetric_quantization_config())
203+
204+
pipeline = QuantizationPipeline[tuple[torch.Tensor]](
205+
model,
206+
inputs,
207+
quantizer=quantizer,
208+
input_qspecs={get_symmetric_quantization_config().input_activation: 1},
209+
output_qspecs={get_symmetric_quantization_config().output_activation: 1},
210+
)
211+
212+
pipeline.run()

0 commit comments

Comments
 (0)