Skip to content

Commit bdbffce

Browse files
authored
Arm backend: Refactor pass skipping mechanisms (#16153)
* Add config class for pass manager * Add option to skip passes to compile spec * Skip Fuse duplicate user pass for vgf Signed-off-by: Ryan O'Shea <[email protected]>
1 parent de52bfd commit bdbffce

File tree

13 files changed

+243
-38
lines changed

13 files changed

+243
-38
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
8+
import logging
99
from collections import defaultdict
1010
from collections.abc import Sequence
1111

@@ -110,8 +110,13 @@
110110
UnsqueezeBeforeRepeatPass,
111111
UnsqueezeScalarPlaceholdersPass,
112112
)
113-
114113
from executorch.backends.arm._passes.arm_pass import ArmPass
114+
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
115+
from executorch.backends.arm.common.pipeline_config import (
116+
ArmPassPipelineConfig,
117+
FuseDuplicateUsersConfig,
118+
SoftmaxDecompositionConfig,
119+
)
115120
from executorch.backends.arm.tosa.specification import (
116121
tosa_spec_in_set,
117122
TosaLoweringContext,
@@ -124,11 +129,45 @@
124129
from torch.fx.passes.infra.pass_base import PassResult
125130
from torch.nn.modules import Module
126131

132+
logger = logging.getLogger(__name__)
133+
127134

128135
class ArmPassManager(PassManager):
129-
def __init__(self, tosa_spec: TosaSpecification) -> None:
130-
self.tosa_spec = tosa_spec
136+
def __init__(self, compile_spec: ArmCompileSpec) -> None:
137+
self.compile_spec = compile_spec
138+
self.tosa_spec = compile_spec.tosa_spec
139+
self._skip_pass_types: tuple[type, ...] = ()
131140
super().__init__()
141+
self.configure_skip_passes()
142+
143+
def configure_skip_passes(
144+
self,
145+
override_config: ArmPassPipelineConfig | None = None,
146+
) -> tuple[type, ...]:
147+
"""
148+
Configures the pass manager to skip certain passes based on the ArmPassPipelineConfig class
149+
found in the compile spec.
150+
"""
151+
skip_set: set[type] = set()
152+
153+
config = override_config or self.compile_spec.get_pass_pipeline_config()
154+
logger.debug(f"Skip Config: {config}")
155+
156+
match config.softmax:
157+
case SoftmaxDecompositionConfig.MASKED:
158+
skip_set.add(DecomposeSoftmaxUnstablePass)
159+
case SoftmaxDecompositionConfig.UNSTABLE:
160+
skip_set.add(DecomposeSoftmaxPass)
161+
skip_set.add(DecomposeMaskedFillPass)
162+
163+
if config.fuse_duplicate_users is FuseDuplicateUsersConfig.DISABLED:
164+
skip_set.add(FuseDuplicateUsersPass)
165+
166+
self._skip_pass_types = tuple(skip_set)
167+
skip_names = [skipped_pass.__name__ for skipped_pass in self._skip_pass_types]
168+
logger.debug(f"Passes in skip list: {skip_names}")
169+
170+
return self._skip_pass_types
132171

133172
def validate_constraints_mandatory(self):
134173
"""
@@ -165,6 +204,11 @@ def _transform(self, graph_module: GraphModule):
165204
with TosaLoweringContext(self.tosa_spec):
166205
return self(graph_module).graph_module
167206

207+
def add_pass(self, pipeline_pass):
208+
if type(pipeline_pass) in self._skip_pass_types:
209+
return
210+
super().add_pass(pipeline_pass)
211+
168212
def _tosa_pipeline(
169213
self, exported_program: ExportedProgram, graph_module: GraphModule
170214
) -> GraphModule:
@@ -373,11 +417,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
373417
DecomposeSqrtPass(),
374418
DecomposeSiluPass(),
375419
DecomposeAvgPool2dPass(),
376-
(
377-
DecomposeSoftmaxUnstablePass()
378-
if self.tosa_spec.is_U55_subset
379-
else DecomposeSoftmaxPass()
380-
),
420+
DecomposeSoftmaxUnstablePass(),
421+
DecomposeSoftmaxPass(),
381422
ConvertMinMaxPass(),
382423
]
383424
)
@@ -386,7 +427,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
386427
self.add_passes(
387428
[
388429
ReplaceInfAndLimitValuesPass(),
389-
DecomposeMaskedFillPass() if not self.tosa_spec.is_U55_subset else None,
430+
DecomposeMaskedFillPass(),
390431
]
391432
)
392433

backends/arm/common/arm_compile_spec.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
# JIT compiler flows.
1111
#
1212

13+
import json
1314
from abc import ABC, abstractmethod
1415
from dataclasses import dataclass, field
1516
from enum import Enum
1617

18+
from executorch.backends.arm.common.pipeline_config import ArmPassPipelineConfig
1719
from executorch.backends.arm.tosa import TosaSpecification
1820

1921
from executorch.exir.backend.compile_spec_schema import CompileSpec
@@ -36,6 +38,7 @@ class DebugMode(Enum):
3638
_DEBUG_ARTIFACT_KEY = "debug_artifact_path"
3739
_DEBUG_MODE_KEY = "dump_debug_info"
3840
_OUTPUT_REORDER_KEY = "ouput_reorder_workaround"
41+
_TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config"
3942

4043
def _set_compile_specs(
4144
self,
@@ -44,13 +47,15 @@ def _set_compile_specs(
4447
path_for_intermediates: str | None = None,
4548
tosa_debug_mode: DebugMode | None = None,
4649
output_order_workaround: bool = True,
50+
pipeline_config: ArmPassPipelineConfig | None = None,
4751
):
4852
"""Set all values of dataclass directly."""
4953
self.tosa_spec = tosa_spec
5054
self.compiler_flags = compiler_flags
5155
self.path_for_intermediates = path_for_intermediates
5256
self.tosa_debug_mode = tosa_debug_mode
5357
self.output_order_workaround = output_order_workaround
58+
self._pipeline_config = pipeline_config
5459

5560
@classmethod
5661
def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
@@ -60,6 +65,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
6065
path_for_intermediates: str | None = None
6166
tosa_debug_mode: ArmCompileSpec.DebugMode | None = None
6267
output_order_workaround: bool = True
68+
pipeline_config: ArmPassPipelineConfig | None = None
6369
unknown_specs: dict[str, str] = {}
6470
for spec in compile_specs:
6571
key = spec.key
@@ -98,6 +104,12 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
98104
tosa_debug_mode = ArmCompileSpec.DebugMode[val]
99105
elif key == ArmCompileSpec._OUTPUT_REORDER_KEY:
100106
output_order_workaround = val # type: ignore[assignment]
107+
elif key == ArmCompileSpec._TRANSFORM_PIPELINE_CONFIG_KEY:
108+
if pipeline_config is not None:
109+
raise ValueError(
110+
"More than one transform pipeline entry in compile spec."
111+
)
112+
pipeline_config = ArmPassPipelineConfig.from_dict(json.loads(val))
101113
else:
102114
unknown_specs[key] = val
103115

@@ -120,6 +132,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
120132
path_for_intermediates=path_for_intermediates,
121133
tosa_debug_mode=tosa_debug_mode,
122134
output_order_workaround=output_order_workaround,
135+
pipeline_config=pipeline_config,
123136
)
124137
cls.from_list_hook(compile_spec, unknown_specs)
125138
compile_spec.validate()
@@ -189,8 +202,33 @@ def to_list(self):
189202
)
190203
)
191204

205+
if self._pipeline_config is not None and not self._pipeline_config.is_default():
206+
compile_spec.append(
207+
CompileSpec(
208+
ArmCompileSpec._TRANSFORM_PIPELINE_CONFIG_KEY,
209+
self._pipeline_config.serialize(),
210+
)
211+
)
192212
return compile_spec
193213

214+
def get_pass_pipeline_config(self) -> ArmPassPipelineConfig:
215+
"""
216+
Returns configuration that controls how the Arm pass pipeline should behave.
217+
Subclasses may override to tweak defaults for specific targets.
218+
"""
219+
if self._pipeline_config is None:
220+
self._pipeline_config = self._create_default_pipeline_config()
221+
return self._pipeline_config
222+
223+
def set_pass_pipeline_config(self, config: ArmPassPipelineConfig) -> None:
224+
self._pipeline_config = config
225+
226+
def _create_default_pipeline_config(self) -> ArmPassPipelineConfig:
227+
config = ArmPassPipelineConfig()
228+
if self.tosa_spec.is_U55_subset:
229+
config.disable_masked_softmax()
230+
return config
231+
194232
def get_intermediate_path(self) -> str | None:
195233
"""
196234
Gets the path used for dumping intermediate results such as tosa and pte.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import json
7+
from dataclasses import dataclass, fields
8+
from enum import auto, Enum
9+
from typing import Any
10+
11+
12+
class SoftmaxDecompositionConfig(Enum):
13+
MASKED = auto()
14+
UNSTABLE = auto()
15+
16+
17+
class FuseDuplicateUsersConfig(Enum):
18+
ENABLED = auto()
19+
DISABLED = auto()
20+
21+
22+
@dataclass
23+
class ArmPassPipelineConfig:
24+
softmax: SoftmaxDecompositionConfig = SoftmaxDecompositionConfig.MASKED
25+
fuse_duplicate_users: FuseDuplicateUsersConfig = FuseDuplicateUsersConfig.ENABLED
26+
27+
def disable_masked_softmax(self) -> None:
28+
self.softmax = SoftmaxDecompositionConfig.UNSTABLE
29+
30+
def disable_fuse_duplicate_users(self) -> None:
31+
self.fuse_duplicate_users = FuseDuplicateUsersConfig.DISABLED
32+
33+
def is_default(self) -> bool:
34+
return (
35+
self.softmax is SoftmaxDecompositionConfig.MASKED
36+
and self.fuse_duplicate_users is FuseDuplicateUsersConfig.ENABLED
37+
)
38+
39+
def to_dict(self) -> dict[str, str]:
40+
return {f.name: getattr(self, f.name).name for f in fields(self)}
41+
42+
@classmethod
43+
def from_dict(cls, data: dict[str, Any]) -> "ArmPassPipelineConfig":
44+
config = cls()
45+
for f in fields(cls):
46+
raw_value = data.get(f.name)
47+
if raw_value is None:
48+
continue
49+
enum_type = f.type
50+
setattr(config, f.name, enum_type[raw_value])
51+
return config
52+
53+
def serialize(self) -> bytes:
54+
"""Return a serialized representation of this config."""
55+
return json.dumps(self.to_dict()).encode()
56+
57+
def __repr__(self):
58+
fields = ", ".join(f"{name}={value!r}" for name, value in self.__dict__.items())
59+
return f"({fields})"

backends/arm/ethosu/compile_spec.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
7-
7+
from executorch.backends.arm.common.pipeline_config import ( # noqa: unused
8+
ArmPassPipelineConfig,
9+
)
810
from executorch.backends.arm.tosa import ( # type: ignore[import-not-found]
911
TosaSpecification,
1012
)
11-
12-
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found]
13-
CompileSpec,
14-
)
13+
from executorch.exir.backend.compile_spec_schema import CompileSpec
1514

1615

1716
class EthosUCompileSpec(ArmCompileSpec):
@@ -43,7 +42,6 @@ def __init__(
4342
4443
"""
4544
self.target = target
46-
4745
# Set vela compiler flags
4846
if config_ini is None:
4947
config_ini = "Arm/vela.ini"
@@ -57,25 +55,26 @@ def __init__(
5755
]
5856
)
5957
# default system config and memory mode
60-
if "ethos-u55" in self.target:
58+
target_lower = self.target.lower()
59+
if "ethos-u55" in target_lower:
6160
if system_config is None:
6261
system_config = "Ethos_U55_High_End_Embedded"
6362
if memory_mode is None:
6463
memory_mode = "Shared_Sram"
65-
elif "ethos-u85" in self.target:
64+
elif "ethos-u85" in target_lower:
6665
if system_config is None:
6766
system_config = "Ethos_U85_SYS_DRAM_Mid"
6867
if memory_mode is None:
6968
memory_mode = "Sram_Only"
7069
else:
71-
raise RuntimeError(f"Unknown ethos target: {self.target}")
70+
raise RuntimeError(f"Unknown ethos target: {target}")
7271

7372
compiler_flags.append(f"--system-config={system_config}")
7473
compiler_flags.append(f"--memory-mode={memory_mode}")
7574

7675
# Set TOSA version.
7776
base_tosa_version = "TOSA-1.0+INT+int16"
78-
if "u55" in self.target:
77+
if "u55" in target_lower:
7978
# Add the Ethos-U55 extension marker
8079
base_tosa_version += "+u55"
8180
tosa_spec = TosaSpecification.create_from_string(base_tosa_version)
@@ -109,3 +108,8 @@ def validate(self):
109108
def get_output_format(cls) -> str:
110109
"""Return the artifact format emitted by this compile spec."""
111110
return "vela"
111+
112+
def _create_default_pipeline_config(self) -> ArmPassPipelineConfig:
113+
# Any u55 subset passes are treated as tosa specification configs
114+
# As such, they should be added to the base class default.
115+
return super()._create_default_pipeline_config()

backends/arm/quantizer/arm_quantizer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,13 @@ class TOSAQuantizer(Quantizer):
339339
def __init__(
340340
self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec
341341
) -> None:
342-
343342
super().__init__()
343+
self.compile_spec: ArmCompileSpec
344344
if isinstance(compile_spec_or_tosa_spec, TosaSpecification):
345-
self.tosa_spec = compile_spec_or_tosa_spec
346-
self.compile_spec = None
345+
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
346+
347+
self.compile_spec = TosaCompileSpec(compile_spec_or_tosa_spec)
348+
self.tosa_spec = self.compile_spec.tosa_spec
347349
elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec):
348350
self.compile_spec = compile_spec_or_tosa_spec
349351
self.tosa_spec = self.compile_spec.tosa_spec
@@ -432,9 +434,8 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
432434
# TODO: Fix the need to lazily import this.
433435
from executorch.backends.arm._passes import ArmPassManager
434436

435-
return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline(
436-
graph_module=model
437-
)
437+
pass_manager = ArmPassManager(self.compile_spec)
438+
return pass_manager.transform_for_annotation_pipeline(graph_module=model)
438439

439440
def annotate(self, model: GraphModule) -> GraphModule:
440441
"""Annotate the graph with the configured quantization settings.

backends/arm/test/misc/test_call_operator_submodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from executorch.backends.arm._passes.arm_pass import ArmPass
1111
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
12-
from executorch.backends.arm.tosa.specification import TosaSpecification
12+
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
1313
from torch.fx import GraphModule
1414
from torch.fx.passes.infra.pass_base import PassResult
1515

@@ -58,7 +58,7 @@ def test_call_operator_runs_once_for_cond_submodules() -> None:
5858
graph_module = exported.graph_module
5959

6060
recording_pass = _DepthRecordingPass(graph_module)
61-
pass_manager = ArmPassManager(TosaSpecification.create_from_string("TOSA-1.00+FP"))
61+
pass_manager = ArmPassManager(TosaCompileSpec("TOSA-1.00+FP"))
6262
pass_manager.add_pass(recording_pass)
6363
pass_manager._transform(graph_module)
6464

0 commit comments

Comments
 (0)