Skip to content

Commit 2f9e574

Browse files
Cortex_m backend: Add permute op (pytorch#15848)
Since the transpose op doesn't require qparams but still expects input to be int8, the check in quantized_op_fusion_pass is moved from the call_operator level to the _get_replacement level. This way different ops can have different checks. --------- Signed-off-by: Adrian Lundell <[email protected]>
1 parent 3e9d22c commit 2f9e574

File tree

6 files changed

+275
-6
lines changed

6 files changed

+275
-6
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ set(_cortex_m_kernels__srcs
5858
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
5959
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
6060
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
61+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_transpose.cpp
6162
)
6263

6364
# Generate C++ bindings to register kernels into Executorch
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright 2025 Arm Limited and/or its affiliates.
3+
*
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "cortex_m_ops_common.h"
9+
10+
#include <array>
11+
#include <limits>
12+
#include <vector>
13+
14+
// Include CMSIS-NN headers with C linkage
15+
extern "C" {
16+
#include "arm_nnfunctions.h"
17+
}
18+
19+
namespace cortex_m {
20+
namespace native {
21+
22+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
23+
24+
namespace {
25+
26+
constexpr size_t kMaxSupportedDims = 4;
27+
28+
} // namespace
29+
30+
Tensor& transpose_out(
31+
KernelRuntimeContext& context,
32+
const Tensor& input,
33+
const IntArrayRef perm,
34+
Tensor& out) {
35+
if (input.scalar_type() != ScalarType::Char ||
36+
out.scalar_type() != ScalarType::Char) {
37+
ET_LOG(
38+
Error,
39+
"transpose_out: only int8 tensors are supported (input=%d, out=%d)",
40+
static_cast<int>(input.scalar_type()),
41+
static_cast<int>(out.scalar_type()));
42+
context.fail(Error::InvalidArgument);
43+
return out;
44+
}
45+
46+
const size_t rank = input.dim();
47+
if (rank == 0 || rank > kMaxSupportedDims) {
48+
ET_LOG(
49+
Error,
50+
"transpose_out: expected tensor rank in [1, %zu], got %zu",
51+
kMaxSupportedDims,
52+
rank);
53+
context.fail(Error::InvalidArgument);
54+
return out;
55+
}
56+
57+
if (perm.size() != static_cast<int64_t>(rank)) {
58+
ET_LOG(
59+
Error,
60+
"transpose_out: permutation length %zd does not match tensor rank %zu",
61+
perm.size(),
62+
rank);
63+
context.fail(Error::InvalidArgument);
64+
return out;
65+
}
66+
67+
std::array<int32_t, kMaxSupportedDims> input_dims_arr{1, 1, 1, 1};
68+
std::array<int32_t, kMaxSupportedDims> output_dims_arr{1, 1, 1, 1};
69+
for (size_t i = 0; i < rank; ++i) {
70+
const auto in_size = input.size(i);
71+
const auto out_size = out.size(i);
72+
if (in_size > std::numeric_limits<int32_t>::max() ||
73+
out_size > std::numeric_limits<int32_t>::max()) {
74+
ET_LOG(
75+
Error,
76+
"transpose_out: dimension size exceeds int32_t range (input=%lld, output=%lld)",
77+
static_cast<long long>(in_size),
78+
static_cast<long long>(out_size));
79+
context.fail(Error::InvalidArgument);
80+
return out;
81+
}
82+
input_dims_arr[i] = static_cast<int32_t>(in_size);
83+
output_dims_arr[i] = static_cast<int32_t>(out_size);
84+
}
85+
86+
cmsis_nn_dims input_dims = {
87+
input_dims_arr[0],
88+
input_dims_arr[1],
89+
input_dims_arr[2],
90+
input_dims_arr[3]};
91+
cmsis_nn_dims output_dims = {
92+
output_dims_arr[0],
93+
output_dims_arr[1],
94+
output_dims_arr[2],
95+
output_dims_arr[3]};
96+
97+
std::array<uint32_t, kMaxSupportedDims> perm_buffer{0, 1, 2, 3};
98+
for (size_t i = 0; i < rank; ++i) {
99+
perm_buffer[i] = static_cast<uint32_t>(perm[i]);
100+
}
101+
102+
const cmsis_nn_transpose_params transpose_params{
103+
static_cast<int32_t>(rank), perm_buffer.data()};
104+
105+
const int8_t* input_data = input.const_data_ptr<int8_t>();
106+
int8_t* output_data = out.mutable_data_ptr<int8_t>();
107+
108+
const arm_cmsis_nn_status status = arm_transpose_s8(
109+
input_data, output_data, &input_dims, &output_dims, &transpose_params);
110+
111+
if (status != ARM_CMSIS_NN_SUCCESS) {
112+
ET_LOG(
113+
Error,
114+
"transpose_out: arm_transpose_s8 failed with status [%d]",
115+
static_cast<int>(status));
116+
context.fail(Error::Internal);
117+
return out;
118+
}
119+
120+
return out;
121+
}
122+
123+
} // namespace native
124+
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,3 +349,21 @@ def quantized_linear_impl(
349349
output += output_offset
350350
output = torch.clamp(output, activation_min, activation_max).to(torch.int8)
351351
return output
352+
353+
354+
# ===================================================================
355+
# TRANSPOSE OPERATION DEFINITION
356+
# ===================================================================
357+
lib.define("transpose(Tensor input, int[] perm) -> Tensor")
358+
lib.define("transpose.out(Tensor input, int[] perm, *, Tensor(a!) out) -> Tensor(a!)")
359+
360+
361+
@register_fake("cortex_m::transpose")
362+
def transpose_meta(input: torch.Tensor, perm) -> torch.Tensor:
363+
output_shape = [input.shape[idx] for idx in perm]
364+
return torch.empty(output_shape, dtype=input.dtype, device=input.device)
365+
366+
367+
@impl(lib, "transpose", "CompositeExplicitAutograd")
368+
def transpose_impl(input: torch.Tensor, perm) -> torch.Tensor:
369+
return input.permute(tuple(perm)).contiguous()

backends/cortex_m/ops/operators.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,9 @@
3434
kernels:
3535
- arg_meta: null
3636
kernel_name: cortex_m::quantized_linear_out
37+
38+
- func: cortex_m::transpose.out(Tensor input, int[] perm, *, Tensor(a!) out) -> Tensor(a!)
39+
variants: function
40+
kernels:
41+
- arg_meta: null
42+
kernel_name: cortex_m::transpose_out

backends/cortex_m/passes/quantized_op_fusion_pass.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from typing import Dict
99

10+
import torch
11+
1012
from executorch.backends.cortex_m.passes.passes_utils import (
1113
quantize_multiplier_aot,
1214
SHIFT_INT8,
@@ -30,6 +32,11 @@ class QuantizedOpFusionPass(ExportPass):
3032
"""
3133

3234
def _get_add_replacement(self, args, meta):
35+
if (
36+
meta.data.get("input_qparams", {}) == {}
37+
or meta.data.get("output_qparams", {}) == {}
38+
):
39+
return exir_ops.edge.aten.add.Tensor, args
3340

3441
# Extract values
3542
scale1 = meta["input_qparams"][0].scale
@@ -64,7 +71,12 @@ def _get_add_replacement(self, args, meta):
6471

6572
return exir_ops.edge.cortex_m.quantized_add.default, args
6673

67-
def _get_mul_replacement(self, args, meta) -> int:
74+
def _get_mul_replacement(self, args, meta):
75+
if (
76+
meta.data.get("input_qparams", {}) == {}
77+
or meta.data.get("output_qparams", {}) == {}
78+
):
79+
return exir_ops.edge.aten.mul.Tensor, args
6880

6981
# Extract values
7082
scale1 = meta["input_qparams"][0].scale
@@ -89,24 +101,30 @@ def _get_mul_replacement(self, args, meta) -> int:
89101

90102
return exir_ops.edge.cortex_m.quantized_mul.default, args
91103

104+
def _get_permute_replacement(self, args, meta):
105+
if args[0].data.dtype != torch.int8:
106+
return exir_ops.edge.aten.permute_copy.default, args
107+
108+
rank = len(args[0].data.shape)
109+
perms = [p % rank for p in args[1]]
110+
args = (args[0], perms)
111+
return exir_ops.edge.cortex_m.transpose.default, args
112+
92113
def call_operator(
93114
self,
94115
op: EdgeOpOverload,
95116
args: tuple[Argument, ...],
96117
kwargs: Dict[str, Argument],
97118
meta: NodeMetadata,
98119
) -> ProxyValue:
99-
if (
100-
meta.data.get("input_qparams", {}) == {}
101-
or meta.data.get("output_qparams", {}) == {}
102-
):
103-
return super().call_operator(op, args, {}, meta)
104120

105121
match op:
106122
case exir_ops.edge.aten.add.Tensor:
107123
op, args = self._get_add_replacement(args, meta)
108124
case exir_ops.edge.aten.mul.Tensor:
109125
op, args = self._get_mul_replacement(args, meta)
126+
case exir_ops.edge.aten.permute_copy.default:
127+
op, args = self._get_permute_replacement(args, meta)
110128
case _:
111129
pass
112130

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import torch
8+
from executorch.backends.arm.test.common import parametrize
9+
from executorch.backends.cortex_m.test.tester import (
10+
CortexMTester,
11+
McuTestCase,
12+
ramp_tensor,
13+
)
14+
15+
OPS_BEFORE_PASSES = {
16+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2,
17+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2,
18+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
19+
}
20+
21+
OPS_AFTER_PASSES = {
22+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1,
23+
"executorch_exir_dialects_edge__ops_cortex_m_transpose_default": 1,
24+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
25+
}
26+
27+
28+
class CortexMPermute(torch.nn.Module):
29+
ops_before_transforms = OPS_BEFORE_PASSES
30+
ops_after_transforms = OPS_AFTER_PASSES
31+
32+
def __init__(self, perms):
33+
super().__init__()
34+
self.perms = perms
35+
36+
def forward(self, x):
37+
return x.permute(self.perms)
38+
39+
40+
class CortexMTranspose(torch.nn.Module):
41+
ops_before_transforms = OPS_BEFORE_PASSES
42+
ops_after_transforms = OPS_AFTER_PASSES
43+
44+
def __init__(self, dim0, dim1):
45+
super().__init__()
46+
self.dim0 = dim0
47+
self.dim1 = dim1
48+
49+
def forward(self, x):
50+
return x.transpose(self.dim0, self.dim1)
51+
52+
53+
class CortexMT(torch.nn.Module):
54+
ops_before_transforms = OPS_BEFORE_PASSES
55+
ops_after_transforms = OPS_AFTER_PASSES
56+
57+
def forward(self, x):
58+
return x.t()
59+
60+
61+
test_cases = {
62+
"permute_nhwc_to_nchw": McuTestCase(
63+
CortexMPermute((0, 3, 1, 2)),
64+
(ramp_tensor(-0.5, 0.5, (2, 3, 4, 2)),),
65+
),
66+
"permute_nchw_to_nhwc_neg_index": McuTestCase(
67+
CortexMPermute((0, -2, -1, -3)),
68+
(ramp_tensor(10, 100, (2, 3, 4, 2)),),
69+
),
70+
"permute_rank_1": McuTestCase(
71+
CortexMPermute((0,)),
72+
(ramp_tensor(10, 100, (3)),),
73+
),
74+
"transpose_1_2": McuTestCase(
75+
CortexMTranspose(1, 2),
76+
(ramp_tensor(-1.0, 1.0, (1, 3, 4)),),
77+
),
78+
"transpose_0_1": McuTestCase(
79+
CortexMTranspose(0, 1),
80+
(ramp_tensor(-2.0, 2.0, (2, 3, 4, 3)),),
81+
),
82+
"t_operator": McuTestCase(
83+
CortexMT(),
84+
(ramp_tensor(-0.5, 0.5, (4, 2)),),
85+
),
86+
}
87+
88+
89+
@parametrize("test_case", test_cases)
90+
def test_dialect_transpose(test_case):
91+
tester = CortexMTester(test_case.model, test_case.example_inputs)
92+
tester.test_dialect(
93+
test_case.model.ops_before_transforms,
94+
test_case.model.ops_after_transforms,
95+
qtol=1,
96+
)
97+
98+
99+
@parametrize("test_case", test_cases)
100+
def test_implementation_transpose(test_case):
101+
tester = CortexMTester(test_case.model, test_case.example_inputs)
102+
tester.test_implementation(qtol=1)

0 commit comments

Comments
 (0)