diff --git a/backends/arm/test/ops/test_index_tensor.py b/backends/arm/test/ops/test_index_tensor.py index 557846922b8..bc19634bf30 100644 --- a/backends/arm/test/ops/test_index_tensor.py +++ b/backends/arm/test/ops/test_index_tensor.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. -from enum import IntEnum from typing import Tuple import torch @@ -25,22 +24,12 @@ class IndexTensorTestCommon: # Gathers and reshapes should result in no inaccuracies rtol = 0.0 atol = 0.0 + BEFORE = "BEFORE" + MIDDLE = "MIDDLE" + AFTER = "AFTER" - class OpPlacement(IntEnum): - """ - Simple enum used to indicate where slices or ellipsis should be placed - in tests. - IntEnum so that Dynamo does not complain about unsupported types. - """ - BEFORE = 1 - MIDDLE = 2 - AFTER = 3 - - -input_params_slice = Tuple[ - torch.Tensor, int, int, IndexTensorTestCommon.OpPlacement, Tuple[torch.Tensor] -] +input_params_slice = Tuple[torch.Tensor, int, int, str, Tuple[torch.Tensor]] input_params = Tuple[torch.Tensor, Tuple[torch.Tensor]] @@ -55,12 +44,12 @@ class IndexTensor_Ellipsis(torch.nn.Module): test_data_ellipsis: dict[input_params] = { "test_4d_ellipsis_before": ( torch.rand(size=(25, 5, 13, 7)), - IndexTensorTestCommon.OpPlacement.BEFORE, + IndexTensorTestCommon.BEFORE, (torch.arange(2, dtype=torch.int32),), ), "test_4d_ellipsis_middle": ( torch.rand(size=(25, 5, 13, 7)), - IndexTensorTestCommon.OpPlacement.MIDDLE, + IndexTensorTestCommon.MIDDLE, ( torch.arange(2, dtype=torch.int32), torch.arange(2, dtype=torch.int32), @@ -72,7 +61,7 @@ class IndexTensor_Ellipsis(torch.nn.Module): # partitioning is difficult and unreliable, as such # it is not xfail as the existing logic can handle it. torch.rand(size=(25, 5, 13, 7)), - IndexTensorTestCommon.OpPlacement.AFTER, + IndexTensorTestCommon.AFTER, (torch.arange(2, dtype=torch.int32),), ), } @@ -80,15 +69,15 @@ class IndexTensor_Ellipsis(torch.nn.Module): def forward( self, input_: torch.Tensor, - position: IndexTensorTestCommon.OpPlacement, + position: str, indices: Tuple[None | torch.Tensor], ): match position: - case IndexTensorTestCommon.OpPlacement.BEFORE: + case IndexTensorTestCommon.BEFORE: return input_[..., indices[0]] - case IndexTensorTestCommon.OpPlacement.MIDDLE: + case IndexTensorTestCommon.MIDDLE: return input_[indices[0], ..., indices[1]] - case IndexTensorTestCommon.OpPlacement.AFTER: + case IndexTensorTestCommon.AFTER: return input_[indices[0], ...] return input_[indices] @@ -154,7 +143,7 @@ class IndexTensor_Slice(torch.nn.Module): torch.rand(size=(5, 3, 4, 5)), 0, 2, - IndexTensorTestCommon.OpPlacement.BEFORE, + IndexTensorTestCommon.BEFORE, (torch.arange(2, dtype=torch.int32),), ), "test_3d_slice_before_2d_idx": ( @@ -164,14 +153,14 @@ class IndexTensor_Slice(torch.nn.Module): torch.arange(5 * 3 * 4, dtype=torch.float32).reshape(5, 3, 4), 0, 2, - IndexTensorTestCommon.OpPlacement.BEFORE, + IndexTensorTestCommon.BEFORE, (torch.arange(2, dtype=torch.int32).unsqueeze(0).tile(2, 1),), ), "test_4d_slice_middle": ( torch.arange(5 * 3 * 2, dtype=torch.int32).reshape(5, 3, 2), 0, 2, - IndexTensorTestCommon.OpPlacement.MIDDLE, + IndexTensorTestCommon.MIDDLE, ( torch.arange(2, dtype=torch.int32), torch.arange(2, dtype=torch.int32), @@ -185,7 +174,7 @@ class IndexTensor_Slice(torch.nn.Module): torch.rand(size=(25, 5, 13, 7)), 0, 2, - IndexTensorTestCommon.OpPlacement.AFTER, + IndexTensorTestCommon.AFTER, (torch.arange(2, dtype=torch.int32),), ), } @@ -195,15 +184,15 @@ def forward( input_: torch.Tensor, slice_start: int, slice_end: int, - position: IndexTensorTestCommon.OpPlacement, + position: str, indices: Tuple[None | torch.Tensor], ): match position: - case IndexTensorTestCommon.OpPlacement.BEFORE: + case IndexTensorTestCommon.BEFORE: return input_[slice_start:slice_end, indices[0]] - case IndexTensorTestCommon.OpPlacement.MIDDLE: + case IndexTensorTestCommon.MIDDLE: return input_[indices[0], slice_start:slice_end, indices[1]] - case IndexTensorTestCommon.OpPlacement.AFTER: + case IndexTensorTestCommon.AFTER: return input_[indices[0], slice_start:slice_end]