Skip to content
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions keras/src/trainers/data_adapters/py_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class PyDataset:
multiprocessed setting.
Reduce this value to reduce the CPU memory consumption of
your dataset. Defaults to 10.
shuffle: Whether to shuffle the sample ordering at the end of
each epoch. This argument is passed to `model.fit()`. When
`model.fit(..., shuffle=True)`, the training loop
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add verb: "When calling"

automatically calls `on_epoch_end()` at each epoch
boundary, allowing datasets to implement custom
shuffling logic. Defaults to `False`.

Notes:

Expand All @@ -52,6 +58,9 @@ class PyDataset:
over the dataset. They are not being used by the `PyDataset` class
directly. When you are manually iterating over a `PyDataset`,
no parallelism is applied.
- `shuffle=False` keeps the sample order fixed across epochs.
For distributed or deterministic training prefer
`shuffle=False` and manage the order externally.
Comment on lines +61 to +63
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shuffle argument is not part of PyDataset.__init__, so it should be removed from here. It's just an argument in your example.


Example:

Expand All @@ -66,10 +75,12 @@ class PyDataset:

class CIFAR10PyDataset(keras.utils.PyDataset):

def __init__(self, x_set, y_set, batch_size, **kwargs):
def __init__(self, x_set, y_set, batch_size, shuffle=False, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shuffle argument is not part of PyDataset.__init__. I understand that you can do it this way here like in your example.

But in Keras, we add shuffling via PyDatasetAdapter. So it makes it confusing because if you do:

  • CIFAR10PyDataset(shuffle=True) and model.fit(shuffle=False) it will be shuffled
  • CIFAR10PyDataset(shuffle=False) and model.fit(shuffle=True) it will be shuffled

I guess my question is, what is the reason to add shuffle directly in the PyDataset?

super().__init__(**kwargs)
self.x, self.y = x_set, y_set
self.batch_size = batch_size
self.shuffle = shuffle
self.indices = np.arange(len(self.x))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the example, also do np.random.shuffle(self.indices) here when shuffle is True, since the best practice when shuffling is to do it for every epoch, not just epoch >= 1


def __len__(self):
# Return number of batches.
Expand All @@ -81,12 +92,19 @@ def __getitem__(self, idx):
# Cap upper bound at array length; the last batch may be smaller
# if the total number of items is not a multiple of batch size.
high = min(low + self.batch_size, len(self.x))
batch_x = self.x[low:high]
batch_y = self.y[low:high]
# Retrieve a batch of data by index
batch_indices = self.indices[low:high]
batch_x = self.x[batch_indices]
batch_y = self.y[batch_indices]

return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)

def on_epoch_end(self):
# Shuffle indices at the end of each epoch if enabled
if self.shuffle:
np.random.shuffle(self.indices)
```
"""

Expand Down
Loading