Skip to content

Commit 082715b

Browse files
committed
tests: added VitDet and attention tests
1 parent 13e9aa3 commit 082715b

File tree

1 file changed

+177
-7
lines changed

1 file changed

+177
-7
lines changed

tests/models/lw_detr/test_modeling_lw_detr_vit.py

Lines changed: 177 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import unittest
22

33
from transformers import LwDetrViTConfig, is_torch_available
4-
from transformers.testing_utils import require_torch
4+
from transformers.testing_utils import require_torch, torch_device
55

66
from ...test_backbone_common import BackboneTesterMixin
7-
from ...test_modeling_common import floats_tensor, ids_tensor
7+
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
88

99

1010
if is_torch_available():
11+
import torch
12+
from torch import nn
13+
1114
from transformers import LwDetrViTBackbone
1215

1316

@@ -22,10 +25,10 @@ def __init__(
2225
is_training=True,
2326
image_size=256,
2427
hidden_size=16,
25-
num_hidden_layers=6,
28+
num_hidden_layers=2,
2629
num_attention_heads=2,
27-
window_block_indices=[0, 2],
28-
out_indices=[1, 3, 5],
30+
window_block_indices=[1],
31+
out_indices=[0],
2932
num_windows=16,
3033
dropout_prob=0.0,
3134
attn_implementation="eager",
@@ -78,12 +81,179 @@ def prepare_config_and_inputs_for_common(self):
7881
inputs_dict = {"pixel_values": pixel_values}
7982
return config, inputs_dict
8083

84+
def create_and_check_backbone(self, config, pixel_values, labels):
85+
model = LwDetrViTBackbone(config=config)
86+
model.to(torch_device)
87+
model.eval()
88+
result = model(pixel_values)
89+
90+
# verify hidden states
91+
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
92+
self.parent.assertListEqual(
93+
list(result.feature_maps[0].shape),
94+
[
95+
self.batch_size,
96+
self.hidden_size,
97+
self.get_config().num_windows_side ** 2,
98+
self.get_config().num_windows_side ** 2,
99+
],
100+
)
101+
102+
# verify channels
103+
self.parent.assertEqual(len(model.channels), len(config.out_features))
104+
self.parent.assertListEqual(model.channels, [config.hidden_size])
105+
106+
# verify backbone works with out_features=None
107+
config.out_features = None
108+
model = LwDetrViTBackbone(config=config)
109+
model.to(torch_device)
110+
model.eval()
111+
result = model(pixel_values)
112+
113+
# verify feature maps
114+
self.parent.assertEqual(len(result.feature_maps), 1)
115+
self.parent.assertListEqual(
116+
list(result.feature_maps[0].shape),
117+
[self.batch_size, config.hidden_size, config.patch_size, config.patch_size],
118+
)
119+
120+
# verify channels
121+
self.parent.assertEqual(len(model.channels), 1)
122+
self.parent.assertListEqual(model.channels, [config.hidden_size])
123+
81124

82125
@require_torch
83-
class LwDetrViTBackboneTest(BackboneTesterMixin, unittest.TestCase):
126+
class LwDetrViTBackboneTest(ModelTesterMixin, BackboneTesterMixin, unittest.TestCase):
84127
all_model_classes = (LwDetrViTBackbone,) if is_torch_available() else ()
85-
has_attentions = False
86128
config_class = LwDetrViTConfig
129+
test_resize_embeddings = False
130+
test_torch_exportable = True
87131

88132
def setUp(self):
89133
self.model_tester = LwDetrVitModelTester(self)
134+
135+
def test_backbone(self):
136+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
137+
self.model_tester.create_and_check_backbone(*config_and_inputs)
138+
139+
def test_model_get_set_embeddings(self):
140+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
141+
142+
for model_class in self.all_model_classes:
143+
model = model_class(config)
144+
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
145+
x = model.get_output_embeddings()
146+
self.assertTrue(x is None or isinstance(x, nn.Linear))
147+
148+
def test_attention_outputs(self):
149+
def check_attention_output(inputs_dict, config, model_class):
150+
config._attn_implementation = "eager"
151+
model = model_class(config)
152+
model.to(torch_device)
153+
model.eval()
154+
155+
with torch.no_grad():
156+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
157+
158+
attentions = outputs.attentions
159+
160+
windowed_attentions = [attentions[i] for i in self.model_tester.window_block_indices]
161+
unwindowed_attentions = [attentions[i] for i in self.model_tester.out_indices]
162+
163+
expected_windowed_attention_shape = [
164+
self.model_tester.batch_size * self.model_tester.num_windows,
165+
self.model_tester.num_attention_heads,
166+
self.model_tester.get_config().num_windows_side ** 2,
167+
self.model_tester.get_config().num_windows_side ** 2,
168+
]
169+
170+
expected_unwindowed_attention_shape = [
171+
self.model_tester.batch_size,
172+
self.model_tester.num_attention_heads,
173+
self.model_tester.image_size,
174+
self.model_tester.image_size,
175+
]
176+
177+
for i, attention in enumerate(windowed_attentions):
178+
self.assertListEqual(
179+
list(attention.shape),
180+
expected_windowed_attention_shape,
181+
)
182+
183+
for i, attention in enumerate(unwindowed_attentions):
184+
self.assertListEqual(
185+
list(attention.shape),
186+
expected_unwindowed_attention_shape,
187+
)
188+
189+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
190+
191+
for model_class in self.all_model_classes:
192+
inputs_dict["output_attentions"] = True
193+
check_attention_output(inputs_dict, config, model_class)
194+
195+
# check that output_hidden_states also work using config
196+
del inputs_dict["output_attentions"]
197+
config.output_attentions = True
198+
199+
check_attention_output(inputs_dict, config, model_class)
200+
201+
def test_hidden_states_output(self):
202+
def check_hidden_states_output(inputs_dict, config, model_class):
203+
model = model_class(config)
204+
model.to(torch_device)
205+
model.eval()
206+
207+
with torch.no_grad():
208+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
209+
210+
hidden_states = outputs.hidden_states
211+
212+
expected_num_stages = self.model_tester.num_hidden_layers
213+
self.assertEqual(len(hidden_states), expected_num_stages + 1)
214+
215+
# VitDet's feature maps are of shape (batch_size, num_channels, height, width)
216+
self.assertListEqual(
217+
list(hidden_states[0].shape[-2:]),
218+
[
219+
self.model_tester.hidden_size,
220+
self.model_tester.hidden_size,
221+
],
222+
)
223+
224+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
225+
226+
for model_class in self.all_model_classes:
227+
inputs_dict["output_hidden_states"] = True
228+
check_hidden_states_output(inputs_dict, config, model_class)
229+
230+
# check that output_hidden_states also work using config
231+
del inputs_dict["output_hidden_states"]
232+
config.output_hidden_states = True
233+
234+
check_hidden_states_output(inputs_dict, config, model_class)
235+
236+
# overwrite since LwDetrVitDet only supports retraining gradients of hidden states
237+
def test_retain_grad_hidden_states_attentions(self):
238+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
239+
config.output_hidden_states = True
240+
config.output_attentions = self.has_attentions
241+
242+
# no need to test all models as different heads yield the same functionality
243+
model_class = self.all_model_classes[0]
244+
model = model_class(config)
245+
model.to(torch_device)
246+
247+
inputs = self._prepare_for_class(inputs_dict, model_class)
248+
249+
outputs = model(**inputs)
250+
251+
output = outputs.feature_maps[0]
252+
253+
# Encoder-/Decoder-only models
254+
hidden_states = outputs.hidden_states[0]
255+
hidden_states.retain_grad()
256+
257+
output.flatten()[0].backward(retain_graph=True)
258+
259+
self.assertIsNotNone(hidden_states.grad)

0 commit comments

Comments
 (0)