Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions python/ray/data/_internal/block_batching/iter_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def _prefetch_blocks(
num_batches_to_prefetch=self._prefetch_batches,
batch_size=self._batch_size,
eager_free=self._eager_free,
stats=self._stats,
)

def _resolve_block_refs(
Expand Down Expand Up @@ -322,6 +323,7 @@ def prefetch_batches_locally(
num_batches_to_prefetch: int,
batch_size: Optional[int],
eager_free: bool = False,
stats: Optional[DatasetStats] = None,
) -> Iterator[ObjectRef[Block]]:
"""Given an iterator of batched RefBundles, returns an iterator over the
corresponding block references while prefetching `num_batches_to_prefetch`
Expand All @@ -334,8 +336,13 @@ def prefetch_batches_locally(
current batch during the scan.
batch_size: User specified batch size, or None to let the system pick.
eager_free: Whether to eagerly free the object reference from the object store.
stats: Dataset stats object used to store ref bundle retrieval time.
"""

def get_next_ref_bundle() -> RefBundle:
with stats.iter_get_ref_bundles_s.timer() if stats else nullcontext():
return next(ref_bundles)

sliding_window = collections.deque()
current_window_size = 0

Expand All @@ -358,7 +365,7 @@ def prefetch_batches_locally(
batch_size is None and len(sliding_window) < num_batches_to_prefetch
):
try:
next_ref_bundle = next(ref_bundles)
next_ref_bundle = get_next_ref_bundle()
sliding_window.extend(next_ref_bundle.blocks)
current_window_size += next_ref_bundle.num_rows()
except StopIteration:
Expand All @@ -371,7 +378,7 @@ def prefetch_batches_locally(
current_window_size -= metadata.num_rows
if batch_size is None or current_window_size < num_rows_to_prefetch:
try:
next_ref_bundle = next(ref_bundles)
next_ref_bundle = get_next_ref_bundle()
for block_ref_and_md in next_ref_bundle.blocks:
sliding_window.append(block_ref_and_md)
current_window_size += block_ref_and_md[1].num_rows
Expand Down
19 changes: 19 additions & 0 deletions python/ray/data/_internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,11 @@ def __init__(self, max_stats=1000):
description="Seconds spent in iterator initialization code",
tag_keys=iter_tag_keys,
)
self.iter_get_ref_bundles_s = Gauge(
"data_iter_get_ref_bundles_seconds",
description="Seconds spent getting RefBundles from the dataset iterator",
tag_keys=iter_tag_keys,
)
self.iter_get_s = Gauge(
"data_iter_get_seconds",
description="Seconds spent in ray.get() while resolving block references",
Expand Down Expand Up @@ -565,6 +570,7 @@ def update_iteration_metrics(
tags = self._create_tags(dataset_tag)

self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags)
self.iter_get_ref_bundles_s.set(stats.iter_get_ref_bundles_s.get(), tags)
self.iter_get_s.set(stats.iter_get_s.get(), tags)
self.iter_next_batch_s.set(stats.iter_next_batch_s.get(), tags)
self.iter_format_batch_s.set(stats.iter_format_batch_s.get(), tags)
Expand Down Expand Up @@ -1098,6 +1104,7 @@ def __init__(

# Iteration stats, filled out if the user iterates over the dataset.
self.iter_wait_s: Timer = Timer()
self.iter_get_ref_bundles_s: Timer = Timer()
self.iter_get_s: Timer = Timer()
self.iter_next_batch_s: Timer = Timer()
self.iter_format_batch_s: Timer = Timer()
Expand Down Expand Up @@ -1146,6 +1153,7 @@ def to_summary(self) -> "DatasetStatsSummary":

iter_stats = IterStatsSummary(
self.iter_wait_s,
self.iter_get_ref_bundles_s,
self.iter_get_s,
self.iter_next_batch_s,
self.iter_format_batch_s,
Expand Down Expand Up @@ -1843,6 +1851,8 @@ def __repr__(self, level=0) -> str:
class IterStatsSummary:
# Time spent in actor based prefetching, in seconds.
wait_time: Timer
# Time spent getting RefBundles from the dataset iterator, in seconds
get_ref_bundles_time: Timer
# Time spent in `ray.get()`, in seconds
get_time: Timer
# Time spent in batch building, in seconds
Expand Down Expand Up @@ -1880,6 +1890,7 @@ def to_string(self) -> str:
self.block_time.get()
or self.time_to_first_batch.get()
or self.total_time.get()
or self.get_ref_bundles_time.get()
or self.get_time.get()
or self.next_time.get()
or self.format_time.get()
Expand Down Expand Up @@ -1911,6 +1922,13 @@ def to_string(self) -> str:
out += (
"* Batch iteration time breakdown (summed across prefetch threads):\n"
)
if self.get_ref_bundles_time.get():
out += " * In get RefBundles: {} min, {} max, {} avg, {} total\n".format(
fmt(self.get_ref_bundles_time.min()),
fmt(self.get_ref_bundles_time.max()),
fmt(self.get_ref_bundles_time.avg()),
fmt(self.get_ref_bundles_time.get()),
)
if self.get_time.get():
out += " * In ray.get(): {} min, {} max, {} avg, {} total\n".format(
fmt(self.get_time.min()),
Expand Down Expand Up @@ -1973,6 +1991,7 @@ def __repr__(self, level=0) -> str:
return (
f"IterStatsSummary(\n"
f"{indent} wait_time={fmt(self.wait_time.get()) or None},\n"
f"{indent} get_ref_bundles_time={fmt(self.get_ref_bundles_time.get()) or None},\n"
f"{indent} get_time={fmt(self.get_time.get()) or None},\n"
f"{indent} iter_blocks_local={self.iter_blocks_local or None},\n"
f"{indent} iter_blocks_remote={self.iter_blocks_remote or None},\n"
Expand Down
9 changes: 9 additions & 0 deletions python/ray/data/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def test_streaming_split_stats(ray_start_regular_shared, restore_data_context):
* Total time spent waiting for the first batch after starting iteration: T
* Total execution time for user thread: T
* Batch iteration time breakdown (summed across prefetch threads):
* In get RefBundles: T min, T max, T avg, T total
* In ray.get(): T min, T max, T avg, T total
* In batch creation: T min, T max, T avg, T total
* In batch formatting: T min, T max, T avg, T total
Expand Down Expand Up @@ -696,6 +697,7 @@ def test_dataset_stats_basic(
f" * Total time spent waiting for the first batch after starting iteration: T\n"
f" * Total execution time for user thread: T\n"
f"* Batch iteration time breakdown (summed across prefetch threads):\n"
f" * In get RefBundles: T min, T max, T avg, T total\n"
f" * In ray.get(): T min, T max, T avg, T total\n"
f" * In batch creation: T min, T max, T avg, T total\n"
f" * In batch formatting: T min, T max, T avg, T total\n"
Expand Down Expand Up @@ -740,6 +742,7 @@ def test_block_location_nums(ray_start_regular_shared, restore_data_context):
f" * Total time spent waiting for the first batch after starting iteration: T\n"
f" * Total execution time for user thread: T\n"
f"* Batch iteration time breakdown (summed across prefetch threads):\n"
f" * In get RefBundles: T min, T max, T avg, T total\n"
f" * In ray.get(): T min, T max, T avg, T total\n"
f" * In batch creation: T min, T max, T avg, T total\n"
f" * In batch formatting: T min, T max, T avg, T total\n"
Expand Down Expand Up @@ -845,6 +848,7 @@ def test_dataset__repr__(ray_start_regular_shared, restore_data_context):
" ],\n"
" iter_stats=IterStatsSummary(\n"
" wait_time=T,\n"
" get_ref_bundles_time=T,\n"
" get_time=T,\n"
" iter_blocks_local=None,\n"
" iter_blocks_remote=None,\n"
Expand All @@ -866,6 +870,7 @@ def test_dataset__repr__(ray_start_regular_shared, restore_data_context):
" operators_stats=[],\n"
" iter_stats=IterStatsSummary(\n"
" wait_time=T,\n"
" get_ref_bundles_time=T,\n"
" get_time=T,\n"
" iter_blocks_local=None,\n"
" iter_blocks_remote=None,\n"
Expand Down Expand Up @@ -985,6 +990,7 @@ def check_stats():
" ],\n"
" iter_stats=IterStatsSummary(\n"
" wait_time=T,\n"
" get_ref_bundles_time=T,\n"
" get_time=T,\n"
" iter_blocks_local=None,\n"
" iter_blocks_remote=None,\n"
Expand Down Expand Up @@ -1080,6 +1086,7 @@ def check_stats():
" ],\n"
" iter_stats=IterStatsSummary(\n"
" wait_time=T,\n"
" get_ref_bundles_time=T,\n"
" get_time=T,\n"
" iter_blocks_local=None,\n"
" iter_blocks_remote=None,\n"
Expand All @@ -1101,6 +1108,7 @@ def check_stats():
" operators_stats=[],\n"
" iter_stats=IterStatsSummary(\n"
" wait_time=T,\n"
" get_ref_bundles_time=T,\n"
" get_time=T,\n"
" iter_blocks_local=None,\n"
" iter_blocks_remote=None,\n"
Expand Down Expand Up @@ -1537,6 +1545,7 @@ def test_streaming_stats_full(ray_start_regular_shared, restore_data_context):
* Total time spent waiting for the first batch after starting iteration: T
* Total execution time for user thread: T
* Batch iteration time breakdown (summed across prefetch threads):
* In get RefBundles: T min, T max, T avg, T total
* In ray.get(): T min, T max, T avg, T total
* In batch creation: T min, T max, T avg, T total
* In batch formatting: T min, T max, T avg, T total
Expand Down
4 changes: 4 additions & 0 deletions release/train_tests/benchmark/ray_dataloader_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def get_metrics(self) -> Dict[str, Any]:
"prefetch_block-min": iter_stats.wait_time.min(),
"prefetch_block-max": iter_stats.wait_time.max(),
"prefetch_block-total": iter_stats.wait_time.get(),
"get_ref_bundles-avg": iter_stats.get_ref_bundles_time.avg(),
"get_ref_bundles-min": iter_stats.get_ref_bundles_time.min(),
"get_ref_bundles-max": iter_stats.get_ref_bundles_time.max(),
"get_ref_bundles-total": iter_stats.get_ref_bundles_time.get(),
"fetch_block-avg": iter_stats.get_time.avg(),
"fetch_block-min": iter_stats.get_time.min(),
"fetch_block-max": iter_stats.get_time.max(),
Expand Down