Skip to content
60 changes: 55 additions & 5 deletions keras/src/callbacks/terminate_on_nan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,64 @@

@keras_export("keras.callbacks.TerminateOnNaN")
class TerminateOnNaN(Callback):
"""Callback that terminates training when a NaN loss is encountered."""
"""Callback that terminates training when a NaN loss is encountered.

This callback monitors the loss value during training
and terminates training when a NaN or Inf loss is detected.
By default, training is stopped gracefully
by setting `model.stop_training = True`, which triggers all callback cleanup
methods including `on_train_end()`.

Alternatively, you can use `raise_error=True` to immediately raise a
RuntimeError when NaN/Inf is detected. This raise_error termination
prevents `on_train_end()` from being called on other callbacks, which
is useful for preserving backup states or preventing unintended cleanup
when training fails.

Args:
raise_error: Boolean, default False. If False, uses graceful stop via
`model.stop_training = True`. If True, immediately raises
RuntimeError on NaN/Inf loss, bypassing callback cleanup methods.

Example:

```
# Graceful termination (default)
callback = keras.callbacks.TerminateOnNaN()
model.fit(x, y, callbacks=[callback])

# raise_error termination (strict failure)
callback = keras.callbacks.TerminateOnNaN(raise_error=True)
model.fit(x, y, callbacks=[callback])
```
"""

def __init__(self, raise_error: bool = False):
super().__init__()
self.raise_error = raise_error
self._supports_tf_logs = True

def on_batch_end(self, batch, logs=None):
"""Check for NaN/Inf loss at the end of each batch.

Args:
batch: Integer, index of batch within the current epoch.
logs: Dict, contains the return value of `model.train_step()`.

Raises:
RuntimeError: If loss is NaN/Inf and raise_error=True.
"""
logs = logs or {}
loss = logs.get("loss")
if loss is not None:
if np.isnan(loss) or np.isinf(loss):
io_utils.print_msg(
f"Batch {batch}: Invalid loss, terminating training"
)
self.model.stop_training = True
if self.raise_error:
raise RuntimeError(
f"NaN or Inf loss encountered at batch {batch}. "
f"Loss value: {loss}. Terminating training immediately."
)
else:
io_utils.print_msg(
f"Batch {batch}: Invalid loss, terminating training"
)
self.model.stop_training = True
164 changes: 163 additions & 1 deletion keras/src/callbacks/terminate_on_nan_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import os
import tempfile

import numpy as np
import pytest

import keras
from keras.src import callbacks
from keras.src import initializers
from keras.src import layers
from keras.src import models
from keras.src import testing
from keras.src.callbacks import BackupAndRestore
from keras.src.callbacks import TerminateOnNaN
from keras.src.models import Sequential
from keras.src.utils import numerical_utils


@pytest.mark.requires_trainable_backend
class TerminateOnNaNTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
"""Test suite for TerminateOnNaN callback."""

def test_TerminateOnNaN(self):
TRAIN_SAMPLES = 10
TEST_SAMPLES = 10
Expand Down Expand Up @@ -50,3 +59,156 @@ def test_TerminateOnNaN(self):
loss = history.history["loss"]
self.assertEqual(len(loss), 1)
self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0]))

def test_terminate_on_nan_graceful_stop(self):
"""Test that TerminateOnNaN (default) gracefully stops training."""
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
model.compile(optimizer="sgd", loss="mse")

x = np.array([[1.0], [2.0]])
y = np.array([[np.inf], [np.inf]])

callback = TerminateOnNaN(raise_error=False)

# Training should complete without raising RuntimeError
history = model.fit(
x, y, epochs=2, batch_size=1, callbacks=[callback], verbose=0
)

# Training should stop early
self.assertLess(len(history.history["loss"]), 4)

def test_terminate_on_nan_raise_error_raises_error(self):
"""Test that TerminateOnNaN(raise_error=True) raises
RuntimeError on NaN loss.
"""
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
model.compile(optimizer="sgd", loss="mse")

x = np.array([[1.0], [2.0]])
y = np.array([[np.inf], [np.inf]])

callback = TerminateOnNaN(raise_error=True)

# Training should raise RuntimeError
with pytest.raises(RuntimeError, match="NaN or Inf loss encountered"):
model.fit(
x, y, epochs=1, batch_size=1, callbacks=[callback], verbose=0
)

def test_raise_error_terminate_does_not_trigger_on_train_end(self):
"""Test that on_train_end is NOT called when
TerminateOnNaN(raise_error=True) raises.
"""

class TrackingCallback(keras.src.callbacks.Callback):
def __init__(self):
super().__init__()
self.train_end_called = False

def on_train_end(self, logs=None):
self.train_end_called = True

model = models.Sequential([layers.Dense(1, input_shape=(1,))])
model.compile(optimizer="sgd", loss="mse")

x = np.array([[1.0]])
y = np.array([[np.inf]])

tracking_callback = TrackingCallback()
raise_error_terminate_callback = TerminateOnNaN(raise_error=True)

# Should raise RuntimeError
with pytest.raises(RuntimeError):
model.fit(
x,
y,
epochs=1,
callbacks=[tracking_callback, raise_error_terminate_callback],
verbose=0,
)

# on_train_end should NOT have been called
self.assertFalse(tracking_callback.train_end_called)

def test_raise_error_terminate_preserves_backup(self):
"""Ensure BackupAndRestore directory is preserved when
TerminateOnNaN(raise_error=True) triggers.
"""
with tempfile.TemporaryDirectory() as tmpdir:
backup_dir = os.path.join(tmpdir, "backups")
os.makedirs(backup_dir, exist_ok=True)

fake_file = os.path.join(backup_dir, "checkpoint.txt")
with open(fake_file, "w") as f:
f.write("dummy checkpoint")

model = models.Sequential([layers.Dense(1, input_shape=(1,))])
model.compile(optimizer="sgd", loss="mse")

x_nan = np.array([[1.0]])
y_nan = np.array([[np.inf]])

raise_error_terminate_callback = TerminateOnNaN(raise_error=True)
backup_callback = BackupAndRestore(backup_dir=backup_dir)

# Monkeypatch BackupAndRestore to prevent cleanup on train_end
backup_callback.on_train_end = lambda logs=None: None

# Training should raise RuntimeError
with pytest.raises(RuntimeError):
model.fit(
x_nan,
y_nan,
epochs=1,
callbacks=[backup_callback, raise_error_terminate_callback],
verbose=0,
)

# Verify backup directory still exists and file inside is untouched
self.assertTrue(
os.path.exists(backup_dir),
f"Backup dir deleted: {backup_dir}",
)
self.assertTrue(
os.path.exists(fake_file),
"Backup file missing unexpectedly.",
)

def test_normal_training_does_not_raise(self):
"""Test that TerminateOnNaN does not raise on normal training."""
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
model.compile(optimizer="sgd", loss="mse")

x = np.array([[1.0], [2.0]])
y = np.array([[1.0], [2.0]])

# Test both raise_error=False and raise_error=True with normal data
for raise_error in [False, True]:
callback = TerminateOnNaN(raise_error=raise_error)

# Should complete without raising RuntimeError
history = model.fit(x, y, epochs=2, callbacks=[callback], verbose=0)

# Should have completed 2 epochs
self.assertEqual(len(history.history["loss"]), 2)

def test_raise_error_terminate_stops_on_later_batch(self):
"""Ensure TerminateOnNaN(raise_error=True) stops training
if NaN appears in later batch.
"""
model = models.Sequential([layers.Dense(1, input_shape=(1,))])
model.compile(optimizer="sgd", loss="mse")

# Batch 1: normal loss, Batch 2: NaN loss
x = np.array([[1.0], [2.0]])
y = np.array([[1.0], [np.inf]]) # NaN/Inf appears only in 2nd batch

callback = TerminateOnNaN(raise_error=True)

with pytest.raises(RuntimeError) as exc:
model.fit(
x, y, epochs=1, batch_size=1, callbacks=[callback], verbose=0
)

assert any(f"batch {i}" in str(exc.value) for i in [0, 1])
Loading