11from __future__ import annotations
22
33import contextlib
4+ import mmap
45import pathlib
56import shutil
67import threading
7- from collections .abc import Callable , Generator , Iterable
8+ from collections .abc import Generator , Iterator
89from contextlib import contextmanager
10+ from pathlib import Path
911from typing import Any
1012
11- from toolz import concat
12-
1313from distributed .metrics import context_meter , thread_time
1414from distributed .shuffle ._buffer import ShardsBuffer
1515from distributed .shuffle ._exceptions import DataUnavailable
1616from distributed .shuffle ._limiter import ResourceLimiter
17- from distributed .shuffle ._pickle import pickle_bytelist
18- from distributed .utils import Deadline , empty_context , log_errors , nbytes
17+ from distributed .shuffle ._pickle import pickle_bytelist , unpickle_bytestream
18+ from distributed .utils import Deadline , log_errors , nbytes
1919
2020
2121class ReadWriteLock :
@@ -126,7 +126,6 @@ class DiskShardsBuffer(ShardsBuffer):
126126 def __init__ (
127127 self ,
128128 directory : str | pathlib .Path ,
129- read : Callable [[pathlib .Path ], tuple [Any , int ]],
130129 memory_limiter : ResourceLimiter ,
131130 ):
132131 super ().__init__ (
@@ -137,11 +136,10 @@ def __init__(
137136 self .directory = pathlib .Path (directory )
138137 self .directory .mkdir (exist_ok = True )
139138 self ._closed = False
140- self ._read = read
141139 self ._directory_lock = ReadWriteLock ()
142140
143141 @log_errors
144- async def _process (self , id : str , shards : list [Any ]) -> None :
142+ async def _process (self , id : str , shards : list [object ]) -> None :
145143 """Write one buffer to file
146144
147145 This function was built to offload the disk IO, but since then we've
@@ -154,36 +152,30 @@ async def _process(self, id: str, shards: list[Any]) -> None:
154152 future then we should consider simplifying this considerably and
155153 dropping the write into communicate above.
156154 """
157- frames : Iterable [bytes | bytearray | memoryview ]
158- if isinstance (shards [0 ], bytes ):
159- # Manually serialized dataframes
160- frames = shards
161- serialize_meter_ctx : Any = empty_context
162- else :
163- # Unserialized numpy arrays
164- # Note: no calls to pickle_bytelist will happen until we actually start
165- # writing to disk below.
166- frames = concat (pickle_bytelist (shard ) for shard in shards )
167- serialize_meter_ctx = context_meter .meter ("serialize" , func = thread_time )
155+ nbytes_acc = 0
156+
157+ def pickle_and_tally () -> Iterator [bytes | memoryview ]:
158+ nonlocal nbytes_acc
159+ for shard in shards :
160+ for frame in pickle_bytelist (shard ):
161+ nbytes_acc += nbytes (frame )
162+ yield frame
168163
169164 with (
170165 self ._directory_lock .read (),
171166 context_meter .meter ("disk-write" ),
172- serialize_meter_ctx ,
167+ context_meter . meter ( "serialize" , func = thread_time ) ,
173168 ):
174- # Consider boosting total_size a bit here to account for duplication
175- # We only need shared (i.e., read) access to the directory to write
176- # to a file inside of it.
177169 if self ._closed :
178170 raise RuntimeError ("Already closed" )
179171
180172 with open (self .directory / str (id ), mode = "ab" ) as f :
181- f .writelines (frames )
173+ f .writelines (pickle_and_tally () )
182174
183175 context_meter .digest_metric ("disk-write" , 1 , "count" )
184- context_meter .digest_metric ("disk-write" , sum ( map ( nbytes , frames )) , "bytes" )
176+ context_meter .digest_metric ("disk-write" , nbytes_acc , "bytes" )
185177
186- def read (self , id : str ) -> Any :
178+ def read (self , id : str ) -> list [ Any ] :
187179 """Read a complete file back into memory"""
188180 self .raise_on_exception ()
189181 if not self ._inputs_done :
@@ -210,6 +202,24 @@ def read(self, id: str) -> Any:
210202 else :
211203 raise DataUnavailable (id )
212204
205+ @staticmethod
206+ def _read (path : Path ) -> tuple [list [Any ], int ]:
207+ """Open a memory-mapped file descriptor to disk, read all metadata, and unpickle
208+ all arrays. This is a fast sequence of short reads interleaved with seeks.
209+ Do not read in memory the actual data; the arrays' buffers will point to the
210+ memory-mapped area.
211+
212+ The file descriptor will be automatically closed by the kernel when all the
213+ returned arrays are dereferenced, which will happen after the call to
214+ concatenate3.
215+ """
216+ with path .open (mode = "r+b" ) as fh :
217+ buffer = memoryview (mmap .mmap (fh .fileno (), 0 ))
218+
219+ # The file descriptor has *not* been closed!
220+ shards = list (unpickle_bytestream (buffer ))
221+ return shards , buffer .nbytes
222+
213223 async def close (self ) -> None :
214224 await super ().close ()
215225 with self ._directory_lock .write ():
0 commit comments