Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion keras/src/trainers/data_adapters/py_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class PyDataset:
multiprocessed setting.
Reduce this value to reduce the CPU memory consumption of
your dataset. Defaults to 10.
shuffle: Whether to shuffle the sample ordering at the end of
each epoch.This argument passed to `model.fit()`. when
`model.fit(.., shuffle=True)`, the training loop
automatically calls `on_epoch_end()` at each epoch
boundary, allowing datasets to implement custom
shuffling logic. Defaults to False.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are a few grammatical and formatting issues in the docstring for the shuffle argument. It would be clearer with some corrections for spacing, sentence structure, and consistent code formatting.

Suggested change
shuffle: Whether to shuffle the sample ordering at the end of
each epoch.This argument passed to `model.fit()`. when
`model.fit(.., shuffle=True)`, the training loop
automatically calls `on_epoch_end()` at each epoch
boundary, allowing datasets to implement custom
shuffling logic. Defaults to False.
shuffle: Whether to shuffle the sample ordering at the end of
each epoch. This argument is passed to `model.fit()`. When
`model.fit(..., shuffle=True)`, the training loop
automatically calls `on_epoch_end()` at each epoch
boundary, allowing datasets to implement custom
shuffling logic. Defaults to `False`.


Notes:

Expand All @@ -52,6 +58,9 @@ class PyDataset:
over the dataset. They are not being used by the `PyDataset` class
directly. When you are manually iterating over a `PyDataset`,
no parallelism is applied.
- `shuffle=False` keeps the sample order fixed across epochs.
For distributed or deterministic training prefer
`shuffle=False` and manage the order externally.
Comment on lines +61 to +63
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shuffle argument is not part of PyDataset.__init__, so it should be removed from here. It's just an argument in your example.


Example:

Expand All @@ -66,10 +75,12 @@ class PyDataset:

class CIFAR10PyDataset(keras.utils.PyDataset):

def __init__(self, x_set, y_set, batch_size, **kwargs):
def __init__(self, x_set, y_set, batch_size,shuffle=False, **kwargs):
super().__init__(**kwargs)
self.x, self.y = x_set, y_set
self.batch_size = batch_size
self.shuffle = shuffle
self.indices = np.arrange(len(self.x))

def __len__(self):
# Return number of batches.
Expand All @@ -87,6 +98,12 @@ def __getitem__(self, idx):
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)

def on_epoch_end(self):
# Called automatically by model.fit() when shuffle=True
#
if self.shuffle:
np.random.shuffle(self.indices)
```
"""

Expand Down
Loading