Skip to content

Commit de52bfd

Browse files
authored
Arm backend: Add conv3d support to Tosa/Vgf backends (#16093)
Conv3d is not supported by vela so it can not be supported in u55 or u85. * Adds Conv3D support for FP32, Int8 and int16A8W * Reworks to_tosa_memory_format_pass.py to handle spatial rank 3 tensors (DHW) * Adds support for rank 5 tensors to analyze_output_utils.py * Reworks conv2d passes to handle conv3d and renames them to be more generic Signed-off-by: Ryan O'Shea <[email protected]>
1 parent 2198973 commit de52bfd

20 files changed

+745
-256
lines changed

backends/arm/_passes/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@
4848
from .decompose_glu_pass import DecomposeGluPass # noqa
4949
from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa
5050
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
51-
from .decompose_int16_activation_conv2d_pass import ( # noqa
52-
DecomposeConv2dWithInt16ActivationPass,
51+
from .decompose_int16_activation_conv_pass import ( # noqa
52+
DecomposeConvWithInt16ActivationPass,
5353
)
5454
from .decompose_int32_clamp_pass import DecomposeInt32ClampPass # noqa
5555
from .decompose_int_pow_pass import DecomposeIntPowPass # noqa
@@ -109,7 +109,7 @@
109109
from .replace_scalar_with_tensor_pass import ( # noqa
110110
ReplaceScalarWithTensorByProfilePass,
111111
)
112-
from .rewrite_conv2d_pass import RewriteConv2dPass # noqa
112+
from .rewrite_conv_pass import RewriteConvPass # noqa
113113
from .rewrite_matmul import RewriteMatmulPass # noqa
114114
from .rewrite_upsample import RewriteUpsamplePass # noqa
115115
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
DecomposeAtanPass,
4141
DecomposeAvgPool2dPass,
4242
DecomposeBatchNormNoStatsPass,
43-
DecomposeConv2dWithInt16ActivationPass,
43+
DecomposeConvWithInt16ActivationPass,
4444
DecomposeCoshPass,
4545
DecomposeCosineSimilarityPass,
4646
DecomposeCumsumPass,
@@ -101,7 +101,7 @@
101101
RemoveNoopPass,
102102
ReplaceInfAndLimitValuesPass,
103103
ReplaceScalarWithTensorByProfilePass,
104-
RewriteConv2dPass,
104+
RewriteConvPass,
105105
RewriteMatmulPass,
106106
RewriteUpsamplePass,
107107
ScalarsToAttributePass,
@@ -277,7 +277,7 @@ def _tosa_pipeline(
277277
BroadcastArgsPass(),
278278
ConvertPermuteSingletonToViewPass(),
279279
FuseViewCopyTransformPass(),
280-
DecomposeConv2dWithInt16ActivationPass(),
280+
DecomposeConvWithInt16ActivationPass(),
281281
DecomposeSumPass(),
282282
InsertTableOpsPass(exported_program),
283283
]
@@ -287,7 +287,7 @@ def _tosa_pipeline(
287287
self.add_passes(
288288
[
289289
RewriteUpsamplePass(),
290-
RewriteConv2dPass(exported_program),
290+
RewriteConvPass(exported_program),
291291
RewriteMatmulPass(),
292292
]
293293
)

backends/arm/_passes/arm_pass_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,20 @@ def get_param_tensor(
106106
raise RuntimeError(f"unsupported param type, {node.op}.")
107107

108108

109+
def expand_around_channel(param: Sequence[int] | int, spatial_rank: int) -> list[int]:
110+
"""
111+
Expand a scalar or 1-D parameter around the channel dimension into a broadcastable
112+
shape while preserving the channel location.
113+
"""
114+
if isinstance(param, int):
115+
return [param] * spatial_rank
116+
117+
param_list = list(param)
118+
if len(param_list) == 1 and spatial_rank > 1:
119+
param_list = param_list * spatial_rank
120+
return param_list
121+
122+
109123
def create_node(
110124
graph: torch.fx.Graph,
111125
op_target: OpOverload | EdgeOpOverload,

backends/arm/_passes/conv1d_unsqueeze_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from executorch.backends.arm._passes import ArmPass
1212

13-
from executorch.backends.arm._passes.rewrite_conv2d_pass import RewriteConv2dPass
13+
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
1414
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
1515

1616
from executorch.exir.dialects._ops import ops as exir_ops
@@ -29,7 +29,7 @@ class Conv1dUnsqueezePass(ArmPass):
2929
"""
3030

3131
_passes_required_after: Set[Type[ExportPass]] = {
32-
RewriteConv2dPass,
32+
RewriteConvPass,
3333
SizeAdjustInputPass,
3434
}
3535

backends/arm/_passes/decompose_cumsum_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm._passes import ArmPass
1111
from executorch.backends.arm._passes.arm_pass_utils import create_node
1212
from executorch.backends.arm._passes.quant_args import QuantArgs
13-
from executorch.backends.arm._passes.rewrite_conv2d_pass import RewriteConv2dPass
13+
from executorch.backends.arm._passes.rewrite_conv_pass import RewriteConvPass
1414

1515
from executorch.backends.transforms.utils import create_constant_placeholder
1616
from executorch.exir import ExportedProgram
@@ -42,7 +42,7 @@ class DecomposeCumsumPass(ArmPass):
4242
And the convolution is applied over dimension H.
4343
"""
4444

45-
_passes_required_after: Set[Type[ExportPass]] = {RewriteConv2dPass}
45+
_passes_required_after: Set[Type[ExportPass]] = {RewriteConvPass}
4646

4747
def __init__(self, exported_program: ExportedProgram) -> None:
4848
super().__init__()

backends/arm/_passes/decompose_int16_activation_conv2d_pass.py renamed to backends/arm/_passes/decompose_int16_activation_conv_pass.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,37 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
from typing import cast, Set, Type
7+
from typing import cast, Sequence, Set, Type
88

99
import torch
10-
from executorch.backends.arm._passes.arm_pass import ArmPass
10+
from executorch.backends.arm._passes import ArmPass
1111
from executorch.backends.arm._passes.quant_args import QuantArgs
1212

1313
from executorch.backends.arm.tosa.specification import get_context_spec
1414
from executorch.exir.dialects._ops import ops as exir_ops
1515
from executorch.exir.pass_base import ExportPass
1616

1717

18-
class DecomposeConv2dWithInt16ActivationPass(ArmPass):
18+
class DecomposeConvWithInt16ActivationPass(ArmPass):
1919
"""
2020
This pass decomposes a convolution with input dtype int16 and bias
21-
into a convolution without bias followed by an addition of the bias
22-
since the TOSA op requires the bias to be int48 which is hard to represent
21+
into a convolution without bias followed by an addition of the bias.
22+
We also reshape the 1D bias to [1, C, 1, …] so it broadcasts along the channel
23+
dimension. Since the TOSA op requires the bias to be int48 which is hard to represent
2324
in torch. Instead rescale the int48 output to int16 and add the bias in int16.
2425
"""
2526

27+
def __init__(self) -> None:
28+
super().__init__()
29+
2630
_passes_required_after: Set[Type[ExportPass]] = set()
2731

32+
def bias_view_shape(
33+
self, bias: torch.Tensor, activation_rank: int
34+
) -> Sequence[int]:
35+
# reshape bias to match convolution output rank so addition broadcasts over channels
36+
return [1, bias.shape[0], *([1] * (activation_rank - 2))]
37+
2838
def call_operator(self, op, args, kwargs, meta):
2939
if op != exir_ops.edge.aten.convolution.default:
3040
return super().call_operator(op, args, kwargs, meta)
@@ -37,18 +47,22 @@ def call_operator(self, op, args, kwargs, meta):
3747
if args[2] is None:
3848
return super().call_operator(op, args, kwargs, meta)
3949

40-
if args[0].data.dtype == torch.int8:
41-
return super().call_operator(op, args, kwargs, meta)
42-
elif args[0].data.dtype == torch.int16:
43-
if not tosa_spec.support_extension("int16"):
44-
raise ValueError(
45-
"int16 activation for convolution requires TOSA int16 extension"
46-
)
47-
else:
50+
activation_tensor = args[0].data
51+
activation_rank = activation_tensor.dim()
52+
53+
if activation_rank not in (4, 5) or activation_tensor.dtype != torch.int16:
4854
return super().call_operator(op, args, kwargs, meta)
4955

50-
# convolution with bias and activation is int16
51-
bias = args[2]
56+
if not tosa_spec.support_extension("int16"):
57+
raise ValueError(
58+
"int16 activation for convolution requires TOSA int16 extension"
59+
)
60+
61+
# convolution with bias and activation is int16 (expected activation rank enforced above)
62+
# The bias is assumed to be quantized with the same quantization parameters as
63+
# the output of the convolution
64+
bias_arg = args[2]
65+
bias_data = bias_arg.data
5266

5367
no_bias_args = list(args)
5468
no_bias_args[2] = None
@@ -63,7 +77,7 @@ def call_operator(self, op, args, kwargs, meta):
6377
# reshape the tensor to the same rank as the convolution output to add the bias to the channels
6478
channel_bias = super().call_operator(
6579
exir_ops.edge.aten.view_copy.default,
66-
(bias, [1, len(bias.data), 1, 1]),
80+
(bias_arg, self.bias_view_shape(bias_data, activation_rank)),
6781
{},
6882
new_meta,
6983
)

backends/arm/_passes/rewrite_conv2d_pass.py renamed to backends/arm/_passes/rewrite_conv_pass.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from executorch.backends.arm._passes.arm_pass_utils import (
1414
create_node,
15+
expand_around_channel,
1516
get_first_fake_tensor,
1617
get_param_tensor,
1718
is_buffer,
@@ -29,7 +30,7 @@
2930
from torch.export.graph_signature import InputKind
3031

3132

32-
class RewriteConv2dPass(ArmPass):
33+
class RewriteConvPass(ArmPass):
3334
"""Rewrites aten.convolution to tosa.CONV2D or tosa.DEPTHWISE_CONV2D."""
3435

3536
def __init__(self, exported_program: torch.export.ExportedProgram):
@@ -88,11 +89,27 @@ def _is_depthwise_conv2d(self, node: torch.fx.Node) -> bool:
8889
or node.target != exir_ops.edge.aten.convolution.default
8990
):
9091
return False
92+
input_tensor = get_first_fake_tensor(node.all_input_nodes[0])
93+
if len(input_tensor.shape) != 4:
94+
return False
9195
groups = node.args[-1]
92-
in_channels = get_first_fake_tensor(node.all_input_nodes[0]).shape[1]
96+
in_channels = input_tensor.shape[1]
9397
out_channels = get_first_fake_tensor(node).shape[1]
9498
return (in_channels == groups) and (out_channels % in_channels) == 0
9599

100+
def _is_conv3d(self, rank, groups) -> bool:
101+
if rank == 5:
102+
# A Conv3D is considered depthwise if Group == InChannels and
103+
# Group * N == OutChannels, where N is a possitive integer.
104+
# Currently we do not support depthwise or grouped conv3d.
105+
# @TODO Add grouped/depthwise conv3d support or reject in partitioner.
106+
if groups != 1:
107+
raise RuntimeError(
108+
"CONV3D with groups != 1 is not supported in the Arm backend."
109+
)
110+
return True
111+
return False
112+
96113
def _reshape_weights(self, weight_node: torch.fx.Node, in_channels: int) -> None:
97114
"""Reshape the weights for depthwise convolution such that when serialized to TOSA,
98115
the weights are in the format [H, W, in_channels, m_length] where
@@ -201,7 +218,7 @@ def insert_output_rescale(self, graph_module, node):
201218
)
202219
return rescale_node
203220

204-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
221+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
205222
modified = False
206223
for node in graph_module.graph.nodes:
207224
if (
@@ -224,30 +241,40 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
224241
group,
225242
) = node.args
226243

227-
pad = [val for val in pad for _ in (0, 1)]
228244
input_fake_tensor = get_first_fake_tensor(x)
229245
weight_fake_tensor = get_first_fake_tensor(weight)
230-
# Adjust the pad value if needed to meet the
231-
# strict convolution output shape calculation.
232-
pad[1] = self._adjust_pad_if_needed(
233-
input_fake_tensor.shape[2],
234-
weight_fake_tensor.shape[2],
235-
stride[0],
236-
pad[1],
237-
dilation[0],
238-
)
239-
pad[3] = self._adjust_pad_if_needed(
240-
input_fake_tensor.shape[3],
241-
weight_fake_tensor.shape[3],
242-
stride[1],
243-
pad[3],
244-
dilation[1],
245-
)
246+
input_shape = input_fake_tensor.shape
247+
weight_shape = weight_fake_tensor.shape
248+
spatial_rank = len(input_shape) - 2
249+
stride_list = expand_around_channel(stride, spatial_rank)
250+
dilation_list = expand_around_channel(dilation, spatial_rank)
251+
pad_list = expand_around_channel(pad, spatial_rank)
252+
253+
pad_attr: list[int] = []
254+
for value in pad_list:
255+
pad_attr.extend([value, value]) # duplicate pad before/after per axis
256+
257+
for axis_index in range(spatial_rank):
258+
pad_index = axis_index * 2 + 1 # adjust trailing pad entry
259+
pad_attr[pad_index] = self._adjust_pad_if_needed(
260+
input_shape[axis_index + 2],
261+
weight_shape[axis_index + 2],
262+
stride_list[axis_index],
263+
pad_attr[pad_index],
264+
dilation_list[axis_index],
265+
)
266+
267+
stride = tuple(stride_list)
268+
dilation = tuple(dilation_list)
269+
pad = pad_attr
270+
246271
has_bias = bias is not None
247272
if not has_bias:
248273
bias = self._add_bias(graph_module, node, weight)
249274

250-
if self._is_depthwise_conv2d(node):
275+
if self._is_conv3d(len(input_shape), group):
276+
target_op = exir_ops.backend.tosa.CONV3D.default
277+
elif self._is_depthwise_conv2d(node):
251278
target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default
252279
# If there are any TOSA.DEPTHWISE_CONV2D nodes using the weights, we've already reshaped them.
253280
if all(user.target != target_op for user in weight.users):
@@ -256,7 +283,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
256283
else:
257284
target_op = exir_ops.backend.tosa.CONV2D.default
258285

259-
conv2d_args = (
286+
conv_args = (
260287
x,
261288
weight,
262289
bias,
@@ -272,7 +299,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
272299
tosa_op = create_node(
273300
graph=graph_module.graph,
274301
op_target=target_op,
275-
args=conv2d_args,
302+
args=conv_args,
276303
from_node=node,
277304
inherit_qparams=True,
278305
)
@@ -281,7 +308,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
281308
input_fake_tensor,
282309
weight_fake_tensor,
283310
bias_fake_tensor,
284-
*conv2d_args[3:],
311+
*conv_args[3:],
285312
)
286313

287314
if (

0 commit comments

Comments
 (0)