Skip to content

Commit 2bc465a

Browse files
Arm backend: Fix broken index_tensor tests (#16220)
index_tensors tests used an IntEnum to indicate which positions to index. This caused export to fail. This patches changes the IntEnum to simply use strings instead. cc @freddan80 @per @zingo @digantdesai Signed-off-by: Oscar Andersson <[email protected]>
1 parent cdc2701 commit 2bc465a

File tree

1 file changed

+19
-30
lines changed

1 file changed

+19
-30
lines changed

backends/arm/test/ops/test_index_tensor.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
from enum import IntEnum
87
from typing import Tuple
98

109
import torch
@@ -25,22 +24,12 @@ class IndexTensorTestCommon:
2524
# Gathers and reshapes should result in no inaccuracies
2625
rtol = 0.0
2726
atol = 0.0
27+
BEFORE = "BEFORE"
28+
MIDDLE = "MIDDLE"
29+
AFTER = "AFTER"
2830

29-
class OpPlacement(IntEnum):
30-
"""
31-
Simple enum used to indicate where slices or ellipsis should be placed
32-
in tests.
33-
IntEnum so that Dynamo does not complain about unsupported types.
34-
"""
3531

36-
BEFORE = 1
37-
MIDDLE = 2
38-
AFTER = 3
39-
40-
41-
input_params_slice = Tuple[
42-
torch.Tensor, int, int, IndexTensorTestCommon.OpPlacement, Tuple[torch.Tensor]
43-
]
32+
input_params_slice = Tuple[torch.Tensor, int, int, str, Tuple[torch.Tensor]]
4433
input_params = Tuple[torch.Tensor, Tuple[torch.Tensor]]
4534

4635

@@ -55,12 +44,12 @@ class IndexTensor_Ellipsis(torch.nn.Module):
5544
test_data_ellipsis: dict[input_params] = {
5645
"test_4d_ellipsis_before": (
5746
torch.rand(size=(25, 5, 13, 7)),
58-
IndexTensorTestCommon.OpPlacement.BEFORE,
47+
IndexTensorTestCommon.BEFORE,
5948
(torch.arange(2, dtype=torch.int32),),
6049
),
6150
"test_4d_ellipsis_middle": (
6251
torch.rand(size=(25, 5, 13, 7)),
63-
IndexTensorTestCommon.OpPlacement.MIDDLE,
52+
IndexTensorTestCommon.MIDDLE,
6453
(
6554
torch.arange(2, dtype=torch.int32),
6655
torch.arange(2, dtype=torch.int32),
@@ -72,23 +61,23 @@ class IndexTensor_Ellipsis(torch.nn.Module):
7261
# partitioning is difficult and unreliable, as such
7362
# it is not xfail as the existing logic can handle it.
7463
torch.rand(size=(25, 5, 13, 7)),
75-
IndexTensorTestCommon.OpPlacement.AFTER,
64+
IndexTensorTestCommon.AFTER,
7665
(torch.arange(2, dtype=torch.int32),),
7766
),
7867
}
7968

8069
def forward(
8170
self,
8271
input_: torch.Tensor,
83-
position: IndexTensorTestCommon.OpPlacement,
72+
position: str,
8473
indices: Tuple[None | torch.Tensor],
8574
):
8675
match position:
87-
case IndexTensorTestCommon.OpPlacement.BEFORE:
76+
case IndexTensorTestCommon.BEFORE:
8877
return input_[..., indices[0]]
89-
case IndexTensorTestCommon.OpPlacement.MIDDLE:
78+
case IndexTensorTestCommon.MIDDLE:
9079
return input_[indices[0], ..., indices[1]]
91-
case IndexTensorTestCommon.OpPlacement.AFTER:
80+
case IndexTensorTestCommon.AFTER:
9281
return input_[indices[0], ...]
9382

9483
return input_[indices]
@@ -154,7 +143,7 @@ class IndexTensor_Slice(torch.nn.Module):
154143
torch.rand(size=(5, 3, 4, 5)),
155144
0,
156145
2,
157-
IndexTensorTestCommon.OpPlacement.BEFORE,
146+
IndexTensorTestCommon.BEFORE,
158147
(torch.arange(2, dtype=torch.int32),),
159148
),
160149
"test_3d_slice_before_2d_idx": (
@@ -164,14 +153,14 @@ class IndexTensor_Slice(torch.nn.Module):
164153
torch.arange(5 * 3 * 4, dtype=torch.float32).reshape(5, 3, 4),
165154
0,
166155
2,
167-
IndexTensorTestCommon.OpPlacement.BEFORE,
156+
IndexTensorTestCommon.BEFORE,
168157
(torch.arange(2, dtype=torch.int32).unsqueeze(0).tile(2, 1),),
169158
),
170159
"test_4d_slice_middle": (
171160
torch.arange(5 * 3 * 2, dtype=torch.int32).reshape(5, 3, 2),
172161
0,
173162
2,
174-
IndexTensorTestCommon.OpPlacement.MIDDLE,
163+
IndexTensorTestCommon.MIDDLE,
175164
(
176165
torch.arange(2, dtype=torch.int32),
177166
torch.arange(2, dtype=torch.int32),
@@ -185,7 +174,7 @@ class IndexTensor_Slice(torch.nn.Module):
185174
torch.rand(size=(25, 5, 13, 7)),
186175
0,
187176
2,
188-
IndexTensorTestCommon.OpPlacement.AFTER,
177+
IndexTensorTestCommon.AFTER,
189178
(torch.arange(2, dtype=torch.int32),),
190179
),
191180
}
@@ -195,15 +184,15 @@ def forward(
195184
input_: torch.Tensor,
196185
slice_start: int,
197186
slice_end: int,
198-
position: IndexTensorTestCommon.OpPlacement,
187+
position: str,
199188
indices: Tuple[None | torch.Tensor],
200189
):
201190
match position:
202-
case IndexTensorTestCommon.OpPlacement.BEFORE:
191+
case IndexTensorTestCommon.BEFORE:
203192
return input_[slice_start:slice_end, indices[0]]
204-
case IndexTensorTestCommon.OpPlacement.MIDDLE:
193+
case IndexTensorTestCommon.MIDDLE:
205194
return input_[indices[0], slice_start:slice_end, indices[1]]
206-
case IndexTensorTestCommon.OpPlacement.AFTER:
195+
case IndexTensorTestCommon.AFTER:
207196
return input_[indices[0], slice_start:slice_end]
208197

209198

0 commit comments

Comments
 (0)