Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras/src/utils/rng_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def set_random_seed(seed):

# Store seed in global state so we can query it if set.
global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed)
# Remove global SeedGenerator, it will be recreated from the seed.
global_state.set_global_attribute("global_seed_generator", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability and avoid using 'magic strings', consider defining "global_seed_generator" as a module-level constant, similar to how GLOBAL_RANDOM_SEED is handled. This would make the code easier to read and maintain.

For example, you could add at the top of the file:

GLOBAL_SEED_GENERATOR_KEY = "global_seed_generator"

And then use it here:

global_state.set_global_attribute(GLOBAL_SEED_GENERATOR_KEY, None)

Ideally, this constant would be shared with keras/src/random/seed_generator.py to ensure consistency.

random.seed(seed)
np.random.seed(seed)
if tf.available:
Expand Down
14 changes: 8 additions & 6 deletions keras/src/utils/rng_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import pytest
import tensorflow as tf

import keras
Expand All @@ -9,11 +8,7 @@


class TestRandomSeedSetting(test_case.TestCase):
@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support random seed setting.",
)
def test_set_random_seed(self):
def test_set_random_seed_with_seed_generator(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test was previously skipped for the NumPy backend. By removing the skipif decorator, this test will now run for the NumPy backend. However, the get_model_output function uses tf.data.Dataset, which is specific to TensorFlow and will likely cause the test to fail when run with the NumPy backend.

To ensure this test can run across all backends, you could modify get_model_output to not use tf.data.Dataset. For example:

def get_model_output():
    model = keras.Sequential(
        [
            keras.layers.Dense(10),
            keras.layers.Dropout(0.5),
            keras.layers.Dense(10),
        ]
    )
    x = np.random.random((32, 10)).astype("float32")
    return model.predict(x)

Alternatively, if the intention is to keep this test for TensorFlow-based backends only, the skipif decorator should be restored.

def get_model_output():
model = keras.Sequential(
[
Expand All @@ -31,3 +26,10 @@ def get_model_output():
rng_utils.set_random_seed(42)
y2 = get_model_output()
self.assertAllClose(y1, y2)

def test_set_random_seed_with_global_seed_generator(self):
rng_utils.set_random_seed(42)
y1 = backend.random.randint((32, 10), minval=0, maxval=1000)
rng_utils.set_random_seed(42)
y2 = backend.random.randint((32, 10), minval=0, maxval=1000)
self.assertAllClose(y1, y2)
Loading