44# LICENSE file in the root directory of this source tree.
55
66
7- from enum import IntEnum
87from typing import Tuple
98
109import torch
@@ -25,22 +24,12 @@ class IndexTensorTestCommon:
2524 # Gathers and reshapes should result in no inaccuracies
2625 rtol = 0.0
2726 atol = 0.0
27+ BEFORE = "BEFORE"
28+ MIDDLE = "MIDDLE"
29+ AFTER = "AFTER"
2830
29- class OpPlacement (IntEnum ):
30- """
31- Simple enum used to indicate where slices or ellipsis should be placed
32- in tests.
33- IntEnum so that Dynamo does not complain about unsupported types.
34- """
3531
36- BEFORE = 1
37- MIDDLE = 2
38- AFTER = 3
39-
40-
41- input_params_slice = Tuple [
42- torch .Tensor , int , int , IndexTensorTestCommon .OpPlacement , Tuple [torch .Tensor ]
43- ]
32+ input_params_slice = Tuple [torch .Tensor , int , int , str , Tuple [torch .Tensor ]]
4433input_params = Tuple [torch .Tensor , Tuple [torch .Tensor ]]
4534
4635
@@ -55,12 +44,12 @@ class IndexTensor_Ellipsis(torch.nn.Module):
5544 test_data_ellipsis : dict [input_params ] = {
5645 "test_4d_ellipsis_before" : (
5746 torch .rand (size = (25 , 5 , 13 , 7 )),
58- IndexTensorTestCommon .OpPlacement . BEFORE ,
47+ IndexTensorTestCommon .BEFORE ,
5948 (torch .arange (2 , dtype = torch .int32 ),),
6049 ),
6150 "test_4d_ellipsis_middle" : (
6251 torch .rand (size = (25 , 5 , 13 , 7 )),
63- IndexTensorTestCommon .OpPlacement . MIDDLE ,
52+ IndexTensorTestCommon .MIDDLE ,
6453 (
6554 torch .arange (2 , dtype = torch .int32 ),
6655 torch .arange (2 , dtype = torch .int32 ),
@@ -72,23 +61,23 @@ class IndexTensor_Ellipsis(torch.nn.Module):
7261 # partitioning is difficult and unreliable, as such
7362 # it is not xfail as the existing logic can handle it.
7463 torch .rand (size = (25 , 5 , 13 , 7 )),
75- IndexTensorTestCommon .OpPlacement . AFTER ,
64+ IndexTensorTestCommon .AFTER ,
7665 (torch .arange (2 , dtype = torch .int32 ),),
7766 ),
7867 }
7968
8069 def forward (
8170 self ,
8271 input_ : torch .Tensor ,
83- position : IndexTensorTestCommon . OpPlacement ,
72+ position : str ,
8473 indices : Tuple [None | torch .Tensor ],
8574 ):
8675 match position :
87- case IndexTensorTestCommon .OpPlacement . BEFORE :
76+ case IndexTensorTestCommon .BEFORE :
8877 return input_ [..., indices [0 ]]
89- case IndexTensorTestCommon .OpPlacement . MIDDLE :
78+ case IndexTensorTestCommon .MIDDLE :
9079 return input_ [indices [0 ], ..., indices [1 ]]
91- case IndexTensorTestCommon .OpPlacement . AFTER :
80+ case IndexTensorTestCommon .AFTER :
9281 return input_ [indices [0 ], ...]
9382
9483 return input_ [indices ]
@@ -154,7 +143,7 @@ class IndexTensor_Slice(torch.nn.Module):
154143 torch .rand (size = (5 , 3 , 4 , 5 )),
155144 0 ,
156145 2 ,
157- IndexTensorTestCommon .OpPlacement . BEFORE ,
146+ IndexTensorTestCommon .BEFORE ,
158147 (torch .arange (2 , dtype = torch .int32 ),),
159148 ),
160149 "test_3d_slice_before_2d_idx" : (
@@ -164,14 +153,14 @@ class IndexTensor_Slice(torch.nn.Module):
164153 torch .arange (5 * 3 * 4 , dtype = torch .float32 ).reshape (5 , 3 , 4 ),
165154 0 ,
166155 2 ,
167- IndexTensorTestCommon .OpPlacement . BEFORE ,
156+ IndexTensorTestCommon .BEFORE ,
168157 (torch .arange (2 , dtype = torch .int32 ).unsqueeze (0 ).tile (2 , 1 ),),
169158 ),
170159 "test_4d_slice_middle" : (
171160 torch .arange (5 * 3 * 2 , dtype = torch .int32 ).reshape (5 , 3 , 2 ),
172161 0 ,
173162 2 ,
174- IndexTensorTestCommon .OpPlacement . MIDDLE ,
163+ IndexTensorTestCommon .MIDDLE ,
175164 (
176165 torch .arange (2 , dtype = torch .int32 ),
177166 torch .arange (2 , dtype = torch .int32 ),
@@ -185,7 +174,7 @@ class IndexTensor_Slice(torch.nn.Module):
185174 torch .rand (size = (25 , 5 , 13 , 7 )),
186175 0 ,
187176 2 ,
188- IndexTensorTestCommon .OpPlacement . AFTER ,
177+ IndexTensorTestCommon .AFTER ,
189178 (torch .arange (2 , dtype = torch .int32 ),),
190179 ),
191180 }
@@ -195,15 +184,15 @@ def forward(
195184 input_ : torch .Tensor ,
196185 slice_start : int ,
197186 slice_end : int ,
198- position : IndexTensorTestCommon . OpPlacement ,
187+ position : str ,
199188 indices : Tuple [None | torch .Tensor ],
200189 ):
201190 match position :
202- case IndexTensorTestCommon .OpPlacement . BEFORE :
191+ case IndexTensorTestCommon .BEFORE :
203192 return input_ [slice_start :slice_end , indices [0 ]]
204- case IndexTensorTestCommon .OpPlacement . MIDDLE :
193+ case IndexTensorTestCommon .MIDDLE :
205194 return input_ [indices [0 ], slice_start :slice_end , indices [1 ]]
206- case IndexTensorTestCommon .OpPlacement . AFTER :
195+ case IndexTensorTestCommon .AFTER :
207196 return input_ [indices [0 ], slice_start :slice_end ]
208197
209198
0 commit comments