Skip to content

Commit 8a732a7

Browse files
committed
Don't send columns for every shard
1 parent b9725c5 commit 8a732a7

File tree

5 files changed

+128
-37
lines changed

5 files changed

+128
-37
lines changed

distributed/shuffle/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def fail(self, exception: Exception) -> None:
296296
if not self.closed:
297297
self._exception = exception
298298

299-
def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
299+
def _read_from_disk(self, id: NDIndex) -> Any:
300300
self.raise_if_closed()
301301
return self._disk_buffer.read("_".join(str(i) for i in id))
302302

distributed/shuffle/_disk.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import shutil
77
import threading
88
from collections.abc import Generator, Iterator
9-
from contextlib import contextmanager
9+
from contextlib import contextmanager, nullcontext
1010
from pathlib import Path
1111
from typing import Any
1212

@@ -123,6 +123,11 @@ class DiskShardsBuffer(ShardsBuffer):
123123
implementation of this scheme.
124124
"""
125125

126+
directory: pathlib.Path
127+
_closed: bool
128+
_use_raw_buffers: bool | None
129+
_directory_lock: ReadWriteLock
130+
126131
def __init__(
127132
self,
128133
directory: str | pathlib.Path,
@@ -136,6 +141,7 @@ def __init__(
136141
self.directory = pathlib.Path(directory)
137142
self.directory.mkdir(exist_ok=True)
138143
self._closed = False
144+
self._use_raw_buffers = None
139145
self._directory_lock = ReadWriteLock()
140146

141147
@log_errors
@@ -152,14 +158,23 @@ async def _process(self, id: str, shards: list[Any]) -> None:
152158
future then we should consider simplifying this considerably and
153159
dropping the write into communicate above.
154160
"""
161+
assert shards
162+
if self._use_raw_buffers is None:
163+
self._use_raw_buffers = isinstance(shards[0], list) and isinstance(
164+
shards[0][0], (bytes, bytearray, memoryview)
165+
)
166+
serialize_ctx = (
167+
nullcontext()
168+
if self._use_raw_buffers
169+
else context_meter.meter("serialize", func=thread_time)
170+
)
171+
155172
nbytes_acc = 0
156173

157174
def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]:
158175
nonlocal nbytes_acc
159176
for shard in shards:
160-
if isinstance(shard, list) and isinstance(
161-
shard[0], (bytes, bytearray, memoryview)
162-
):
177+
if self._use_raw_buffers:
163178
# list[bytes | bytearray | memoryview] for dataframe shuffle
164179
# Shard was pre-serialized before being sent over the network.
165180
nbytes_acc += sum(map(nbytes, shard))
@@ -173,7 +188,7 @@ def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]:
173188
with (
174189
self._directory_lock.read(),
175190
context_meter.meter("disk-write"),
176-
context_meter.meter("serialize", func=thread_time),
191+
serialize_ctx,
177192
):
178193
if self._closed:
179194
raise RuntimeError("Already closed")
@@ -184,7 +199,7 @@ def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]:
184199
context_meter.digest_metric("disk-write", 1, "count")
185200
context_meter.digest_metric("disk-write", nbytes_acc, "bytes")
186201

187-
def read(self, id: str) -> list[Any]:
202+
def read(self, id: str) -> Any:
188203
"""Read a complete file back into memory"""
189204
self.raise_on_exception()
190205
if not self._inputs_done:
@@ -211,8 +226,7 @@ def read(self, id: str) -> list[Any]:
211226
else:
212227
raise DataUnavailable(id)
213228

214-
@staticmethod
215-
def _read(path: Path) -> tuple[list[Any], int]:
229+
def _read(self, path: Path) -> tuple[Any, int]:
216230
"""Open a memory-mapped file descriptor to disk, read all metadata, and unpickle
217231
all arrays. This is a fast sequence of short reads interleaved with seeks.
218232
Do not read in memory the actual data; the arrays' buffers will point to the
@@ -224,10 +238,14 @@ def _read(path: Path) -> tuple[list[Any], int]:
224238
"""
225239
with path.open(mode="r+b") as fh:
226240
buffer = memoryview(mmap.mmap(fh.fileno(), 0))
227-
228241
# The file descriptor has *not* been closed!
229-
shards = list(unpickle_bytestream(buffer))
230-
return shards, buffer.nbytes
242+
243+
assert self._use_raw_buffers is not None
244+
if self._use_raw_buffers:
245+
return buffer, buffer.nbytes
246+
else:
247+
shards = list(unpickle_bytestream(buffer))
248+
return shards, buffer.nbytes
231249

232250
async def close(self) -> None:
233251
await super().close()

distributed/shuffle/_pickle.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@
22

33
import pickle
44
from collections.abc import Iterator
5-
from typing import Any
5+
from typing import TYPE_CHECKING, Any
6+
7+
from toolz import first
68

79
from distributed.protocol.utils import pack_frames_prelude, unpack_frames
810

11+
if TYPE_CHECKING:
12+
import pandas as pd
13+
914

1015
def pickle_bytelist(obj: object, prelude: bool = True) -> list[pickle.PickleBuffer]:
1116
"""Variant of :func:`serialize_bytelist`, that doesn't support compression, locally
@@ -39,3 +44,68 @@ def unpickle_bytestream(b: bytes | bytearray | memoryview) -> Iterator[Any]:
3944
if remainder.nbytes == 0:
4045
break
4146
b = remainder
47+
48+
49+
def pickle_dataframe_shard(
50+
input_part_id: int,
51+
shard: pd.DataFrame,
52+
) -> list[pickle.PickleBuffer]:
53+
"""Optimized pickler for pandas Dataframes. DIscard all unnecessary metadata
54+
(like the columns header).
55+
56+
Parameters:
57+
obj: pandas
58+
"""
59+
return pickle_bytelist(
60+
(input_part_id, shard.index, *shard._mgr.blocks), prelude=False
61+
)
62+
63+
64+
def unpickle_and_concat_dataframe_shards(
65+
b: bytes | bytearray | memoryview, meta: pd.DataFrame
66+
) -> pd.DataFrame:
67+
"""Optimized unpickler for pandas Dataframes.
68+
69+
Parameters
70+
----------
71+
b:
72+
raw buffer, containing the concatenation of the outputs of
73+
:func:`pickle_dataframe_shard`, in arbitrary order
74+
meta:
75+
DataFrame header
76+
77+
Returns
78+
-------
79+
Reconstructed output shard, sorted by input partition ID
80+
81+
**Roundtrip example**
82+
83+
>>> import random
84+
>>> import pandas as pd
85+
>>> from toolz import concat
86+
87+
>>> df = pd.DataFrame(...) # Input partition
88+
>>> meta = df.iloc[:0].copy()
89+
>>> shards = df.iloc[0:10], df.iloc[10:20], ...
90+
>>> frames = [pickle_dataframe_shard(i, shard) for i, shard in enumerate(shards)]
91+
>>> random.shuffle(frames) # Simulate the frames arriving in arbitrary order
92+
>>> blob = bytearray(b"".join(concat(frames))) # Simulate disk roundtrip
93+
>>> df2 = unpickle_and_concat_dataframe_shards(blob, meta)
94+
"""
95+
import pandas as pd
96+
from pandas.core.internals import BlockManager
97+
98+
parts = list(unpickle_bytestream(b))
99+
# [(input_part_id, index, *blocks), ...]
100+
parts.sort(key=first)
101+
shards = []
102+
for _, idx, *blocks in parts:
103+
axes = [meta.columns, idx]
104+
df = pd.DataFrame._from_mgr( # type: ignore[attr-defined]
105+
BlockManager(blocks, axes, verify_integrity=False), axes
106+
)
107+
shards.append(df)
108+
109+
# Actually load memory-mapped buffers into memory and close the file
110+
# descriptors
111+
return pd.concat(shards, copy=True)

distributed/shuffle/_shuffle.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from pickle import PickleBuffer
1818
from typing import TYPE_CHECKING, Any
1919

20-
from toolz import first, second
2120
from tornado.ioloop import IOLoop
2221

2322
import dask
@@ -42,7 +41,10 @@
4241
)
4342
from distributed.shuffle._exceptions import DataUnavailable
4443
from distributed.shuffle._limiter import ResourceLimiter
45-
from distributed.shuffle._pickle import pickle_bytelist
44+
from distributed.shuffle._pickle import (
45+
pickle_dataframe_shard,
46+
unpickle_and_concat_dataframe_shards,
47+
)
4648
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
4749
from distributed.utils import nbytes
4850

@@ -311,14 +313,8 @@ def split_by_worker(
311313
-------
312314
{worker addr: (input_part_id, [(output_part_id, buffers), ...]), ...}
313315
314-
where buffers is a list of
315-
316-
[
317-
PickleBuffer(pickle bytes) # includes input_part_id
318-
buffer,
319-
buffer,
320-
...
321-
]
316+
where buffers is the serialized output (pickle bytes, buffer, buffer, ...) of
317+
(input_part_id, index, *blocks)
322318
323319
**Notes**
324320
@@ -341,7 +337,7 @@ def split_by_worker(
341337
assert isinstance(output_part_id, int)
342338
if drop_column:
343339
del part[column]
344-
frames = pickle_bytelist((input_part_id, part), prelude=False)
340+
frames = pickle_dataframe_shard(input_part_id, part)
345341
out[worker_for[output_part_id]].append((output_part_id, frames))
346342

347343
return {k: (input_part_id, v) for k, v in out.items()}
@@ -520,21 +516,16 @@ def _get_output_partition(
520516
key: Key,
521517
**kwargs: Any,
522518
) -> pd.DataFrame:
523-
import pandas as pd
519+
meta = self.meta.copy()
520+
if self.drop_column:
521+
meta = self.meta.drop(columns=self.column)
524522

525523
try:
526-
parts = self._read_from_disk((partition_id,))
524+
buffer = self._read_from_disk((partition_id,))
527525
except DataUnavailable:
528-
result = self.meta.copy()
529-
if self.drop_column:
530-
result = self.meta.drop(columns=self.column)
531-
return result
532-
533-
# [(input_partition_id, part), ...]] -> [part, ...]
534-
shards = list(map(second, sorted(parts, key=first)))
535-
# Actually load memory-mapped buffers into memory and close the file
536-
# descriptors
537-
return pd.concat(shards, copy=True)
526+
return meta
527+
528+
return unpickle_and_concat_dataframe_shards(buffer, meta)
538529

539530
def _get_assigned_worker(self, id: int) -> str:
540531
return self.worker_for[id]

distributed/shuffle/tests/test_core.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from pickle import PickleBuffer
4+
35
import pytest
46

57
from distributed.shuffle._core import _mean_shard_size
@@ -12,7 +14,17 @@ def test_mean_shard_size():
1214
# Don't fully iterate over large collections
1315
assert _mean_shard_size([b"12" * n for n in range(1000)]) == 9
1416
# Support any Buffer object
15-
assert _mean_shard_size([b"12", bytearray(b"1234"), memoryview(b"123456")]) == 4
17+
assert (
18+
_mean_shard_size(
19+
[
20+
b"12",
21+
bytearray(b"1234"),
22+
memoryview(b"123456"),
23+
PickleBuffer(b"12345678"),
24+
]
25+
)
26+
== 5
27+
)
1628
# Recursion into lists or tuples; ignore int
1729
assert _mean_shard_size([(1, 2, [3, b"123456"])]) == 6
1830
# Don't blindly call sizeof() on unexpected objects

0 commit comments

Comments
 (0)