Skip to content

Commit f81e366

Browse files
authored
[Data] [stats] Add RefBundle retrieval time metric to iterator dataset stats (#58422)
## Why These Changes Are Needed This PR adds a new metric to track the time spent retrieving `RefBundle` objects during dataset iteration. This metric provides better visibility into the performance breakdown of batch iteration, specifically capturing the time spent in `get_next_ref_bundle()` calls within the `prefetch_batches_locally` function. ## Related Issue Number N/A ## Example ``` dataloader/train = {'producer_throughput': 8361.841782656593, 'iter_stats': {'prefetch_block-avg': inf, 'prefetch_block-min': inf, 'prefetch_block-max': 0, 'prefetch_block-total': 0, 'get_ref_bundles-avg': 0.05172277254545271, 'get_ref_bundles-min': 1.1991999997462699e-05, 'get_ref_bundles-max': 11.057470971999976, 'get_ref_bundles-total': 15.361663445999454, 'fetch_block-avg': 0.31572694455743233, 'fetch_block-min': 0.0006362799999806157, 'fetch_block-max': 2.1665870369999993, 'fetch_block-total': 93.45517558899996, 'block_to_batch-avg': 0.001048687573988573, 'block_to_batch-min': 2.10620000302697e-05, 'block_to_batch-max': 0.049948245999985375, 'block_to_batch-total': 2.048086831999683, 'format_batch-avg': 0.0001013781433686053, 'format_batch-min': 1.415700000961806e-05, 'format_batch-max': 0.009682661999988795, 'format_batch-total': 0.19799151399888615, 'collate-avg': 0.01303446213312943, 'collate-min': 0.00025646699998560507, 'collate-max': 0.9855495820000328, 'collate-total': 25.456304546001775, 'finalize-avg': 0.012211385266257683, 'finalize-min': 0.004209667999987232, 'finalize-max': 0.3785081949999949, 'finalize-total': 23.848835425001255, 'time_spent_blocked-avg': 0.04783407008137157, 'time_spent_blocked-min': 1.2316999971062614e-05, 'time_spent_blocked-max': 12.46102861700001, 'time_spent_blocked-total': 93.46777293900004, 'time_spent_training-avg': 0.015053571562211652, 'time_spent_training-min': 1.3704999958008557e-05, 'time_spent_training-max': 1.079616685000019, 'time_spent_training-total': 29.399625260999358}} ``` ## Checks - [ ] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [ ] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: xgui <[email protected]> Signed-off-by: Xinyuan <[email protected]>
1 parent dbb3909 commit f81e366

File tree

4 files changed

+41
-2
lines changed

4 files changed

+41
-2
lines changed

python/ray/data/_internal/block_batching/iter_batches.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def _prefetch_blocks(
146146
num_batches_to_prefetch=self._prefetch_batches,
147147
batch_size=self._batch_size,
148148
eager_free=self._eager_free,
149+
stats=self._stats,
149150
)
150151

151152
def _resolve_block_refs(
@@ -322,6 +323,7 @@ def prefetch_batches_locally(
322323
num_batches_to_prefetch: int,
323324
batch_size: Optional[int],
324325
eager_free: bool = False,
326+
stats: Optional[DatasetStats] = None,
325327
) -> Iterator[ObjectRef[Block]]:
326328
"""Given an iterator of batched RefBundles, returns an iterator over the
327329
corresponding block references while prefetching `num_batches_to_prefetch`
@@ -334,8 +336,13 @@ def prefetch_batches_locally(
334336
current batch during the scan.
335337
batch_size: User specified batch size, or None to let the system pick.
336338
eager_free: Whether to eagerly free the object reference from the object store.
339+
stats: Dataset stats object used to store ref bundle retrieval time.
337340
"""
338341

342+
def get_next_ref_bundle() -> RefBundle:
343+
with stats.iter_get_ref_bundles_s.timer() if stats else nullcontext():
344+
return next(ref_bundles)
345+
339346
sliding_window = collections.deque()
340347
current_window_size = 0
341348

@@ -358,7 +365,7 @@ def prefetch_batches_locally(
358365
batch_size is None and len(sliding_window) < num_batches_to_prefetch
359366
):
360367
try:
361-
next_ref_bundle = next(ref_bundles)
368+
next_ref_bundle = get_next_ref_bundle()
362369
sliding_window.extend(next_ref_bundle.blocks)
363370
current_window_size += next_ref_bundle.num_rows()
364371
except StopIteration:
@@ -371,7 +378,7 @@ def prefetch_batches_locally(
371378
current_window_size -= metadata.num_rows
372379
if batch_size is None or current_window_size < num_rows_to_prefetch:
373380
try:
374-
next_ref_bundle = next(ref_bundles)
381+
next_ref_bundle = get_next_ref_bundle()
375382
for block_ref_and_md in next_ref_bundle.blocks:
376383
sliding_window.append(block_ref_and_md)
377384
current_window_size += block_ref_and_md[1].num_rows

python/ray/data/_internal/stats.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,11 @@ def __init__(self, max_stats=1000):
327327
description="Seconds spent in iterator initialization code",
328328
tag_keys=iter_tag_keys,
329329
)
330+
self.iter_get_ref_bundles_s = Gauge(
331+
"data_iter_get_ref_bundles_seconds",
332+
description="Seconds spent getting RefBundles from the dataset iterator",
333+
tag_keys=iter_tag_keys,
334+
)
330335
self.iter_get_s = Gauge(
331336
"data_iter_get_seconds",
332337
description="Seconds spent in ray.get() while resolving block references",
@@ -565,6 +570,7 @@ def update_iteration_metrics(
565570
tags = self._create_tags(dataset_tag)
566571

567572
self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags)
573+
self.iter_get_ref_bundles_s.set(stats.iter_get_ref_bundles_s.get(), tags)
568574
self.iter_get_s.set(stats.iter_get_s.get(), tags)
569575
self.iter_next_batch_s.set(stats.iter_next_batch_s.get(), tags)
570576
self.iter_format_batch_s.set(stats.iter_format_batch_s.get(), tags)
@@ -1098,6 +1104,7 @@ def __init__(
10981104

10991105
# Iteration stats, filled out if the user iterates over the dataset.
11001106
self.iter_wait_s: Timer = Timer()
1107+
self.iter_get_ref_bundles_s: Timer = Timer()
11011108
self.iter_get_s: Timer = Timer()
11021109
self.iter_next_batch_s: Timer = Timer()
11031110
self.iter_format_batch_s: Timer = Timer()
@@ -1146,6 +1153,7 @@ def to_summary(self) -> "DatasetStatsSummary":
11461153

11471154
iter_stats = IterStatsSummary(
11481155
self.iter_wait_s,
1156+
self.iter_get_ref_bundles_s,
11491157
self.iter_get_s,
11501158
self.iter_next_batch_s,
11511159
self.iter_format_batch_s,
@@ -1843,6 +1851,8 @@ def __repr__(self, level=0) -> str:
18431851
class IterStatsSummary:
18441852
# Time spent in actor based prefetching, in seconds.
18451853
wait_time: Timer
1854+
# Time spent getting RefBundles from the dataset iterator, in seconds
1855+
get_ref_bundles_time: Timer
18461856
# Time spent in `ray.get()`, in seconds
18471857
get_time: Timer
18481858
# Time spent in batch building, in seconds
@@ -1880,6 +1890,7 @@ def to_string(self) -> str:
18801890
self.block_time.get()
18811891
or self.time_to_first_batch.get()
18821892
or self.total_time.get()
1893+
or self.get_ref_bundles_time.get()
18831894
or self.get_time.get()
18841895
or self.next_time.get()
18851896
or self.format_time.get()
@@ -1911,6 +1922,13 @@ def to_string(self) -> str:
19111922
out += (
19121923
"* Batch iteration time breakdown (summed across prefetch threads):\n"
19131924
)
1925+
if self.get_ref_bundles_time.get():
1926+
out += " * In get RefBundles: {} min, {} max, {} avg, {} total\n".format(
1927+
fmt(self.get_ref_bundles_time.min()),
1928+
fmt(self.get_ref_bundles_time.max()),
1929+
fmt(self.get_ref_bundles_time.avg()),
1930+
fmt(self.get_ref_bundles_time.get()),
1931+
)
19141932
if self.get_time.get():
19151933
out += " * In ray.get(): {} min, {} max, {} avg, {} total\n".format(
19161934
fmt(self.get_time.min()),
@@ -1973,6 +1991,7 @@ def __repr__(self, level=0) -> str:
19731991
return (
19741992
f"IterStatsSummary(\n"
19751993
f"{indent} wait_time={fmt(self.wait_time.get()) or None},\n"
1994+
f"{indent} get_ref_bundles_time={fmt(self.get_ref_bundles_time.get()) or None},\n"
19761995
f"{indent} get_time={fmt(self.get_time.get()) or None},\n"
19771996
f"{indent} iter_blocks_local={self.iter_blocks_local or None},\n"
19781997
f"{indent} iter_blocks_remote={self.iter_blocks_remote or None},\n"

python/ray/data/tests/test_stats.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ def test_streaming_split_stats(ray_start_regular_shared, restore_data_context):
501501
* Total time spent waiting for the first batch after starting iteration: T
502502
* Total execution time for user thread: T
503503
* Batch iteration time breakdown (summed across prefetch threads):
504+
* In get RefBundles: T min, T max, T avg, T total
504505
* In ray.get(): T min, T max, T avg, T total
505506
* In batch creation: T min, T max, T avg, T total
506507
* In batch formatting: T min, T max, T avg, T total
@@ -696,6 +697,7 @@ def test_dataset_stats_basic(
696697
f" * Total time spent waiting for the first batch after starting iteration: T\n"
697698
f" * Total execution time for user thread: T\n"
698699
f"* Batch iteration time breakdown (summed across prefetch threads):\n"
700+
f" * In get RefBundles: T min, T max, T avg, T total\n"
699701
f" * In ray.get(): T min, T max, T avg, T total\n"
700702
f" * In batch creation: T min, T max, T avg, T total\n"
701703
f" * In batch formatting: T min, T max, T avg, T total\n"
@@ -740,6 +742,7 @@ def test_block_location_nums(ray_start_regular_shared, restore_data_context):
740742
f" * Total time spent waiting for the first batch after starting iteration: T\n"
741743
f" * Total execution time for user thread: T\n"
742744
f"* Batch iteration time breakdown (summed across prefetch threads):\n"
745+
f" * In get RefBundles: T min, T max, T avg, T total\n"
743746
f" * In ray.get(): T min, T max, T avg, T total\n"
744747
f" * In batch creation: T min, T max, T avg, T total\n"
745748
f" * In batch formatting: T min, T max, T avg, T total\n"
@@ -845,6 +848,7 @@ def test_dataset__repr__(ray_start_regular_shared, restore_data_context):
845848
" ],\n"
846849
" iter_stats=IterStatsSummary(\n"
847850
" wait_time=T,\n"
851+
" get_ref_bundles_time=T,\n"
848852
" get_time=T,\n"
849853
" iter_blocks_local=None,\n"
850854
" iter_blocks_remote=None,\n"
@@ -866,6 +870,7 @@ def test_dataset__repr__(ray_start_regular_shared, restore_data_context):
866870
" operators_stats=[],\n"
867871
" iter_stats=IterStatsSummary(\n"
868872
" wait_time=T,\n"
873+
" get_ref_bundles_time=T,\n"
869874
" get_time=T,\n"
870875
" iter_blocks_local=None,\n"
871876
" iter_blocks_remote=None,\n"
@@ -985,6 +990,7 @@ def check_stats():
985990
" ],\n"
986991
" iter_stats=IterStatsSummary(\n"
987992
" wait_time=T,\n"
993+
" get_ref_bundles_time=T,\n"
988994
" get_time=T,\n"
989995
" iter_blocks_local=None,\n"
990996
" iter_blocks_remote=None,\n"
@@ -1080,6 +1086,7 @@ def check_stats():
10801086
" ],\n"
10811087
" iter_stats=IterStatsSummary(\n"
10821088
" wait_time=T,\n"
1089+
" get_ref_bundles_time=T,\n"
10831090
" get_time=T,\n"
10841091
" iter_blocks_local=None,\n"
10851092
" iter_blocks_remote=None,\n"
@@ -1101,6 +1108,7 @@ def check_stats():
11011108
" operators_stats=[],\n"
11021109
" iter_stats=IterStatsSummary(\n"
11031110
" wait_time=T,\n"
1111+
" get_ref_bundles_time=T,\n"
11041112
" get_time=T,\n"
11051113
" iter_blocks_local=None,\n"
11061114
" iter_blocks_remote=None,\n"
@@ -1537,6 +1545,7 @@ def test_streaming_stats_full(ray_start_regular_shared, restore_data_context):
15371545
* Total time spent waiting for the first batch after starting iteration: T
15381546
* Total execution time for user thread: T
15391547
* Batch iteration time breakdown (summed across prefetch threads):
1548+
* In get RefBundles: T min, T max, T avg, T total
15401549
* In ray.get(): T min, T max, T avg, T total
15411550
* In batch creation: T min, T max, T avg, T total
15421551
* In batch formatting: T min, T max, T avg, T total

release/train_tests/benchmark/ray_dataloader_factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def get_metrics(self) -> Dict[str, Any]:
123123
"prefetch_block-min": iter_stats.wait_time.min(),
124124
"prefetch_block-max": iter_stats.wait_time.max(),
125125
"prefetch_block-total": iter_stats.wait_time.get(),
126+
"get_ref_bundles-avg": iter_stats.get_ref_bundles_time.avg(),
127+
"get_ref_bundles-min": iter_stats.get_ref_bundles_time.min(),
128+
"get_ref_bundles-max": iter_stats.get_ref_bundles_time.max(),
129+
"get_ref_bundles-total": iter_stats.get_ref_bundles_time.get(),
126130
"fetch_block-avg": iter_stats.get_time.avg(),
127131
"fetch_block-min": iter_stats.get_time.min(),
128132
"fetch_block-max": iter_stats.get_time.max(),

0 commit comments

Comments
 (0)