From a8d3fa2176ebc88fe831c8bfb687140c20eb9dec Mon Sep 17 00:00:00 2001 From: xgui Date: Wed, 5 Nov 2025 22:51:12 +0000 Subject: [PATCH 1/5] add refbundle time Signed-off-by: xgui --- .../_internal/block_batching/iter_batches.py | 11 +++++++++-- python/ray/data/_internal/stats.py | 19 +++++++++++++++++++ .../benchmark/ray_dataloader_factory.py | 4 ++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 9dc052d12aa..110a58db404 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -141,6 +141,7 @@ def _prefetch_blocks( self, ref_bundles: Iterator[RefBundle] ) -> Iterator[ObjectRef[Block]]: return prefetch_batches_locally( + stats=self._stats, ref_bundles=ref_bundles, prefetcher=self._prefetcher, num_batches_to_prefetch=self._prefetch_batches, @@ -320,6 +321,7 @@ def prefetch_batches_locally( ref_bundles: Iterator[RefBundle], prefetcher: BlockPrefetcher, num_batches_to_prefetch: int, + stats: Optional[DatasetStats], batch_size: Optional[int], eager_free: bool = False, ) -> Iterator[ObjectRef[Block]]: @@ -332,10 +334,15 @@ def prefetch_batches_locally( prefetcher: The prefetcher to use. num_batches_to_prefetch: The number of batches to prefetch ahead of the current batch during the scan. + stats: Dataset stats object used to store ref bundle retrieval time. batch_size: User specified batch size, or None to let the system pick. eager_free: Whether to eagerly free the object reference from the object store. """ + def get_next_ref_bundle() -> RefBundle: + with stats.iter_ref_bundle_retrieval_s.timer() if stats else nullcontext(): + return next(ref_bundles) + sliding_window = collections.deque() current_window_size = 0 @@ -358,7 +365,7 @@ def prefetch_batches_locally( batch_size is None and len(sliding_window) < num_batches_to_prefetch ): try: - next_ref_bundle = next(ref_bundles) + next_ref_bundle = get_next_ref_bundle() sliding_window.extend(next_ref_bundle.blocks) current_window_size += next_ref_bundle.num_rows() except StopIteration: @@ -371,7 +378,7 @@ def prefetch_batches_locally( current_window_size -= metadata.num_rows if batch_size is None or current_window_size < num_rows_to_prefetch: try: - next_ref_bundle = next(ref_bundles) + next_ref_bundle = get_next_ref_bundle() for block_ref_and_md in next_ref_bundle.blocks: sliding_window.append(block_ref_and_md) current_window_size += block_ref_and_md[1].num_rows diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index f41d3434349..f059bd4e569 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -327,6 +327,11 @@ def __init__(self, max_stats=1000): description="Seconds spent in iterator initialization code", tag_keys=iter_tag_keys, ) + self.iter_get_ref_bundles_s = Gauge( + "data_iter_get_ref_bundles_seconds", + description="Seconds spent getting RefBundles from the dataset iterator", + tag_keys=iter_tag_keys, + ) self.iter_get_s = Gauge( "data_iter_get_seconds", description="Seconds spent in ray.get() while resolving block references", @@ -565,6 +570,7 @@ def update_iteration_metrics( tags = self._create_tags(dataset_tag) self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags) + self.iter_get_ref_bundles_s.set(stats.iter_get_ref_bundles_s.get(), tags) self.iter_get_s.set(stats.iter_get_s.get(), tags) self.iter_next_batch_s.set(stats.iter_next_batch_s.get(), tags) self.iter_format_batch_s.set(stats.iter_format_batch_s.get(), tags) @@ -1098,6 +1104,7 @@ def __init__( # Iteration stats, filled out if the user iterates over the dataset. self.iter_wait_s: Timer = Timer() + self.iter_get_ref_bundles_s: Timer = Timer() self.iter_get_s: Timer = Timer() self.iter_next_batch_s: Timer = Timer() self.iter_format_batch_s: Timer = Timer() @@ -1146,6 +1153,7 @@ def to_summary(self) -> "DatasetStatsSummary": iter_stats = IterStatsSummary( self.iter_wait_s, + self.iter_get_ref_bundles_s, self.iter_get_s, self.iter_next_batch_s, self.iter_format_batch_s, @@ -1843,6 +1851,8 @@ def __repr__(self, level=0) -> str: class IterStatsSummary: # Time spent in actor based prefetching, in seconds. wait_time: Timer + # Time spent getting RefBundles from the dataset iterator, in seconds + get_ref_bundles_time: Timer # Time spent in `ray.get()`, in seconds get_time: Timer # Time spent in batch building, in seconds @@ -1880,6 +1890,7 @@ def to_string(self) -> str: self.block_time.get() or self.time_to_first_batch.get() or self.total_time.get() + or self.get_ref_bundles_time.get() or self.get_time.get() or self.next_time.get() or self.format_time.get() @@ -1911,6 +1922,13 @@ def to_string(self) -> str: out += ( "* Batch iteration time breakdown (summed across prefetch threads):\n" ) + if self.get_ref_bundles_time.get(): + out += " * In get RefBundles: {} min, {} max, {} avg, {} total\n".format( + fmt(self.get_ref_bundles_time.min()), + fmt(self.get_ref_bundles_time.max()), + fmt(self.get_ref_bundles_time.avg()), + fmt(self.get_ref_bundles_time.get()), + ) if self.get_time.get(): out += " * In ray.get(): {} min, {} max, {} avg, {} total\n".format( fmt(self.get_time.min()), @@ -1973,6 +1991,7 @@ def __repr__(self, level=0) -> str: return ( f"IterStatsSummary(\n" f"{indent} wait_time={fmt(self.wait_time.get()) or None},\n" + f"{indent} get_ref_bundles_time={fmt(self.get_ref_bundles_time.get()) or None},\n" f"{indent} get_time={fmt(self.get_time.get()) or None},\n" f"{indent} iter_blocks_local={self.iter_blocks_local or None},\n" f"{indent} iter_blocks_remote={self.iter_blocks_remote or None},\n" diff --git a/release/train_tests/benchmark/ray_dataloader_factory.py b/release/train_tests/benchmark/ray_dataloader_factory.py index e3a7c7dd8ec..62f4e51aa67 100644 --- a/release/train_tests/benchmark/ray_dataloader_factory.py +++ b/release/train_tests/benchmark/ray_dataloader_factory.py @@ -123,6 +123,10 @@ def get_metrics(self) -> Dict[str, Any]: "prefetch_block-min": iter_stats.wait_time.min(), "prefetch_block-max": iter_stats.wait_time.max(), "prefetch_block-total": iter_stats.wait_time.get(), + "get_ref_bundles-avg": iter_stats.get_ref_bundles_time.avg(), + "get_ref_bundles-min": iter_stats.get_ref_bundles_time.min(), + "get_ref_bundles-max": iter_stats.get_ref_bundles_time.max(), + "get_ref_bundles-total": iter_stats.get_ref_bundles_time.get(), "fetch_block-avg": iter_stats.get_time.avg(), "fetch_block-min": iter_stats.get_time.min(), "fetch_block-max": iter_stats.get_time.max(), From 8ea8c38cb80f229da30aacdc7c0e4907301fc876 Mon Sep 17 00:00:00 2001 From: xgui Date: Thu, 6 Nov 2025 04:21:02 +0000 Subject: [PATCH 2/5] remove one test Signed-off-by: xgui --- python/ray/data/_internal/block_batching/iter_batches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 110a58db404..17b447ad1b5 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -340,7 +340,7 @@ def prefetch_batches_locally( """ def get_next_ref_bundle() -> RefBundle: - with stats.iter_ref_bundle_retrieval_s.timer() if stats else nullcontext(): + with stats.iter_get_ref_bundles_s.timer() if stats else nullcontext(): return next(ref_bundles) sliding_window = collections.deque() From 614580de28107b6fc0a068ffcf6b958589b3465e Mon Sep 17 00:00:00 2001 From: xgui Date: Thu, 6 Nov 2025 04:32:03 +0000 Subject: [PATCH 3/5] fix unittest Signed-off-by: xgui --- python/ray/data/tests/test_stats.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 3a66d0be8c0..7db5d2e4e59 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -501,6 +501,7 @@ def test_streaming_split_stats(ray_start_regular_shared, restore_data_context): * Total time spent waiting for the first batch after starting iteration: T * Total execution time for user thread: T * Batch iteration time breakdown (summed across prefetch threads): + * In get RefBundles: T min, T max, T avg, T total * In ray.get(): T min, T max, T avg, T total * In batch creation: T min, T max, T avg, T total * In batch formatting: T min, T max, T avg, T total @@ -696,6 +697,7 @@ def test_dataset_stats_basic( f" * Total time spent waiting for the first batch after starting iteration: T\n" f" * Total execution time for user thread: T\n" f"* Batch iteration time breakdown (summed across prefetch threads):\n" + f" * In get RefBundles: T min, T max, T avg, T total\n" f" * In ray.get(): T min, T max, T avg, T total\n" f" * In batch creation: T min, T max, T avg, T total\n" f" * In batch formatting: T min, T max, T avg, T total\n" @@ -740,6 +742,7 @@ def test_block_location_nums(ray_start_regular_shared, restore_data_context): f" * Total time spent waiting for the first batch after starting iteration: T\n" f" * Total execution time for user thread: T\n" f"* Batch iteration time breakdown (summed across prefetch threads):\n" + f" * In get RefBundles: T min, T max, T avg, T total\n" f" * In ray.get(): T min, T max, T avg, T total\n" f" * In batch creation: T min, T max, T avg, T total\n" f" * In batch formatting: T min, T max, T avg, T total\n" @@ -845,6 +848,7 @@ def test_dataset__repr__(ray_start_regular_shared, restore_data_context): " ],\n" " iter_stats=IterStatsSummary(\n" " wait_time=T,\n" + " get_ref_bundles_time=T,\n" " get_time=T,\n" " iter_blocks_local=None,\n" " iter_blocks_remote=None,\n" @@ -866,6 +870,7 @@ def test_dataset__repr__(ray_start_regular_shared, restore_data_context): " operators_stats=[],\n" " iter_stats=IterStatsSummary(\n" " wait_time=T,\n" + " get_ref_bundles_time=T,\n" " get_time=T,\n" " iter_blocks_local=None,\n" " iter_blocks_remote=None,\n" @@ -985,6 +990,7 @@ def check_stats(): " ],\n" " iter_stats=IterStatsSummary(\n" " wait_time=T,\n" + " get_ref_bundles_time=T,\n" " get_time=T,\n" " iter_blocks_local=None,\n" " iter_blocks_remote=None,\n" @@ -1080,6 +1086,7 @@ def check_stats(): " ],\n" " iter_stats=IterStatsSummary(\n" " wait_time=T,\n" + " get_ref_bundles_time=T,\n" " get_time=T,\n" " iter_blocks_local=None,\n" " iter_blocks_remote=None,\n" @@ -1101,6 +1108,7 @@ def check_stats(): " operators_stats=[],\n" " iter_stats=IterStatsSummary(\n" " wait_time=T,\n" + " get_ref_bundles_time=T,\n" " get_time=T,\n" " iter_blocks_local=None,\n" " iter_blocks_remote=None,\n" @@ -1537,6 +1545,7 @@ def test_streaming_stats_full(ray_start_regular_shared, restore_data_context): * Total time spent waiting for the first batch after starting iteration: T * Total execution time for user thread: T * Batch iteration time breakdown (summed across prefetch threads): + * In get RefBundles: T min, T max, T avg, T total * In ray.get(): T min, T max, T avg, T total * In batch creation: T min, T max, T avg, T total * In batch formatting: T min, T max, T avg, T total From 64f52f1980148c9b3a96c11362819a41a61cdc69 Mon Sep 17 00:00:00 2001 From: Xinyuan <43737116+xinyuangui2@users.noreply.github.com> Date: Wed, 5 Nov 2025 20:37:14 -0800 Subject: [PATCH 4/5] Set default value for stats parameter in prefetch_batches_locally Signed-off-by: Xinyuan <43737116+xinyuangui2@users.noreply.github.com> --- python/ray/data/_internal/block_batching/iter_batches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index 17b447ad1b5..a1a51611d87 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -321,7 +321,7 @@ def prefetch_batches_locally( ref_bundles: Iterator[RefBundle], prefetcher: BlockPrefetcher, num_batches_to_prefetch: int, - stats: Optional[DatasetStats], + stats: Optional[DatasetStats] = None, batch_size: Optional[int], eager_free: bool = False, ) -> Iterator[ObjectRef[Block]]: From b77e37148736187cdedc5fe2dbdf00fb313e439b Mon Sep 17 00:00:00 2001 From: xgui Date: Thu, 6 Nov 2025 06:51:13 +0000 Subject: [PATCH 5/5] fix parameter Signed-off-by: xgui --- python/ray/data/_internal/block_batching/iter_batches.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/ray/data/_internal/block_batching/iter_batches.py b/python/ray/data/_internal/block_batching/iter_batches.py index a1a51611d87..599e8c767f1 100644 --- a/python/ray/data/_internal/block_batching/iter_batches.py +++ b/python/ray/data/_internal/block_batching/iter_batches.py @@ -141,12 +141,12 @@ def _prefetch_blocks( self, ref_bundles: Iterator[RefBundle] ) -> Iterator[ObjectRef[Block]]: return prefetch_batches_locally( - stats=self._stats, ref_bundles=ref_bundles, prefetcher=self._prefetcher, num_batches_to_prefetch=self._prefetch_batches, batch_size=self._batch_size, eager_free=self._eager_free, + stats=self._stats, ) def _resolve_block_refs( @@ -321,9 +321,9 @@ def prefetch_batches_locally( ref_bundles: Iterator[RefBundle], prefetcher: BlockPrefetcher, num_batches_to_prefetch: int, - stats: Optional[DatasetStats] = None, batch_size: Optional[int], eager_free: bool = False, + stats: Optional[DatasetStats] = None, ) -> Iterator[ObjectRef[Block]]: """Given an iterator of batched RefBundles, returns an iterator over the corresponding block references while prefetching `num_batches_to_prefetch` @@ -334,9 +334,9 @@ def prefetch_batches_locally( prefetcher: The prefetcher to use. num_batches_to_prefetch: The number of batches to prefetch ahead of the current batch during the scan. - stats: Dataset stats object used to store ref bundle retrieval time. batch_size: User specified batch size, or None to let the system pick. eager_free: Whether to eagerly free the object reference from the object store. + stats: Dataset stats object used to store ref bundle retrieval time. """ def get_next_ref_bundle() -> RefBundle: