@@ -271,6 +271,8 @@ def get_file_idx_to_row_range(
271271 lengths : List [int ],
272272 rank : int ,
273273 world_size : int ,
274+ start_row : int = 0 ,
275+ last_row : Optional [int ] = None ,
274276 ) -> Dict [int , Tuple [int , int ]]:
275277 """
276278 Given a rank, world_size, and the lengths (number of rows) for a list of files,
@@ -296,14 +298,26 @@ def get_file_idx_to_row_range(
296298 # All ..._g variables are globals indices (meaning they range from 0 to
297299 # total_length - 1). All ..._l variables are local indices (meaning they range
298300 # from 0 to lengths[i] - 1 for the ith file).
299-
300- total_length = sum (lengths )
301+ if last_row is None :
302+ total_length = sum (lengths ) - start_row
303+ else :
304+ total_length = last_row - start_row + 1
301305 rows_per_rank = total_length // world_size
306+ remainder = total_length % world_size
302307
303308 # Global indices that rank is responsible for. All ranges (left, right) are
304309 # inclusive.
305- rank_left_g = rank * rows_per_rank
306- rank_right_g = (rank + 1 ) * rows_per_rank - 1
310+ if rank < remainder :
311+ rank_left_g = rank * (rows_per_rank + 1 )
312+ rank_right_g = (rank + 1 ) * (rows_per_rank + 1 ) - 1
313+ else :
314+ rank_left_g = (
315+ remainder * (rows_per_rank + 1 ) + (rank - remainder ) * rows_per_rank
316+ )
317+ rank_right_g = rank_left_g + rows_per_rank - 1
318+
319+ rank_left_g += start_row
320+ rank_right_g += start_row
307321
308322 output = {}
309323
@@ -734,34 +748,31 @@ def __init__(
734748 }
735749
736750 def _load_data_for_rank (self ) -> None :
737- if self .stage == "train" :
738- file_idx_to_row_range = BinaryCriteoUtils .get_file_idx_to_row_range (
739- lengths = [
740- BinaryCriteoUtils .get_shape_from_npy (
741- path , path_manager_key = self .path_manager_key
742- )[0 ]
743- for path in self .dense_paths
744- ],
745- rank = self .rank ,
746- world_size = self .world_size ,
747- )
748- elif self .stage in ["val" , "test" ]:
751+ start_row , last_row = 0 , None
752+ if self .stage in ["val" , "test" ]:
749753 # Last day's dataset is split into 2 sets: 1st half for "val"; 2nd for "test"
750754 samples_in_file = BinaryCriteoUtils .get_shape_from_npy (
751755 self .dense_paths [0 ], path_manager_key = self .path_manager_key
752756 )[0 ]
753-
754- dataset_start = 0
757+ start_row = 0
755758 dataset_len = int (np .ceil (samples_in_file / 2.0 ))
756-
757759 if self .stage == "test" :
758- dataset_start = dataset_len
759- dataset_len = samples_in_file - dataset_len
760- segment_len = dataset_len // self .world_size
761- rank_start_row = dataset_start + self .rank * segment_len
762-
763- rank_last_row = rank_start_row + segment_len - 1
764- file_idx_to_row_range = {0 : (rank_start_row , rank_last_row )}
760+ start_row = dataset_len
761+ dataset_len = samples_in_file - start_row
762+ last_row = start_row + dataset_len - 1
763+
764+ file_idx_to_row_range = BinaryCriteoUtils .get_file_idx_to_row_range (
765+ lengths = [
766+ BinaryCriteoUtils .get_shape_from_npy (
767+ path , path_manager_key = self .path_manager_key
768+ )[0 ]
769+ for path in self .dense_paths
770+ ],
771+ rank = self .rank ,
772+ world_size = self .world_size ,
773+ start_row = start_row ,
774+ last_row = last_row ,
775+ )
765776
766777 self .dense_arrs , self .sparse_arrs , self .labels_arrs = [], [], []
767778 for arrs , paths in zip (
0 commit comments