-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Add raise_error option to TerminateOnNaN for immediate termination on NaN/Inf losses #21841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add raise_error option to TerminateOnNaN for immediate termination on NaN/Inf losses #21841
Conversation
Summary of ChangesHello @MalyalaKarthik66, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new callback, Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a HardTerminateOnNaN callback designed to immediately halt training by raising a RuntimeError upon encountering a NaN or Inf loss. The goal is to prevent cleanup hooks like on_train_end from executing, thereby preserving the training state for debugging. My review identified a critical issue with this core premise: Keras' training loop utilizes a try...finally block, which ensures on_train_end is always called, even when an exception occurs. This means the callback does not function as described. The tests also appear to be based on this incorrect assumption. I have provided detailed feedback on this fundamental issue. Additionally, I've included a suggestion to improve the test implementation for better consistency with the existing test suite.
| class HardTerminateOnNaN(Callback): | ||
| """Callback that terminates training immediately | ||
| when NaN/Inf loss is detected. | ||
| This callback raises a RuntimeError when a NaN or Inf loss is encountered, | ||
| which immediately stops training without triggering `on_train_end()` hooks | ||
| for other callbacks. This is useful when you want to preserve backup states | ||
| or prevent early stopping from restoring weights after a NaN failure. | ||
| Unlike `TerminateOnNaN`, which gracefully stops training using | ||
| `model.stop_training = True` and triggers all callback cleanup methods, | ||
| `HardTerminateOnNaN` crashes the training loop immediately. | ||
| Example: | ||
| ``` | ||
| callback = keras.callbacks.HardTerminateOnNaN() | ||
| model.fit(x, y, callbacks=[callback]) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self._supports_tf_logs = True | ||
|
|
||
| def on_batch_end(self, batch, logs=None): | ||
| """Check for NaN/Inf loss at the end of each batch. | ||
| Args: | ||
| batch: Integer, index of batch within the current epoch. | ||
| logs: Dict, contains the return value of `model.train_step()`. | ||
| Raises: | ||
| RuntimeError: If loss is NaN or Inf. | ||
| """ | ||
| logs = logs or {} | ||
| loss = logs.get("loss") | ||
| if loss is not None: | ||
| if np.isnan(loss) or np.isinf(loss): | ||
| raise RuntimeError( | ||
| f"NaN or Inf loss encountered at batch {batch}. " | ||
| f"Loss value: {loss}. Terminating training immediately." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The premise of this callback appears to be that raising a RuntimeError will prevent on_train_end hooks from being called. However, the fit() method in Keras' Trainer class wraps the training loop in a try...finally block. This guarantees that callbacks.on_train_end() is executed even when an exception is raised during training.
This means that cleanup logic, such as deleting checkpoints in BackupAndRestore.on_train_end, will still run. This behavior contradicts the main goal of this "hard" termination callback, as stated in the docstring: "which immediately stops training without triggering on_train_end() hooks for other callbacks".
The tests for this callback also seem to reflect some confusion around this behavior:
test_hard_terminate_does_not_trigger_on_train_endasserts thaton_train_endis not called, which seems incorrect given thefitloop's implementation.test_hard_terminate_preserves_backupworks around this by monkeypatchingon_train_end, which sidesteps testing the actual behavior in a real-world scenario where the backup would be deleted.
To achieve the desired behavior, a different mechanism might be necessary. For example, you could introduce a state on the model that other callbacks can check within their on_train_end methods to determine if they should skip their cleanup logic due to a hard termination.
|
|
||
| # Create a fake file in the backup folder | ||
| fake_file = os.path.join(backup_dir, "checkpoint.txt") | ||
| open(fake_file, "w").write("dummy checkpoint") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| assert os.path.exists(backup_dir), ( | ||
| f"Backup dir deleted: {backup_dir}" | ||
| ) | ||
| assert os.path.exists(fake_file), ( | ||
| "Backup file missing unexpectedly." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the testing.TestCase base class (which inherits from unittest.TestCase), it's better to use self.assertTrue() for assertions instead of the native assert statement.
self.assertTrue(os.path.exists(backup_dir), f"Backup dir deleted: {backup_dir}")
self.assertTrue(os.path.exists(fake_file), "Backup file missing unexpectedly.")
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #21841 +/- ##
=======================================
Coverage 82.57% 82.58%
=======================================
Files 577 577
Lines 59586 59592 +6
Branches 9347 9348 +1
=======================================
+ Hits 49205 49212 +7
Misses 7975 7975
+ Partials 2406 2405 -1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Thanks for the PR! This sounds like it should be an option in |
|
@fchollet |
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
| backend.backend() in ["numpy", "openvino"], | ||
| reason="TerminateOnNaN not supported for NumPy or OpenVINO backend", | ||
| ) | ||
| class TerminateOnNaNTest(testing.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some tests in keras/src/callbacks/terminate_on_nan_test.py already.
Please move the new tests there and combine them.
| ``` | ||
| """ | ||
|
|
||
| def __init__(self, hard: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we call the new argument raise_error instead of hard? Hopefully that name makes it obvious how it's different.
…d-terminate-nan
…te_on_nan_test.py
|
@hertschuh |
Fix: #21771
Added a raise_error=True option to keras.callbacks.TerminateOnNaN that raises a RuntimeError immediately when NaN or Inf losses occur.
This initial implementation focuses on strict failure behavior. It intentionally: