Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 19 additions & 30 deletions backends/arm/test/ops/test_index_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.


from enum import IntEnum
from typing import Tuple

import torch
Expand All @@ -25,22 +24,12 @@ class IndexTensorTestCommon:
# Gathers and reshapes should result in no inaccuracies
rtol = 0.0
atol = 0.0
BEFORE = "BEFORE"
MIDDLE = "MIDDLE"
AFTER = "AFTER"

class OpPlacement(IntEnum):
"""
Simple enum used to indicate where slices or ellipsis should be placed
in tests.
IntEnum so that Dynamo does not complain about unsupported types.
"""

BEFORE = 1
MIDDLE = 2
AFTER = 3


input_params_slice = Tuple[
torch.Tensor, int, int, IndexTensorTestCommon.OpPlacement, Tuple[torch.Tensor]
]
input_params_slice = Tuple[torch.Tensor, int, int, str, Tuple[torch.Tensor]]
input_params = Tuple[torch.Tensor, Tuple[torch.Tensor]]


Expand All @@ -55,12 +44,12 @@ class IndexTensor_Ellipsis(torch.nn.Module):
test_data_ellipsis: dict[input_params] = {
"test_4d_ellipsis_before": (
torch.rand(size=(25, 5, 13, 7)),
IndexTensorTestCommon.OpPlacement.BEFORE,
IndexTensorTestCommon.BEFORE,
(torch.arange(2, dtype=torch.int32),),
),
"test_4d_ellipsis_middle": (
torch.rand(size=(25, 5, 13, 7)),
IndexTensorTestCommon.OpPlacement.MIDDLE,
IndexTensorTestCommon.MIDDLE,
(
torch.arange(2, dtype=torch.int32),
torch.arange(2, dtype=torch.int32),
Expand All @@ -72,23 +61,23 @@ class IndexTensor_Ellipsis(torch.nn.Module):
# partitioning is difficult and unreliable, as such
# it is not xfail as the existing logic can handle it.
torch.rand(size=(25, 5, 13, 7)),
IndexTensorTestCommon.OpPlacement.AFTER,
IndexTensorTestCommon.AFTER,
(torch.arange(2, dtype=torch.int32),),
),
}

def forward(
self,
input_: torch.Tensor,
position: IndexTensorTestCommon.OpPlacement,
position: str,
indices: Tuple[None | torch.Tensor],
):
match position:
case IndexTensorTestCommon.OpPlacement.BEFORE:
case IndexTensorTestCommon.BEFORE:
return input_[..., indices[0]]
case IndexTensorTestCommon.OpPlacement.MIDDLE:
case IndexTensorTestCommon.MIDDLE:
return input_[indices[0], ..., indices[1]]
case IndexTensorTestCommon.OpPlacement.AFTER:
case IndexTensorTestCommon.AFTER:
return input_[indices[0], ...]

return input_[indices]
Expand Down Expand Up @@ -154,7 +143,7 @@ class IndexTensor_Slice(torch.nn.Module):
torch.rand(size=(5, 3, 4, 5)),
0,
2,
IndexTensorTestCommon.OpPlacement.BEFORE,
IndexTensorTestCommon.BEFORE,
(torch.arange(2, dtype=torch.int32),),
),
"test_3d_slice_before_2d_idx": (
Expand All @@ -164,14 +153,14 @@ class IndexTensor_Slice(torch.nn.Module):
torch.arange(5 * 3 * 4, dtype=torch.float32).reshape(5, 3, 4),
0,
2,
IndexTensorTestCommon.OpPlacement.BEFORE,
IndexTensorTestCommon.BEFORE,
(torch.arange(2, dtype=torch.int32).unsqueeze(0).tile(2, 1),),
),
"test_4d_slice_middle": (
torch.arange(5 * 3 * 2, dtype=torch.int32).reshape(5, 3, 2),
0,
2,
IndexTensorTestCommon.OpPlacement.MIDDLE,
IndexTensorTestCommon.MIDDLE,
(
torch.arange(2, dtype=torch.int32),
torch.arange(2, dtype=torch.int32),
Expand All @@ -185,7 +174,7 @@ class IndexTensor_Slice(torch.nn.Module):
torch.rand(size=(25, 5, 13, 7)),
0,
2,
IndexTensorTestCommon.OpPlacement.AFTER,
IndexTensorTestCommon.AFTER,
(torch.arange(2, dtype=torch.int32),),
),
}
Expand All @@ -195,15 +184,15 @@ def forward(
input_: torch.Tensor,
slice_start: int,
slice_end: int,
position: IndexTensorTestCommon.OpPlacement,
position: str,
indices: Tuple[None | torch.Tensor],
):
match position:
case IndexTensorTestCommon.OpPlacement.BEFORE:
case IndexTensorTestCommon.BEFORE:
return input_[slice_start:slice_end, indices[0]]
case IndexTensorTestCommon.OpPlacement.MIDDLE:
case IndexTensorTestCommon.MIDDLE:
return input_[indices[0], slice_start:slice_end, indices[1]]
case IndexTensorTestCommon.OpPlacement.AFTER:
case IndexTensorTestCommon.AFTER:
return input_[indices[0], slice_start:slice_end]


Expand Down
Loading