Skip to content

Commit d279c56

Browse files
author
ylgh
committed
Merge remote-tracking branch 'origin/main' into v0.3.0
2 parents 8e4a255 + daabdf4 commit d279c56

File tree

2 files changed

+56
-45
lines changed

2 files changed

+56
-45
lines changed

torchrec/datasets/criteo.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ def get_file_idx_to_row_range(
271271
lengths: List[int],
272272
rank: int,
273273
world_size: int,
274+
start_row: int = 0,
275+
last_row: Optional[int] = None,
274276
) -> Dict[int, Tuple[int, int]]:
275277
"""
276278
Given a rank, world_size, and the lengths (number of rows) for a list of files,
@@ -296,14 +298,26 @@ def get_file_idx_to_row_range(
296298
# All ..._g variables are globals indices (meaning they range from 0 to
297299
# total_length - 1). All ..._l variables are local indices (meaning they range
298300
# from 0 to lengths[i] - 1 for the ith file).
299-
300-
total_length = sum(lengths)
301+
if last_row is None:
302+
total_length = sum(lengths) - start_row
303+
else:
304+
total_length = last_row - start_row + 1
301305
rows_per_rank = total_length // world_size
306+
remainder = total_length % world_size
302307

303308
# Global indices that rank is responsible for. All ranges (left, right) are
304309
# inclusive.
305-
rank_left_g = rank * rows_per_rank
306-
rank_right_g = (rank + 1) * rows_per_rank - 1
310+
if rank < remainder:
311+
rank_left_g = rank * (rows_per_rank + 1)
312+
rank_right_g = (rank + 1) * (rows_per_rank + 1) - 1
313+
else:
314+
rank_left_g = (
315+
remainder * (rows_per_rank + 1) + (rank - remainder) * rows_per_rank
316+
)
317+
rank_right_g = rank_left_g + rows_per_rank - 1
318+
319+
rank_left_g += start_row
320+
rank_right_g += start_row
307321

308322
output = {}
309323

@@ -734,34 +748,31 @@ def __init__(
734748
}
735749

736750
def _load_data_for_rank(self) -> None:
737-
if self.stage == "train":
738-
file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range(
739-
lengths=[
740-
BinaryCriteoUtils.get_shape_from_npy(
741-
path, path_manager_key=self.path_manager_key
742-
)[0]
743-
for path in self.dense_paths
744-
],
745-
rank=self.rank,
746-
world_size=self.world_size,
747-
)
748-
elif self.stage in ["val", "test"]:
751+
start_row, last_row = 0, None
752+
if self.stage in ["val", "test"]:
749753
# Last day's dataset is split into 2 sets: 1st half for "val"; 2nd for "test"
750754
samples_in_file = BinaryCriteoUtils.get_shape_from_npy(
751755
self.dense_paths[0], path_manager_key=self.path_manager_key
752756
)[0]
753-
754-
dataset_start = 0
757+
start_row = 0
755758
dataset_len = int(np.ceil(samples_in_file / 2.0))
756-
757759
if self.stage == "test":
758-
dataset_start = dataset_len
759-
dataset_len = samples_in_file - dataset_len
760-
segment_len = dataset_len // self.world_size
761-
rank_start_row = dataset_start + self.rank * segment_len
762-
763-
rank_last_row = rank_start_row + segment_len - 1
764-
file_idx_to_row_range = {0: (rank_start_row, rank_last_row)}
760+
start_row = dataset_len
761+
dataset_len = samples_in_file - start_row
762+
last_row = start_row + dataset_len - 1
763+
764+
file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range(
765+
lengths=[
766+
BinaryCriteoUtils.get_shape_from_npy(
767+
path, path_manager_key=self.path_manager_key
768+
)[0]
769+
for path in self.dense_paths
770+
],
771+
rank=self.rank,
772+
world_size=self.world_size,
773+
start_row=start_row,
774+
last_row=last_row,
775+
)
765776

766777
self.dense_arrs, self.sparse_arrs, self.labels_arrs = [], [], []
767778
for arrs, paths in zip(

torchrec/datasets/tests/test_criteo.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import contextlib
9+
import math
910
import os
1011
import random
1112
import tempfile
@@ -239,12 +240,9 @@ def _validate_sparse_to_contiguous_preproc(
239240
input_files, temp_output_dir, freq_threshold, columns
240241
)
241242

242-
output_files = list(
243-
map(
244-
lambda f: os.path.join(temp_output_dir, f),
245-
os.listdir(temp_output_dir),
246-
)
247-
)
243+
output_files = [
244+
os.path.join(temp_output_dir, f) for f in os.listdir(temp_output_dir)
245+
]
248246
output_files.sort()
249247
for day, file in enumerate(output_files):
250248
processed_data = np.load(file)
@@ -280,9 +278,9 @@ def test_shuffle(self) -> None:
280278
labels_data = [np.array([[i], [i + 3], [i + 6]]) for i in range(3)]
281279

282280
def save_data_list(data: List[np.ndarray], data_type: str) -> None:
283-
for day, data in enumerate(data):
281+
for day, data_ in enumerate(data):
284282
file = os.path.join(temp_input_dir, f"day_{day}_{data_type}.npy")
285-
np.save(file, data)
283+
np.save(file, data_)
286284

287285
save_data_list(dense_data, "dense")
288286
save_data_list(sparse_data, "sparse")
@@ -380,14 +378,14 @@ def _test_dataset(
380378
dataset_start = num_rows // 2 + num_rows % 2
381379
dataset_len = num_rows // 2
382380

383-
incomplete_last_batch_size = dataset_len // world_size % batch_size
384-
num_batches = dataset_len // world_size // batch_size + (
385-
incomplete_last_batch_size != 0
386-
)
387-
388381
lens = []
389-
samples_counts = []
382+
remainder = dataset_len % world_size
390383
for rank in range(world_size):
384+
incomplete_last_batch_size = (
385+
dataset_len // world_size % batch_size + int(rank < remainder)
386+
)
387+
num_samples = dataset_len // world_size + int(rank < remainder)
388+
num_batches = math.ceil(num_samples / batch_size)
391389
datapipe = InMemoryBinaryCriteoIterDataPipe(
392390
stage=stage,
393391
dense_paths=[f[0] for f in files],
@@ -421,12 +419,14 @@ def _test_dataset(
421419
# Check that dataset __len__ matches true length.
422420
self.assertEqual(datapipe_len, len_)
423421
lens.append(len_)
424-
self.assertEqual(samples_count, dataset_len // world_size)
425-
samples_counts.append(samples_count)
422+
self.assertEqual(samples_count, num_samples)
426423

427-
# Ensure all ranks' datapipes return the same number of batches.
428-
self.assertEqual(len(set(lens)), 1)
429-
self.assertEqual(len(set(samples_counts)), 1)
424+
# Ensure all ranks return the correct number of batches.
425+
if remainder > 0:
426+
self.assertEqual(len(set(lens[:remainder])), 1)
427+
self.assertEqual(len(set(lens[remainder:])), 1)
428+
else:
429+
self.assertEqual(len(set(lens)), 1)
430430

431431
def test_dataset_small_files(self) -> None:
432432
self._test_dataset([1] * 20, 4, 2)

0 commit comments

Comments
 (0)