1414)
1515from concurrent .futures import ThreadPoolExecutor
1616from dataclasses import dataclass
17+ from pickle import PickleBuffer
1718from typing import TYPE_CHECKING , Any
1819
19- from toolz import concat , first , second
20+ from toolz import first , second
2021from tornado .ioloop import IOLoop
2122
2223import dask
2829from distributed .core import PooledRPCCall
2930from distributed .exceptions import Reschedule
3031from distributed .metrics import context_meter
32+ from distributed .protocol .utils import pack_frames_prelude
3133from distributed .shuffle ._core import (
3234 NDIndex ,
3335 ShuffleId ,
4042)
4143from distributed .shuffle ._exceptions import DataUnavailable
4244from distributed .shuffle ._limiter import ResourceLimiter
45+ from distributed .shuffle ._pickle import pickle_bytelist
4346from distributed .shuffle ._worker_plugin import ShuffleWorkerPlugin
44- from distributed .sizeof import sizeof
47+ from distributed .utils import nbytes
4548
4649logger = logging .getLogger ("distributed.shuffle" )
4750if TYPE_CHECKING :
@@ -297,36 +300,51 @@ def _construct_graph(self) -> _T_LowLevelGraph:
297300def split_by_worker (
298301 df : pd .DataFrame ,
299302 column : str ,
300- worker_for : pd .Series ,
301- ) -> dict [str , pd .DataFrame ]:
302- """Split data into many horizontal slices, partitioned by destination worker"""
303- nrows = len (df )
304-
305- # (cudf support) Avoid pd.Series
306- constructor = df ._constructor_sliced
307- assert isinstance (constructor , type )
308- if type (worker_for ) is not constructor :
309- worker_for = constructor (worker_for )
310-
311- df = df .merge (
312- right = worker_for ,
313- left_on = column ,
314- right_index = True ,
315- how = "inner" ,
316- )
317- out = dict (split_by_partition (df , "_workers" , drop_column = True ))
318- assert sum (map (len , out .values ())) == nrows
319- return out
320-
303+ drop_column : bool ,
304+ worker_for : dict [int , str ],
305+ input_part_id : int ,
306+ ) -> dict [str , tuple [int , list [tuple [int , list [PickleBuffer ]]]]]:
307+ """Split data into many horizontal slices, partitioned by destination worker,
308+ and serialize them once.
309+
310+ Returns
311+ -------
312+ {worker addr: (input_part_id, [(output_part_id, buffers), ...]), ...}
313+
314+ where buffers is a list of
315+
316+ [
317+ PickleBuffer(pickle bytes) # includes input_part_id
318+ buffer,
319+ buffer,
320+ ...
321+ ]
322+
323+ **Notes**
324+
325+ - The pickle header, which is a bytes object, is wrapped in PickleBuffer so
326+ that it's not unnecessarily deep-copied when it's deserialized by the network
327+ stack.
328+ - We are not delegating serialization to the network stack because (1) it's quicker
329+ with plain pickle and (2) we want to avoid deserializing everything on receive()
330+ only to re-serialize it again immediately afterwards when writing it to disk.
331+ So we serialize it once now and deserialize it once after reading back from disk.
332+
333+ See Also
334+ --------
335+ distributed.protocol.serialize._deserialize_bytes
336+ distributed.protocol.serialize._deserialize_picklebuffer
337+ """
338+ out : defaultdict [str , list [tuple [int , list [PickleBuffer ]]]] = defaultdict (list )
321339
322- def split_by_partition (
323- df : pd .DataFrame , column : str , drop_column : bool
324- ) -> Iterator [tuple [Any , pd .DataFrame ]]:
325- """Split data into many horizontal slices, partitioned by final partition"""
326- for k , group in df .groupby (column , observed = True ):
340+ for output_part_id , part in df .groupby (column , observed = False ):
341+ assert isinstance (output_part_id , int )
327342 if drop_column :
328- del group [column ]
329- yield k , group
343+ del part [column ]
344+ frames = pickle_bytelist ((input_part_id , part ), prelude = False )
345+ out [worker_for [output_part_id ]].append ((output_part_id , frames ))
346+
347+ return {k : (input_part_id , v ) for k , v in out .items ()}
330348
331349
332350class DataFrameShuffleRun (ShuffleRun [int , "pd.DataFrame" ]):
@@ -376,7 +394,7 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]):
376394 column : str
377395 meta : pd .DataFrame
378396 partitions_of : dict [str , list [int ]]
379- worker_for : pd . Series
397+ worker_for : dict [ int , str ]
380398 drop_column : bool
381399
382400 def __init__ (
@@ -399,8 +417,6 @@ def __init__(
399417 drop_column : bool ,
400418 loop : IOLoop ,
401419 ):
402- import pandas as pd
403-
404420 super ().__init__ (
405421 id = id ,
406422 run_id = run_id ,
@@ -422,55 +438,81 @@ def __init__(
422438 for part , addr in worker_for .items ():
423439 partitions_of [addr ].append (part )
424440 self .partitions_of = dict (partitions_of )
425- self .worker_for = pd . Series ( worker_for , name = "_workers" ). astype ( "category" )
441+ self .worker_for = worker_for
426442 self .drop_column = drop_column
427443
428- async def _receive (self , data : list [tuple [int , pd .DataFrame ]]) -> None :
444+ async def _receive (
445+ # See split_by_worker to understand annotation of data.
446+ # PickleBuffer objects may have been converted to bytearray by the
447+ # pickle roundtrip that is done by _core.py when buffers are too small
448+ self ,
449+ data : list [
450+ tuple [int , list [tuple [int , list [PickleBuffer | bytes | bytearray ]]]]
451+ ],
452+ ) -> None :
429453 self .raise_if_closed ()
430454
431- filtered = []
432- for partition_id , part in data :
433- if partition_id not in self .received :
434- filtered .append ((partition_id , part ))
435- self .received .add (partition_id )
436- self .total_recvd += sizeof (part )
437- del data
438- if not filtered :
455+ to_write : defaultdict [
456+ NDIndex , list [bytes | bytearray | memoryview ]
457+ ] = defaultdict (list )
458+
459+ for input_part_id , parts in data :
460+ if input_part_id not in self .received :
461+ self .received .add (input_part_id )
462+ for output_part_id , frames in parts :
463+ frames_raw = [
464+ frame .raw () if isinstance (frame , PickleBuffer ) else frame
465+ for frame in frames
466+ ]
467+ self .total_recvd += sum (map (nbytes , frames_raw ))
468+ to_write [output_part_id ,] += [
469+ pack_frames_prelude (frames_raw ),
470+ * frames_raw ,
471+ ]
472+
473+ if not to_write :
439474 return
440475 try :
441- groups = await self .offload (self ._repartition_buffers , filtered )
442- del filtered
443- await self ._write_to_disk (groups )
476+ await self ._write_to_disk (to_write )
444477 except Exception as e :
445478 self ._exception = e
446479 raise
447480
448- def _repartition_buffers (
449- self , data : list [tuple [int , pd .DataFrame ]]
450- ) -> dict [NDIndex , list [tuple [int , pd .DataFrame ]]]:
451- out : dict [NDIndex , list [tuple [int , pd .DataFrame ]]] = defaultdict (list )
452-
453- for input_part_id , part in data :
454- groups = split_by_partition (part , self .column , self .drop_column )
455- for output_part_id , part in groups :
456- out [output_part_id ,].append ((input_part_id , part ))
457-
458- assert sum (len (part ) for _ , part in data ) == sum (
459- len (part ) for parts in out .values () for _ , part in parts
460- )
461- return out
462-
463481 def _shard_partition (
464482 self ,
465483 data : pd .DataFrame ,
466484 partition_id : int ,
467- ** kwargs : Any ,
468- ) -> dict [str , tuple [int , pd .DataFrame ]]:
469- out = split_by_worker (data , self .column , self .worker_for )
470- nbytes = sum (map (sizeof , out .values ()))
471- context_meter .digest_metric ("p2p-shards" , nbytes , "bytes" )
472- context_meter .digest_metric ("p2p-shards" , len (out ), "count" )
473- return {k : (partition_id , s ) for k , s in out .items ()}
485+ # See split_by_worker to understand annotation
486+ ) -> dict [str , tuple [int , list [tuple [int , list [PickleBuffer ]]]]]:
487+ out = split_by_worker (
488+ df = data ,
489+ column = self .column ,
490+ drop_column = self .drop_column ,
491+ worker_for = self .worker_for ,
492+ input_part_id = partition_id ,
493+ )
494+
495+ # Log metrics
496+ # Note: more metrics for this function call are logged by _core.add_partitiion()
497+ overhead_nbytes = 0
498+ buffers_nbytes = 0
499+ shards_count = 0
500+ buffers_count = 0
501+ for _ , shards in out .values ():
502+ shards_count += len (shards )
503+ for _ , frames in shards :
504+ # frames = [pickle bytes, buffer, buffer, ...]
505+ buffers_count += len (frames ) - 2
506+ overhead_nbytes += frames [0 ].raw ().nbytes
507+ buffers_nbytes += sum (frame .raw ().nbytes for frame in frames [1 :])
508+
509+ context_meter .digest_metric ("p2p-shards-overhead" , overhead_nbytes , "bytes" )
510+ context_meter .digest_metric ("p2p-shards-buffers" , buffers_nbytes , "bytes" )
511+ context_meter .digest_metric ("p2p-shards-buffers" , buffers_count , "count" )
512+ context_meter .digest_metric ("p2p-shards" , shards_count , "count" )
513+ # End log metrics
514+
515+ return out
474516
475517 def _get_output_partition (
476518 self ,
@@ -488,8 +530,8 @@ def _get_output_partition(
488530 result = self .meta .drop (columns = self .column )
489531 return result
490532
491- # [[ (input_partition_id, part), (...), ...], [ ...]] -> [part, ...]
492- shards = list (map (second , sorted (concat ( parts ) , key = first )))
533+ # [(input_partition_id, part), ...]] -> [part, ...]
534+ shards = list (map (second , sorted (parts , key = first )))
493535 # Actually load memory-mapped buffers into memory and close the file
494536 # descriptors
495537 return pd .concat (shards , copy = True )
0 commit comments