Skip to content

Commit 676b308

Browse files
committed
Encapsulate serialization
1 parent 39d4112 commit 676b308

File tree

2 files changed

+82
-23
lines changed

2 files changed

+82
-23
lines changed

distributed/shuffle/_pickle.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from __future__ import annotations
22

33
import pickle
4-
from collections.abc import Iterator
5-
from typing import Any
4+
from collections.abc import Iterable, Iterator
5+
from typing import TYPE_CHECKING, Any
6+
7+
from toolz import first
68

79
from distributed.protocol.utils import pack_frames_prelude, unpack_frames
810

11+
if TYPE_CHECKING:
12+
import pandas as pd
13+
914

1015
def pickle_bytelist(obj: object, prelude: bool = True) -> list[pickle.PickleBuffer]:
1116
"""Variant of :func:`serialize_bytelist`, that doesn't support compression, locally
@@ -39,3 +44,72 @@ def unpickle_bytestream(b: bytes | bytearray | memoryview) -> Iterator[Any]:
3944
if remainder.nbytes == 0:
4045
break
4146
b = remainder
47+
48+
49+
def pickle_dataframe_shard(
50+
input_part_id: int,
51+
shard: pd.DataFrame,
52+
) -> list[pickle.PickleBuffer]:
53+
"""Optimized pickler for pandas Dataframes. DIscard all unnecessary metadata
54+
(like the columns header).
55+
56+
Parameters:
57+
obj: pandas
58+
"""
59+
return pickle_bytelist(
60+
(input_part_id, shard.index, *shard._mgr.blocks), prelude=False
61+
)
62+
63+
64+
def unpickle_and_concat_dataframe_shards(
65+
parts: Iterable[Any], meta: pd.DataFrame
66+
) -> pd.DataFrame:
67+
"""Optimized unpickler for pandas Dataframes.
68+
69+
Parameters
70+
----------
71+
parts:
72+
output of ``unpickle_bytestream(b)``, where b is the memory-mapped blob of
73+
pickled data which is the concatenation of the outputs of
74+
:func:`pickle_dataframe_shard` in arbitrary order
75+
meta:
76+
DataFrame header
77+
78+
Returns
79+
-------
80+
Reconstructed output shard, sorted by input partition ID
81+
82+
**Roundtrip example**
83+
84+
.. code-block:: python
85+
86+
import random
87+
import pandas as pd
88+
89+
df = pd.DataFrame(...) # Input partition
90+
meta = df.iloc[:0]
91+
shards = df.iloc[0:10], df.iloc[10:20], ...
92+
frames = [pickle_dataframe_shard(i, shard) for i, shard in enumerate(shards)]
93+
random.shuffle(frames) # Simulate the frames arriving in arbitrary order
94+
frames = [f for fs in frames for f in fs] # Flatten
95+
blob = bytearray(b"".join(frames)) # Simulate disk roundtrip
96+
parts = unpickle_bytestream(blob)
97+
df2 = unpickle_and_concat_dataframe_shards(parts, meta)
98+
99+
"""
100+
import pandas as pd
101+
from pandas.core.internals import BlockManager
102+
103+
# [(input_part_id, index, *blocks), ...]
104+
parts = sorted(parts, key=first)
105+
shards = []
106+
for _, idx, *blocks in parts:
107+
axes = [meta.columns, idx]
108+
df = pd.DataFrame._from_mgr( # type: ignore[attr-defined]
109+
BlockManager(blocks, axes, verify_integrity=False), axes
110+
)
111+
shards.append(df)
112+
113+
# Actually load memory-mapped buffers into memory and close the file
114+
# descriptors
115+
return pd.concat(shards, copy=True)

distributed/shuffle/_shuffle.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from pickle import PickleBuffer
1818
from typing import TYPE_CHECKING, Any
1919

20-
from toolz import first
2120
from tornado.ioloop import IOLoop
2221

2322
import dask
@@ -42,7 +41,10 @@
4241
)
4342
from distributed.shuffle._exceptions import DataUnavailable
4443
from distributed.shuffle._limiter import ResourceLimiter
45-
from distributed.shuffle._pickle import pickle_bytelist
44+
from distributed.shuffle._pickle import (
45+
pickle_dataframe_shard,
46+
unpickle_and_concat_dataframe_shards,
47+
)
4648
from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin
4749
from distributed.utils import nbytes
4850

@@ -335,9 +337,7 @@ def split_by_worker(
335337
assert isinstance(output_part_id, int)
336338
if drop_column:
337339
del part[column]
338-
frames = pickle_bytelist(
339-
(input_part_id, part.index, *part._mgr.blocks), prelude=False
340-
)
340+
frames = pickle_dataframe_shard(input_part_id, part)
341341
out[worker_for[output_part_id]].append((output_part_id, frames))
342342

343343
return {k: (input_part_id, v) for k, v in out.items()}
@@ -516,9 +516,6 @@ def _get_output_partition(
516516
key: Key,
517517
**kwargs: Any,
518518
) -> pd.DataFrame:
519-
import pandas as pd
520-
from pandas.core.internals import BlockManager
521-
522519
meta = self.meta.copy()
523520
if self.drop_column:
524521
meta = self.meta.drop(columns=self.column)
@@ -528,19 +525,7 @@ def _get_output_partition(
528525
except DataUnavailable:
529526
return meta
530527

531-
# [(input_part_id, index, *blocks), ...]
532-
parts = sorted(parts, key=first)
533-
shards = []
534-
for _, idx, *blocks in parts:
535-
axes = [meta.columns, idx]
536-
df = pd.DataFrame._from_mgr( # type: ignore[attr-defined]
537-
BlockManager(blocks, axes, verify_integrity=False), axes
538-
)
539-
shards.append(df)
540-
541-
# Actually load memory-mapped buffers into memory and close the file
542-
# descriptors
543-
return pd.concat(shards, copy=True)
528+
return unpickle_and_concat_dataframe_shards(parts, meta)
544529

545530
def _get_assigned_worker(self, id: int) -> str:
546531
return self.worker_for[id]

0 commit comments

Comments
 (0)