Skip to content

Commit 37119a9

Browse files
committed
p2p shuffle without pyArrow
1 parent 3f210fd commit 37119a9

File tree

15 files changed

+136
-549
lines changed

15 files changed

+136
-549
lines changed

distributed/shuffle/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from __future__ import annotations
22

3-
from distributed.shuffle._arrow import check_minimal_arrow_version
43
from distributed.shuffle._merge import HashJoinP2PLayer, hash_join_p2p
54
from distributed.shuffle._rechunk import rechunk_p2p
65
from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin
76
from distributed.shuffle._shuffle import P2PShuffleLayer, rearrange_by_column_p2p
87
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
98

109
__all__ = [
11-
"check_minimal_arrow_version",
1210
"hash_join_p2p",
1311
"HashJoinP2PLayer",
1412
"P2PShuffleLayer",

distributed/shuffle/_arrow.py

Lines changed: 0 additions & 201 deletions
This file was deleted.

distributed/shuffle/_core.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from dataclasses import dataclass, field
2020
from enum import Enum
2121
from functools import partial
22-
from pathlib import Path
2322
from typing import TYPE_CHECKING, Any, Generic, NewType, TypeVar, cast
2423

2524
from tornado.ioloop import IOLoop
@@ -38,6 +37,7 @@
3837
from distributed.shuffle._exceptions import ShuffleClosedError
3938
from distributed.shuffle._limiter import ResourceLimiter
4039
from distributed.shuffle._memory import MemoryShardsBuffer
40+
from distributed.sizeof import safe_sizeof as sizeof
4141
from distributed.utils import run_in_executor_with_context, sync
4242
from distributed.utils_comm import retry
4343

@@ -116,11 +116,10 @@ def __init__(
116116
if disk:
117117
self._disk_buffer = DiskShardsBuffer(
118118
directory=directory,
119-
read=self.read,
120119
memory_limiter=memory_limiter_disk,
121120
)
122121
else:
123-
self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize)
122+
self._disk_buffer = MemoryShardsBuffer()
124123

125124
with self._capture_metrics("background-comms"):
126125
self._comm_buffer = CommShardsBuffer(
@@ -372,14 +371,6 @@ def _get_output_partition(
372371
) -> _T_partition_type:
373372
"""Get an output partition to the shuffle run"""
374373

375-
@abc.abstractmethod
376-
def read(self, path: Path) -> tuple[Any, int]:
377-
"""Read shards from disk"""
378-
379-
@abc.abstractmethod
380-
def deserialize(self, buffer: Any) -> Any:
381-
"""Deserialize shards"""
382-
383374

384375
def get_worker_plugin() -> ShuffleWorkerPlugin:
385376
from distributed import get_worker
@@ -518,7 +509,7 @@ def _mean_shard_size(shards: Iterable) -> int:
518509
if not isinstance(shard, int):
519510
# This also asserts that shard is a Buffer and that we didn't forget
520511
# a container or metadata type above
521-
size += memoryview(shard).nbytes
512+
size += sizeof(shard)
522513
count += 1
523514
if count == 10:
524515
break

distributed/shuffle/_disk.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
from __future__ import annotations
22

33
import contextlib
4+
import mmap
45
import pathlib
56
import shutil
67
import threading
7-
from collections.abc import Callable, Generator, Iterable
8+
from collections.abc import Generator, Iterator
89
from contextlib import contextmanager
10+
from pathlib import Path
911
from typing import Any
1012

11-
from toolz import concat
12-
1313
from distributed.metrics import context_meter, thread_time
1414
from distributed.shuffle._buffer import ShardsBuffer
1515
from distributed.shuffle._exceptions import DataUnavailable
1616
from distributed.shuffle._limiter import ResourceLimiter
17-
from distributed.shuffle._pickle import pickle_bytelist
18-
from distributed.utils import Deadline, empty_context, log_errors, nbytes
17+
from distributed.shuffle._pickle import pickle_bytelist, unpickle_bytestream
18+
from distributed.utils import Deadline, log_errors, nbytes
1919

2020

2121
class ReadWriteLock:
@@ -126,7 +126,6 @@ class DiskShardsBuffer(ShardsBuffer):
126126
def __init__(
127127
self,
128128
directory: str | pathlib.Path,
129-
read: Callable[[pathlib.Path], tuple[Any, int]],
130129
memory_limiter: ResourceLimiter,
131130
):
132131
super().__init__(
@@ -137,11 +136,10 @@ def __init__(
137136
self.directory = pathlib.Path(directory)
138137
self.directory.mkdir(exist_ok=True)
139138
self._closed = False
140-
self._read = read
141139
self._directory_lock = ReadWriteLock()
142140

143141
@log_errors
144-
async def _process(self, id: str, shards: list[Any]) -> None:
142+
async def _process(self, id: str, shards: list[object]) -> None:
145143
"""Write one buffer to file
146144
147145
This function was built to offload the disk IO, but since then we've
@@ -154,36 +152,30 @@ async def _process(self, id: str, shards: list[Any]) -> None:
154152
future then we should consider simplifying this considerably and
155153
dropping the write into communicate above.
156154
"""
157-
frames: Iterable[bytes | bytearray | memoryview]
158-
if isinstance(shards[0], bytes):
159-
# Manually serialized dataframes
160-
frames = shards
161-
serialize_meter_ctx: Any = empty_context
162-
else:
163-
# Unserialized numpy arrays
164-
# Note: no calls to pickle_bytelist will happen until we actually start
165-
# writing to disk below.
166-
frames = concat(pickle_bytelist(shard) for shard in shards)
167-
serialize_meter_ctx = context_meter.meter("serialize", func=thread_time)
155+
nbytes_acc = 0
156+
157+
def pickle_and_tally() -> Iterator[bytes | memoryview]:
158+
nonlocal nbytes_acc
159+
for shard in shards:
160+
for frame in pickle_bytelist(shard):
161+
nbytes_acc += nbytes(frame)
162+
yield frame
168163

169164
with (
170165
self._directory_lock.read(),
171166
context_meter.meter("disk-write"),
172-
serialize_meter_ctx,
167+
context_meter.meter("serialize", func=thread_time),
173168
):
174-
# Consider boosting total_size a bit here to account for duplication
175-
# We only need shared (i.e., read) access to the directory to write
176-
# to a file inside of it.
177169
if self._closed:
178170
raise RuntimeError("Already closed")
179171

180172
with open(self.directory / str(id), mode="ab") as f:
181-
f.writelines(frames)
173+
f.writelines(pickle_and_tally())
182174

183175
context_meter.digest_metric("disk-write", 1, "count")
184-
context_meter.digest_metric("disk-write", sum(map(nbytes, frames)), "bytes")
176+
context_meter.digest_metric("disk-write", nbytes_acc, "bytes")
185177

186-
def read(self, id: str) -> Any:
178+
def read(self, id: str) -> list[Any]:
187179
"""Read a complete file back into memory"""
188180
self.raise_on_exception()
189181
if not self._inputs_done:
@@ -210,6 +202,24 @@ def read(self, id: str) -> Any:
210202
else:
211203
raise DataUnavailable(id)
212204

205+
@staticmethod
206+
def _read(path: Path) -> tuple[list[Any], int]:
207+
"""Open a memory-mapped file descriptor to disk, read all metadata, and unpickle
208+
all arrays. This is a fast sequence of short reads interleaved with seeks.
209+
Do not read in memory the actual data; the arrays' buffers will point to the
210+
memory-mapped area.
211+
212+
The file descriptor will be automatically closed by the kernel when all the
213+
returned arrays are dereferenced, which will happen after the call to
214+
concatenate3.
215+
"""
216+
with path.open(mode="r+b") as fh:
217+
buffer = memoryview(mmap.mmap(fh.fileno(), 0))
218+
219+
# The file descriptor has *not* been closed!
220+
shards = list(unpickle_bytestream(buffer))
221+
return shards, buffer.nbytes
222+
213223
async def close(self) -> None:
214224
await super().close()
215225
with self._directory_lock.write():

0 commit comments

Comments
 (0)