Skip to content

Commit b9725c5

Browse files
committed
Serialize dataframes manually
1 parent 37119a9 commit b9725c5

File tree

9 files changed

+150
-109
lines changed

9 files changed

+150
-109
lines changed

distributed/protocol/serialize.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,17 @@ def _deserialize_memoryview(header, frames):
839839
return out
840840

841841

842+
@dask_serialize.register(PickleBuffer)
843+
def _serialize_picklebuffer(obj):
844+
return _serialize_memoryview(obj.raw())
845+
846+
847+
@dask_deserialize.register(PickleBuffer)
848+
def _deserialize_picklebuffer(header, frames):
849+
out = _deserialize_memoryview(header, frames)
850+
return PickleBuffer(out)
851+
852+
842853
#########################
843854
# Descend into __dict__ #
844855
#########################

distributed/shuffle/_core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from distributed.shuffle._exceptions import ShuffleClosedError
3838
from distributed.shuffle._limiter import ResourceLimiter
3939
from distributed.shuffle._memory import MemoryShardsBuffer
40-
from distributed.sizeof import safe_sizeof as sizeof
4140
from distributed.utils import run_in_executor_with_context, sync
4241
from distributed.utils_comm import retry
4342

@@ -215,7 +214,7 @@ async def send(
215214
# and unpickle it on the other side.
216215
# Performance tests informing the size threshold:
217216
# https://github.com/dask/distributed/pull/8318
218-
shards_or_bytes: list | bytes = pickle.dumps(shards)
217+
shards_or_bytes: list | bytes = pickle.dumps(shards, protocol=5)
219218
else:
220219
shards_or_bytes = shards
221220

@@ -334,6 +333,7 @@ def add_partition(
334333
if self.transferred:
335334
raise RuntimeError(f"Cannot add more partitions to {self}")
336335
# Log metrics both in the "execute" and in the "p2p" contexts
336+
context_meter.digest_metric("p2p-partitions", 1, "count")
337337
with self._capture_metrics("foreground"):
338338
with (
339339
context_meter.meter("p2p-shard-partition-noncpu"),
@@ -509,7 +509,7 @@ def _mean_shard_size(shards: Iterable) -> int:
509509
if not isinstance(shard, int):
510510
# This also asserts that shard is a Buffer and that we didn't forget
511511
# a container or metadata type above
512-
size += sizeof(shard)
512+
size += memoryview(shard).nbytes
513513
count += 1
514514
if count == 10:
515515
break

distributed/shuffle/_disk.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __init__(
139139
self._directory_lock = ReadWriteLock()
140140

141141
@log_errors
142-
async def _process(self, id: str, shards: list[object]) -> None:
142+
async def _process(self, id: str, shards: list[Any]) -> None:
143143
"""Write one buffer to file
144144
145145
This function was built to offload the disk IO, but since then we've
@@ -154,12 +154,21 @@ async def _process(self, id: str, shards: list[object]) -> None:
154154
"""
155155
nbytes_acc = 0
156156

157-
def pickle_and_tally() -> Iterator[bytes | memoryview]:
157+
def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]:
158158
nonlocal nbytes_acc
159159
for shard in shards:
160-
for frame in pickle_bytelist(shard):
161-
nbytes_acc += nbytes(frame)
162-
yield frame
160+
if isinstance(shard, list) and isinstance(
161+
shard[0], (bytes, bytearray, memoryview)
162+
):
163+
# list[bytes | bytearray | memoryview] for dataframe shuffle
164+
# Shard was pre-serialized before being sent over the network.
165+
nbytes_acc += sum(map(nbytes, shard))
166+
yield from shard
167+
else:
168+
# tuple[NDIndex, ndarray] for array rechunk
169+
frames = [s.raw() for s in pickle_bytelist(shard)]
170+
nbytes_acc += sum(frame.nbytes for frame in frames)
171+
yield from frames
163172

164173
with (
165174
self._directory_lock.read(),

distributed/shuffle/_memory.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ def __init__(self) -> None:
2020

2121
@log_errors
2222
async def _process(self, id: str, shards: list[Any]) -> None:
23-
# TODO: This can be greatly simplified, there's no need for
24-
# background threads at all.
2523
self._shards[id].extend(shards)
2624

2725
def read(self, id: str) -> Any:
@@ -39,6 +37,7 @@ def read(self, id: str) -> Any:
3937
data = []
4038
while shards:
4139
shard = shards.pop()
40+
# TODO unpickle dataframes
4241
data.append(shard)
4342

4443
return data

distributed/shuffle/_pickle.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from distributed.protocol.utils import pack_frames_prelude, unpack_frames
88

99

10-
def pickle_bytelist(obj: object) -> list[bytes | memoryview]:
10+
def pickle_bytelist(obj: object, prelude: bool = True) -> list[pickle.PickleBuffer]:
1111
"""Variant of :func:`serialize_bytelist`, that doesn't support compression, locally
1212
defined classes, or any of its other fancy features but runs 10x faster for numpy
1313
arrays
@@ -18,11 +18,10 @@ def pickle_bytelist(obj: object) -> list[bytes | memoryview]:
1818
unpickle_bytestream
1919
"""
2020
frames: list = []
21-
pik = pickle.dumps(
22-
obj, protocol=5, buffer_callback=lambda pb: frames.append(pb.raw())
23-
)
24-
frames.insert(0, pik)
25-
frames.insert(0, pack_frames_prelude(frames))
21+
pik = pickle.dumps(obj, protocol=5, buffer_callback=frames.append)
22+
frames.insert(0, pickle.PickleBuffer(pik))
23+
if prelude:
24+
frames.insert(0, pickle.PickleBuffer(pack_frames_prelude(frames)))
2625
return frames
2726

2827

distributed/shuffle/_shuffle.py

Lines changed: 112 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
)
1515
from concurrent.futures import ThreadPoolExecutor
1616
from dataclasses import dataclass
17+
from pickle import PickleBuffer
1718
from typing import TYPE_CHECKING, Any
1819

19-
from toolz import concat, first, second
20+
from toolz import first, second
2021
from tornado.ioloop import IOLoop
2122

2223
import dask
@@ -28,6 +29,7 @@
2829
from distributed.core import PooledRPCCall
2930
from distributed.exceptions import Reschedule
3031
from distributed.metrics import context_meter
32+
from distributed.protocol.utils import pack_frames_prelude
3133
from distributed.shuffle._core import (
3234
NDIndex,
3335
ShuffleId,
@@ -40,8 +42,9 @@
4042
)
4143
from distributed.shuffle._exceptions import DataUnavailable
4244
from distributed.shuffle._limiter import ResourceLimiter
45+
from distributed.shuffle._pickle import pickle_bytelist
4346
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
44-
from distributed.sizeof import sizeof
47+
from distributed.utils import nbytes
4548

4649
logger = logging.getLogger("distributed.shuffle")
4750
if TYPE_CHECKING:
@@ -297,36 +300,51 @@ def _construct_graph(self) -> _T_LowLevelGraph:
297300
def split_by_worker(
298301
df: pd.DataFrame,
299302
column: str,
300-
worker_for: pd.Series,
301-
) -> dict[str, pd.DataFrame]:
302-
"""Split data into many horizontal slices, partitioned by destination worker"""
303-
nrows = len(df)
304-
305-
# (cudf support) Avoid pd.Series
306-
constructor = df._constructor_sliced
307-
assert isinstance(constructor, type)
308-
if type(worker_for) is not constructor:
309-
worker_for = constructor(worker_for)
310-
311-
df = df.merge(
312-
right=worker_for,
313-
left_on=column,
314-
right_index=True,
315-
how="inner",
316-
)
317-
out = dict(split_by_partition(df, "_workers", drop_column=True))
318-
assert sum(map(len, out.values())) == nrows
319-
return out
320-
303+
drop_column: bool,
304+
worker_for: dict[int, str],
305+
input_part_id: int,
306+
) -> dict[str, tuple[int, list[tuple[int, list[PickleBuffer]]]]]:
307+
"""Split data into many horizontal slices, partitioned by destination worker,
308+
and serialize them once.
309+
310+
Returns
311+
-------
312+
{worker addr: (input_part_id, [(output_part_id, buffers), ...]), ...}
313+
314+
where buffers is a list of
315+
316+
[
317+
PickleBuffer(pickle bytes) # includes input_part_id
318+
buffer,
319+
buffer,
320+
...
321+
]
322+
323+
**Notes**
324+
325+
- The pickle header, which is a bytes object, is wrapped in PickleBuffer so
326+
that it's not unnecessarily deep-copied when it's deserialized by the network
327+
stack.
328+
- We are not delegating serialization to the network stack because (1) it's quicker
329+
with plain pickle and (2) we want to avoid deserializing everything on receive()
330+
only to re-serialize it again immediately afterwards when writing it to disk.
331+
So we serialize it once now and deserialize it once after reading back from disk.
332+
333+
See Also
334+
--------
335+
distributed.protocol.serialize._deserialize_bytes
336+
distributed.protocol.serialize._deserialize_picklebuffer
337+
"""
338+
out: defaultdict[str, list[tuple[int, list[PickleBuffer]]]] = defaultdict(list)
321339

322-
def split_by_partition(
323-
df: pd.DataFrame, column: str, drop_column: bool
324-
) -> Iterator[tuple[Any, pd.DataFrame]]:
325-
"""Split data into many horizontal slices, partitioned by final partition"""
326-
for k, group in df.groupby(column, observed=True):
340+
for output_part_id, part in df.groupby(column, observed=False):
341+
assert isinstance(output_part_id, int)
327342
if drop_column:
328-
del group[column]
329-
yield k, group
343+
del part[column]
344+
frames = pickle_bytelist((input_part_id, part), prelude=False)
345+
out[worker_for[output_part_id]].append((output_part_id, frames))
346+
347+
return {k: (input_part_id, v) for k, v in out.items()}
330348

331349

332350
class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]):
@@ -376,7 +394,7 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]):
376394
column: str
377395
meta: pd.DataFrame
378396
partitions_of: dict[str, list[int]]
379-
worker_for: pd.Series
397+
worker_for: dict[int, str]
380398
drop_column: bool
381399

382400
def __init__(
@@ -399,8 +417,6 @@ def __init__(
399417
drop_column: bool,
400418
loop: IOLoop,
401419
):
402-
import pandas as pd
403-
404420
super().__init__(
405421
id=id,
406422
run_id=run_id,
@@ -422,55 +438,81 @@ def __init__(
422438
for part, addr in worker_for.items():
423439
partitions_of[addr].append(part)
424440
self.partitions_of = dict(partitions_of)
425-
self.worker_for = pd.Series(worker_for, name="_workers").astype("category")
441+
self.worker_for = worker_for
426442
self.drop_column = drop_column
427443

428-
async def _receive(self, data: list[tuple[int, pd.DataFrame]]) -> None:
444+
async def _receive(
445+
# See split_by_worker to understand annotation of data.
446+
# PickleBuffer objects may have been converted to bytearray by the
447+
# pickle roundtrip that is done by _core.py when buffers are too small
448+
self,
449+
data: list[
450+
tuple[int, list[tuple[int, list[PickleBuffer | bytes | bytearray]]]]
451+
],
452+
) -> None:
429453
self.raise_if_closed()
430454

431-
filtered = []
432-
for partition_id, part in data:
433-
if partition_id not in self.received:
434-
filtered.append((partition_id, part))
435-
self.received.add(partition_id)
436-
self.total_recvd += sizeof(part)
437-
del data
438-
if not filtered:
455+
to_write: defaultdict[
456+
NDIndex, list[bytes | bytearray | memoryview]
457+
] = defaultdict(list)
458+
459+
for input_part_id, parts in data:
460+
if input_part_id not in self.received:
461+
self.received.add(input_part_id)
462+
for output_part_id, frames in parts:
463+
frames_raw = [
464+
frame.raw() if isinstance(frame, PickleBuffer) else frame
465+
for frame in frames
466+
]
467+
self.total_recvd += sum(map(nbytes, frames_raw))
468+
to_write[output_part_id,] += [
469+
pack_frames_prelude(frames_raw),
470+
*frames_raw,
471+
]
472+
473+
if not to_write:
439474
return
440475
try:
441-
groups = await self.offload(self._repartition_buffers, filtered)
442-
del filtered
443-
await self._write_to_disk(groups)
476+
await self._write_to_disk(to_write)
444477
except Exception as e:
445478
self._exception = e
446479
raise
447480

448-
def _repartition_buffers(
449-
self, data: list[tuple[int, pd.DataFrame]]
450-
) -> dict[NDIndex, list[tuple[int, pd.DataFrame]]]:
451-
out: dict[NDIndex, list[tuple[int, pd.DataFrame]]] = defaultdict(list)
452-
453-
for input_part_id, part in data:
454-
groups = split_by_partition(part, self.column, self.drop_column)
455-
for output_part_id, part in groups:
456-
out[output_part_id,].append((input_part_id, part))
457-
458-
assert sum(len(part) for _, part in data) == sum(
459-
len(part) for parts in out.values() for _, part in parts
460-
)
461-
return out
462-
463481
def _shard_partition(
464482
self,
465483
data: pd.DataFrame,
466484
partition_id: int,
467-
**kwargs: Any,
468-
) -> dict[str, tuple[int, pd.DataFrame]]:
469-
out = split_by_worker(data, self.column, self.worker_for)
470-
nbytes = sum(map(sizeof, out.values()))
471-
context_meter.digest_metric("p2p-shards", nbytes, "bytes")
472-
context_meter.digest_metric("p2p-shards", len(out), "count")
473-
return {k: (partition_id, s) for k, s in out.items()}
485+
# See split_by_worker to understand annotation
486+
) -> dict[str, tuple[int, list[tuple[int, list[PickleBuffer]]]]]:
487+
out = split_by_worker(
488+
df=data,
489+
column=self.column,
490+
drop_column=self.drop_column,
491+
worker_for=self.worker_for,
492+
input_part_id=partition_id,
493+
)
494+
495+
# Log metrics
496+
# Note: more metrics for this function call are logged by _core.add_partitiion()
497+
overhead_nbytes = 0
498+
buffers_nbytes = 0
499+
shards_count = 0
500+
buffers_count = 0
501+
for _, shards in out.values():
502+
shards_count += len(shards)
503+
for _, frames in shards:
504+
# frames = [pickle bytes, buffer, buffer, ...]
505+
buffers_count += len(frames) - 2
506+
overhead_nbytes += frames[0].raw().nbytes
507+
buffers_nbytes += sum(frame.raw().nbytes for frame in frames[1:])
508+
509+
context_meter.digest_metric("p2p-shards-overhead", overhead_nbytes, "bytes")
510+
context_meter.digest_metric("p2p-shards-buffers", buffers_nbytes, "bytes")
511+
context_meter.digest_metric("p2p-shards-buffers", buffers_count, "count")
512+
context_meter.digest_metric("p2p-shards", shards_count, "count")
513+
# End log metrics
514+
515+
return out
474516

475517
def _get_output_partition(
476518
self,
@@ -488,8 +530,8 @@ def _get_output_partition(
488530
result = self.meta.drop(columns=self.column)
489531
return result
490532

491-
# [[(input_partition_id, part), (...), ...], [...]] -> [part, ...]
492-
shards = list(map(second, sorted(concat(parts), key=first)))
533+
# [(input_partition_id, part), ...]] -> [part, ...]
534+
shards = list(map(second, sorted(parts, key=first)))
493535
# Actually load memory-mapped buffers into memory and close the file
494536
# descriptors
495537
return pd.concat(shards, copy=True)

0 commit comments

Comments
 (0)