Skip to content

Conversation

@utsab345
Copy link
Contributor

Fix RandomCrop validation behavior and edge cases

Fixes #21868

Problem Description

The RandomCrop layer had two critical issues:

  1. During validation/inference, the layer did nothing - Images were returned unchanged instead of being center-cropped, causing shape mismatches between training and validation pipelines.

  2. Strict inequality prevented valid crops - The condition input_height > self.height and input_width > self.width prevented random cropping when image dimensions exactly matched the target size (e.g., cropping 512x512 to 256x512).

Changes Made

Core Fixes in random_crop.py:

  • get_random_transformation: Changed condition from input_height > self.height and input_width > self.width to input_height >= self.height and input_width >= self.width to allow cropping when dimensions exactly match
  • transform_images: Removed if training: condition to ensure cropping is applied during both training and validation
  • Consistent transformation: Applied transformations to bounding boxes and segmentation masks during validation as well

Enhanced Testing in random_crop_test.py:

  • Added test_validation_center_crop() to verify validation mode performs center cropping
  • Added test_edge_case_exact_dimensions() to test cropping when image dimensions exactly match target size

Behavior After Fix

  • Training: Applies random crops when image is large enough
  • Validation/Inference: Always applies center cropping (was previously doing nothing)
  • Edge Cases: Properly handles exact dimension matches and smaller images

Testing

  • ✅ All existing tests pass
  • ✅ New tests verify the specific fixes
  • ✅ No breaking changes introduced
  • ✅ Maintains backward compatibility

Impact

  • Resolves shape mismatch errors between training and validation
  • Enables proper use cases like cropping 512x512 images to 256x512
  • Ensures consistent behavior across training and inference pipelines

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @utsab345, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request primarily focuses on enhancing the robustness and correctness of the RandomCrop preprocessing layer. It resolves critical issues where the layer failed to perform any cropping during validation and incorrectly handled images with dimensions exactly matching the target crop size. The changes ensure consistent center cropping during validation and proper application of transformations for bounding boxes and segmentation masks across all operational modes. Additionally, it includes improvements to mask handling within the JAX backend's flash attention implementation, making it more resilient to various input mask configurations.

Highlights

  • RandomCrop Validation Fix: The RandomCrop layer now correctly performs center cropping during validation and inference, resolving previous issues where images were returned unchanged, leading to shape mismatches between training and validation pipelines.
  • Edge Case Handling for RandomCrop: The condition for applying random cropping has been updated from a strict inequality (>) to an inclusive one (>=), allowing the RandomCrop layer to correctly handle cases where image dimensions exactly match the target crop size (e.g., cropping 512x512 to 256x512).
  • Consistent Transformations: Transformations for bounding boxes and segmentation masks are now consistently applied during both training and validation, ensuring uniform behavior across different operational modes.
  • Enhanced Testing for RandomCrop: New tests (test_validation_center_crop and test_edge_case_exact_dimensions) have been added to specifically verify the corrected validation behavior and the handling of exact dimension matches in the RandomCrop layer.
  • JAX Flash Attention Mask Handling: The JAX backend's wrap_flash_attention function has been updated to improve how custom and causal masks are handled, including robust shape validation, broadcasting for multi-head attention, and consistent application of MultiHeadMask.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively resolves two critical issues within the RandomCrop layer. The changes ensure that center-cropping is correctly applied during validation/inference, fixing previous shape mismatches, and also corrects the logic to handle cases where input dimensions exactly match the target crop size. The new tests added are valuable for verifying this corrected behavior. I have one suggestion to further strengthen the new validation test by asserting the content of the cropped image, not just its shape.

Comment on lines +174 to +200
# Create a test image with distinct corners
if backend.config.image_data_format() == "channels_last":
test_image = np.zeros((4, 4, 3))
# Mark corners with different values
test_image[0, 0] = [1, 0, 0] # Top-left red
test_image[0, 3] = [0, 1, 0] # Top-right green
test_image[3, 0] = [0, 0, 1] # Bottom-left blue
test_image[3, 3] = [1, 1, 0] # Bottom-right yellow
else:
test_image = np.zeros((3, 4, 4))
# Mark corners with different values
test_image[0, 0, 0] = 1 # Top-left red
test_image[1, 0, 3] = 1 # Top-right green
test_image[2, 3, 0] = 1 # Bottom-left blue
test_image[0, 3, 3] = 1 # Bottom-right yellow (red channel)
test_image[1, 3, 3] = 1 # Bottom-right yellow (green channel)

# Test validation mode (should center crop)
validation_output = layer(test_image, training=False)

# Center crop should capture the middle 2x2 region
expected_shape = (
(2, 2, 3)
if backend.config.image_data_format() == "channels_last"
else (3, 2, 2)
)
self.assertEqual(validation_output.shape, expected_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test currently verifies that validation mode produces an output of the correct shape, which is a great check for the fix. To make it more robust, I suggest also asserting that the content of the output is correct. This can be done by creating a test image with a distinct pattern in the center and then verifying that the output of the layer matches that pattern. This will confirm that the layer is performing a center crop, not just any crop of the correct size.

Suggested change
# Create a test image with distinct corners
if backend.config.image_data_format() == "channels_last":
test_image = np.zeros((4, 4, 3))
# Mark corners with different values
test_image[0, 0] = [1, 0, 0] # Top-left red
test_image[0, 3] = [0, 1, 0] # Top-right green
test_image[3, 0] = [0, 0, 1] # Bottom-left blue
test_image[3, 3] = [1, 1, 0] # Bottom-right yellow
else:
test_image = np.zeros((3, 4, 4))
# Mark corners with different values
test_image[0, 0, 0] = 1 # Top-left red
test_image[1, 0, 3] = 1 # Top-right green
test_image[2, 3, 0] = 1 # Bottom-left blue
test_image[0, 3, 3] = 1 # Bottom-right yellow (red channel)
test_image[1, 3, 3] = 1 # Bottom-right yellow (green channel)
# Test validation mode (should center crop)
validation_output = layer(test_image, training=False)
# Center crop should capture the middle 2x2 region
expected_shape = (
(2, 2, 3)
if backend.config.image_data_format() == "channels_last"
else (3, 2, 2)
)
self.assertEqual(validation_output.shape, expected_shape)
# Create a test image with a distinct center
if backend.config.image_data_format() == "channels_last":
test_image = np.zeros((4, 4, 3))
# Center 2x2 region with ones
test_image[1:3, 1:3, :] = 1.0
else:
test_image = np.zeros((3, 4, 4))
# Center 2x2 region with ones
test_image[:, 1:3, 1:3] = 1.0
# Test validation mode (should center crop)
validation_output = layer(test_image, training=False)
# Expected output is the center 2x2 region of ones
if backend.config.image_data_format() == "channels_last":
expected_output = np.ones((2, 2, 3))
else:
expected_output = np.ones((3, 2, 2))
self.assertAllClose(validation_output, expected_output)

@codecov-commenter
Copy link

codecov-commenter commented Nov 26, 2025

Codecov Report

❌ Patch coverage is 56.89655% with 25 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.95%. Comparing base (8287e48) to head (4b94a3b).
⚠️ Report is 2 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/jax/nn.py 0.00% 17 Missing ⚠️
...s/preprocessing/image_preprocessing/random_crop.py 80.00% 4 Missing and 4 partials ⚠️

❗ There is a different number of reports uploaded between BASE (8287e48) and HEAD (4b94a3b). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (8287e48) HEAD (4b94a3b)
keras 5 4
keras-torch 1 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21871      +/-   ##
==========================================
- Coverage   82.57%   76.95%   -5.63%     
==========================================
  Files         577      577              
  Lines       59586    59614      +28     
  Branches     9347     9356       +9     
==========================================
- Hits        49205    45873    -3332     
- Misses       7975    11409    +3434     
+ Partials     2406     2332      -74     
Flag Coverage Δ
keras 76.85% <56.89%> (-5.55%) ⬇️
keras-jax 62.86% <56.89%> (-0.02%) ⬇️
keras-numpy 57.51% <56.89%> (-0.02%) ⬇️
keras-openvino 34.33% <0.00%> (-0.02%) ⬇️
keras-tensorflow 64.39% <56.89%> (-0.02%) ⬇️
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ma7555
Copy link
Contributor

ma7555 commented Nov 28, 2025

@utsab345, please check original issue of whether to allow the choice of center crop or normal resize for validation images before merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RandomCrop does nothing at validation/inference

4 participants