Skip to content

Commit 13e9aa3

Browse files
committed
fix: replaced hardcoded value by variable
1 parent e82f6d5 commit 13e9aa3

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/transformers/models/lw_detr/modeling_lw_detr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,7 +1644,7 @@ def forward(self, pixel_values: torch.Tensor = None, **kwargs: Unpack[Transforme
16441644

16451645
window_height = height // self.config.num_windows_side
16461646
window_width = width // self.config.num_windows_side
1647-
# (batch_size, height, width, channels) -> (batch_size*16, window_height*window_width, channels)
1647+
# (batch_size, height, width, channels) -> (batch_size*num_windows_side**2, window_height*window_width, channels)
16481648
hidden_states = (
16491649
hidden_states.reshape(
16501650
batch_size,
@@ -1655,7 +1655,7 @@ def forward(self, pixel_values: torch.Tensor = None, **kwargs: Unpack[Transforme
16551655
channels,
16561656
)
16571657
.permute(0, 1, 3, 2, 4, 5)
1658-
.reshape(batch_size * 16, window_height * window_width, channels)
1658+
.reshape(batch_size * self.config.num_windows_side**2, window_height * window_width, channels)
16591659
)
16601660

16611661
hidden_states = self.encoder(hidden_states, **kwargs)

src/transformers/models/lw_detr/modular_lw_detr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,7 +1576,7 @@ def forward(self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwarg
15761576

15771577
window_height = height // self.config.num_windows_side
15781578
window_width = width // self.config.num_windows_side
1579-
# (batch_size, height, width, channels) -> (batch_size*16, window_height*window_width, channels)
1579+
# (batch_size, height, width, channels) -> (batch_size*num_windows_side**2, window_height*window_width, channels)
15801580
hidden_states = (
15811581
hidden_states.reshape(
15821582
batch_size,
@@ -1587,7 +1587,7 @@ def forward(self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwarg
15871587
channels,
15881588
)
15891589
.permute(0, 1, 3, 2, 4, 5)
1590-
.reshape(batch_size * 16, window_height * window_width, channels)
1590+
.reshape(batch_size * self.config.num_windows_side**2, window_height * window_width, channels)
15911591
)
15921592

15931593
hidden_states = self.encoder(hidden_states, **kwargs)

0 commit comments

Comments
 (0)