66import shutil
77import threading
88from collections .abc import Generator , Iterator
9- from contextlib import contextmanager
9+ from contextlib import contextmanager , nullcontext
1010from pathlib import Path
1111from typing import Any
1212
@@ -123,6 +123,11 @@ class DiskShardsBuffer(ShardsBuffer):
123123 implementation of this scheme.
124124 """
125125
126+ directory : pathlib .Path
127+ _closed : bool
128+ _use_raw_buffers : bool | None
129+ _directory_lock : ReadWriteLock
130+
126131 def __init__ (
127132 self ,
128133 directory : str | pathlib .Path ,
@@ -136,6 +141,7 @@ def __init__(
136141 self .directory = pathlib .Path (directory )
137142 self .directory .mkdir (exist_ok = True )
138143 self ._closed = False
144+ self ._use_raw_buffers = None
139145 self ._directory_lock = ReadWriteLock ()
140146
141147 @log_errors
@@ -152,14 +158,23 @@ async def _process(self, id: str, shards: list[Any]) -> None:
152158 future then we should consider simplifying this considerably and
153159 dropping the write into communicate above.
154160 """
161+ assert shards
162+ if self ._use_raw_buffers is None :
163+ self ._use_raw_buffers = isinstance (shards [0 ], list ) and isinstance (
164+ shards [0 ][0 ], (bytes , bytearray , memoryview )
165+ )
166+ serialize_ctx = (
167+ nullcontext ()
168+ if self ._use_raw_buffers
169+ else context_meter .meter ("serialize" , func = thread_time )
170+ )
171+
155172 nbytes_acc = 0
156173
157174 def pickle_and_tally () -> Iterator [bytes | bytearray | memoryview ]:
158175 nonlocal nbytes_acc
159176 for shard in shards :
160- if isinstance (shard , list ) and isinstance (
161- shard [0 ], (bytes , bytearray , memoryview )
162- ):
177+ if self ._use_raw_buffers :
163178 # list[bytes | bytearray | memoryview] for dataframe shuffle
164179 # Shard was pre-serialized before being sent over the network.
165180 nbytes_acc += sum (map (nbytes , shard ))
@@ -173,7 +188,7 @@ def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]:
173188 with (
174189 self ._directory_lock .read (),
175190 context_meter .meter ("disk-write" ),
176- context_meter . meter ( "serialize" , func = thread_time ) ,
191+ serialize_ctx ,
177192 ):
178193 if self ._closed :
179194 raise RuntimeError ("Already closed" )
@@ -184,7 +199,7 @@ def pickle_and_tally() -> Iterator[bytes | bytearray | memoryview]:
184199 context_meter .digest_metric ("disk-write" , 1 , "count" )
185200 context_meter .digest_metric ("disk-write" , nbytes_acc , "bytes" )
186201
187- def read (self , id : str ) -> list [ Any ] :
202+ def read (self , id : str ) -> Any :
188203 """Read a complete file back into memory"""
189204 self .raise_on_exception ()
190205 if not self ._inputs_done :
@@ -211,8 +226,7 @@ def read(self, id: str) -> list[Any]:
211226 else :
212227 raise DataUnavailable (id )
213228
214- @staticmethod
215- def _read (path : Path ) -> tuple [list [Any ], int ]:
229+ def _read (self , path : Path ) -> tuple [Any , int ]:
216230 """Open a memory-mapped file descriptor to disk, read all metadata, and unpickle
217231 all arrays. This is a fast sequence of short reads interleaved with seeks.
218232 Do not read in memory the actual data; the arrays' buffers will point to the
@@ -224,10 +238,14 @@ def _read(path: Path) -> tuple[list[Any], int]:
224238 """
225239 with path .open (mode = "r+b" ) as fh :
226240 buffer = memoryview (mmap .mmap (fh .fileno (), 0 ))
227-
228241 # The file descriptor has *not* been closed!
229- shards = list (unpickle_bytestream (buffer ))
230- return shards , buffer .nbytes
242+
243+ assert self ._use_raw_buffers is not None
244+ if self ._use_raw_buffers :
245+ return buffer , buffer .nbytes
246+ else :
247+ shards = list (unpickle_bytestream (buffer ))
248+ return shards , buffer .nbytes
231249
232250 async def close (self ) -> None :
233251 await super ().close ()
0 commit comments