-
Notifications
You must be signed in to change notification settings - Fork 19.7k
📝docs: clarify shuffle behavior and example in PyDataset #21847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
646c3cd
16058c4
fa549b9
78ca394
54cae65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,12 @@ class PyDataset: | |
| multiprocessed setting. | ||
| Reduce this value to reduce the CPU memory consumption of | ||
| your dataset. Defaults to 10. | ||
| shuffle: Whether to shuffle the sample ordering at the end of | ||
| each epoch. This argument is passed to `model.fit()`. When | ||
| calling `model.fit(..., shuffle=True)`, the training loop | ||
| automatically calls `on_epoch_end()` at each epoch | ||
| boundary, allowing datasets to implement custom | ||
| shuffling logic. Defaults to `False`. | ||
|
|
||
| Notes: | ||
|
|
||
|
|
@@ -52,6 +58,9 @@ class PyDataset: | |
| over the dataset. They are not being used by the `PyDataset` class | ||
| directly. When you are manually iterating over a `PyDataset`, | ||
| no parallelism is applied. | ||
| - `shuffle=False` keeps the sample order fixed across epochs. | ||
| For distributed or deterministic training prefer | ||
| `shuffle=False` and manage the order externally. | ||
|
Comment on lines
+61
to
+63
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| Example: | ||
|
|
||
|
|
@@ -66,10 +75,17 @@ class PyDataset: | |
|
|
||
| class CIFAR10PyDataset(keras.utils.PyDataset): | ||
|
|
||
| def __init__(self, x_set, y_set, batch_size, **kwargs): | ||
| def __init__(self, x_set, y_set, batch_size, shuffle=False, **kwargs): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The But in Keras, we add shuffling via
I guess my question is, what is the reason to add |
||
| super().__init__(**kwargs) | ||
| self.x, self.y = x_set, y_set | ||
| self.batch_size = batch_size | ||
| self.shuffle = shuffle | ||
| # create index array for shuffling | ||
| self.indices = np.arange(len(self.x)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the example, also do |
||
| # Shuffle once at initialization when shuffle=True | ||
| if self.shuffle: | ||
| np.random.shuffle(self.indices) | ||
|
|
||
|
|
||
| def __len__(self): | ||
| # Return number of batches. | ||
|
|
@@ -81,12 +97,19 @@ def __getitem__(self, idx): | |
| # Cap upper bound at array length; the last batch may be smaller | ||
| # if the total number of items is not a multiple of batch size. | ||
| high = min(low + self.batch_size, len(self.x)) | ||
| batch_x = self.x[low:high] | ||
| batch_y = self.y[low:high] | ||
| # Retrieve a batch using shuffled indices | ||
| batch_indices = self.indices[low:high] | ||
| batch_x = self.x[batch_indices] | ||
| batch_y = self.y[batch_indices] | ||
|
|
||
| return np.array([ | ||
| resize(imread(file_name), (200, 200)) | ||
| for file_name in batch_x]), np.array(batch_y) | ||
|
|
||
| def on_epoch_end(self): | ||
| # Shuffle indices at the end of each epoch if enabled | ||
| if self.shuffle: | ||
| np.random.shuffle(self.indices) | ||
| ``` | ||
| """ | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
shuffleargument is not part ofPyDataset.__init__, so it should be removed from here. It's just an argument in your example.