11import unittest
22
33from transformers import LwDetrViTConfig , is_torch_available
4- from transformers .testing_utils import require_torch
4+ from transformers .testing_utils import require_torch , torch_device
55
66from ...test_backbone_common import BackboneTesterMixin
7- from ...test_modeling_common import floats_tensor , ids_tensor
7+ from ...test_modeling_common import ModelTesterMixin , floats_tensor , ids_tensor
88
99
1010if is_torch_available ():
11+ import torch
12+ from torch import nn
13+
1114 from transformers import LwDetrViTBackbone
1215
1316
@@ -22,10 +25,10 @@ def __init__(
2225 is_training = True ,
2326 image_size = 256 ,
2427 hidden_size = 16 ,
25- num_hidden_layers = 6 ,
28+ num_hidden_layers = 2 ,
2629 num_attention_heads = 2 ,
27- window_block_indices = [0 , 2 ],
28- out_indices = [1 , 3 , 5 ],
30+ window_block_indices = [1 ],
31+ out_indices = [0 ],
2932 num_windows = 16 ,
3033 dropout_prob = 0.0 ,
3134 attn_implementation = "eager" ,
@@ -78,12 +81,179 @@ def prepare_config_and_inputs_for_common(self):
7881 inputs_dict = {"pixel_values" : pixel_values }
7982 return config , inputs_dict
8083
84+ def create_and_check_backbone (self , config , pixel_values , labels ):
85+ model = LwDetrViTBackbone (config = config )
86+ model .to (torch_device )
87+ model .eval ()
88+ result = model (pixel_values )
89+
90+ # verify hidden states
91+ self .parent .assertEqual (len (result .feature_maps ), len (config .out_features ))
92+ self .parent .assertListEqual (
93+ list (result .feature_maps [0 ].shape ),
94+ [
95+ self .batch_size ,
96+ self .hidden_size ,
97+ self .get_config ().num_windows_side ** 2 ,
98+ self .get_config ().num_windows_side ** 2 ,
99+ ],
100+ )
101+
102+ # verify channels
103+ self .parent .assertEqual (len (model .channels ), len (config .out_features ))
104+ self .parent .assertListEqual (model .channels , [config .hidden_size ])
105+
106+ # verify backbone works with out_features=None
107+ config .out_features = None
108+ model = LwDetrViTBackbone (config = config )
109+ model .to (torch_device )
110+ model .eval ()
111+ result = model (pixel_values )
112+
113+ # verify feature maps
114+ self .parent .assertEqual (len (result .feature_maps ), 1 )
115+ self .parent .assertListEqual (
116+ list (result .feature_maps [0 ].shape ),
117+ [self .batch_size , config .hidden_size , config .patch_size , config .patch_size ],
118+ )
119+
120+ # verify channels
121+ self .parent .assertEqual (len (model .channels ), 1 )
122+ self .parent .assertListEqual (model .channels , [config .hidden_size ])
123+
81124
82125@require_torch
83- class LwDetrViTBackboneTest (BackboneTesterMixin , unittest .TestCase ):
126+ class LwDetrViTBackboneTest (ModelTesterMixin , BackboneTesterMixin , unittest .TestCase ):
84127 all_model_classes = (LwDetrViTBackbone ,) if is_torch_available () else ()
85- has_attentions = False
86128 config_class = LwDetrViTConfig
129+ test_resize_embeddings = False
130+ test_torch_exportable = True
87131
88132 def setUp (self ):
89133 self .model_tester = LwDetrVitModelTester (self )
134+
135+ def test_backbone (self ):
136+ config_and_inputs = self .model_tester .prepare_config_and_inputs ()
137+ self .model_tester .create_and_check_backbone (* config_and_inputs )
138+
139+ def test_model_get_set_embeddings (self ):
140+ config , _ = self .model_tester .prepare_config_and_inputs_for_common ()
141+
142+ for model_class in self .all_model_classes :
143+ model = model_class (config )
144+ self .assertIsInstance (model .get_input_embeddings (), (nn .Module ))
145+ x = model .get_output_embeddings ()
146+ self .assertTrue (x is None or isinstance (x , nn .Linear ))
147+
148+ def test_attention_outputs (self ):
149+ def check_attention_output (inputs_dict , config , model_class ):
150+ config ._attn_implementation = "eager"
151+ model = model_class (config )
152+ model .to (torch_device )
153+ model .eval ()
154+
155+ with torch .no_grad ():
156+ outputs = model (** self ._prepare_for_class (inputs_dict , model_class ))
157+
158+ attentions = outputs .attentions
159+
160+ windowed_attentions = [attentions [i ] for i in self .model_tester .window_block_indices ]
161+ unwindowed_attentions = [attentions [i ] for i in self .model_tester .out_indices ]
162+
163+ expected_windowed_attention_shape = [
164+ self .model_tester .batch_size * self .model_tester .num_windows ,
165+ self .model_tester .num_attention_heads ,
166+ self .model_tester .get_config ().num_windows_side ** 2 ,
167+ self .model_tester .get_config ().num_windows_side ** 2 ,
168+ ]
169+
170+ expected_unwindowed_attention_shape = [
171+ self .model_tester .batch_size ,
172+ self .model_tester .num_attention_heads ,
173+ self .model_tester .image_size ,
174+ self .model_tester .image_size ,
175+ ]
176+
177+ for i , attention in enumerate (windowed_attentions ):
178+ self .assertListEqual (
179+ list (attention .shape ),
180+ expected_windowed_attention_shape ,
181+ )
182+
183+ for i , attention in enumerate (unwindowed_attentions ):
184+ self .assertListEqual (
185+ list (attention .shape ),
186+ expected_unwindowed_attention_shape ,
187+ )
188+
189+ config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
190+
191+ for model_class in self .all_model_classes :
192+ inputs_dict ["output_attentions" ] = True
193+ check_attention_output (inputs_dict , config , model_class )
194+
195+ # check that output_hidden_states also work using config
196+ del inputs_dict ["output_attentions" ]
197+ config .output_attentions = True
198+
199+ check_attention_output (inputs_dict , config , model_class )
200+
201+ def test_hidden_states_output (self ):
202+ def check_hidden_states_output (inputs_dict , config , model_class ):
203+ model = model_class (config )
204+ model .to (torch_device )
205+ model .eval ()
206+
207+ with torch .no_grad ():
208+ outputs = model (** self ._prepare_for_class (inputs_dict , model_class ))
209+
210+ hidden_states = outputs .hidden_states
211+
212+ expected_num_stages = self .model_tester .num_hidden_layers
213+ self .assertEqual (len (hidden_states ), expected_num_stages + 1 )
214+
215+ # VitDet's feature maps are of shape (batch_size, num_channels, height, width)
216+ self .assertListEqual (
217+ list (hidden_states [0 ].shape [- 2 :]),
218+ [
219+ self .model_tester .hidden_size ,
220+ self .model_tester .hidden_size ,
221+ ],
222+ )
223+
224+ config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
225+
226+ for model_class in self .all_model_classes :
227+ inputs_dict ["output_hidden_states" ] = True
228+ check_hidden_states_output (inputs_dict , config , model_class )
229+
230+ # check that output_hidden_states also work using config
231+ del inputs_dict ["output_hidden_states" ]
232+ config .output_hidden_states = True
233+
234+ check_hidden_states_output (inputs_dict , config , model_class )
235+
236+ # overwrite since LwDetrVitDet only supports retraining gradients of hidden states
237+ def test_retain_grad_hidden_states_attentions (self ):
238+ config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
239+ config .output_hidden_states = True
240+ config .output_attentions = self .has_attentions
241+
242+ # no need to test all models as different heads yield the same functionality
243+ model_class = self .all_model_classes [0 ]
244+ model = model_class (config )
245+ model .to (torch_device )
246+
247+ inputs = self ._prepare_for_class (inputs_dict , model_class )
248+
249+ outputs = model (** inputs )
250+
251+ output = outputs .feature_maps [0 ]
252+
253+ # Encoder-/Decoder-only models
254+ hidden_states = outputs .hidden_states [0 ]
255+ hidden_states .retain_grad ()
256+
257+ output .flatten ()[0 ].backward (retain_graph = True )
258+
259+ self .assertIsNotNone (hidden_states .grad )
0 commit comments