Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,12 +821,6 @@ def split_to_tiles(images: "torch.Tensor", num_tiles_height: int, num_tiles_widt
return image


def _cast_tensor_to_float(x):
if x.is_floating_point():
return x
return x.float()


def _group_images_by_shape(nested_images, *paired_inputs, is_nested: bool = False):
"""Helper function to flatten a single level of nested image and batch structures and group by shape."""
grouped_images = defaultdict(list)
Expand All @@ -843,7 +837,7 @@ def _group_images_by_shape(nested_images, *paired_inputs, is_nested: bool = Fals
for i, (sublist, *paired_sublists) in enumerate(zip(normalized_images, *normalized_paired)):
for j, (image, *paired_values) in enumerate(zip(sublist, *paired_sublists)):
key = (i, j) if is_nested else j
shape = image.shape[1:]
shape = image.shape

# Add to grouped structures
grouped_images[shape].append(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def _preprocess(
processed_images_grouped = {}
for shape in grouped_images:
stacked_images = grouped_images[shape]
stacked_trimaps = grouped_trimaps[shape]
trimaps_shape = torch.Size([1, *shape[1:]]) # Trimaps have single channel
stacked_trimaps = grouped_trimaps[trimaps_shape]
# Fused rescale and normalize
stacked_images = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
Expand Down