55import functools
66import logging
77import os
8+ import shutil
89import socket
910import subprocess
1011import sys
1112import tempfile
1213import uuid
1314import zipfile
1415from collections .abc import Awaitable
16+ from contextlib import contextmanager
17+ from importlib .util import find_spec
18+ from io import BytesIO
1519from typing import TYPE_CHECKING , Any , Callable , ClassVar
20+ from types import ModuleType
21+ from typing import Any , Tuple
22+ from pathlib import Path
1623
1724from dask .typing import Key
1825from dask .utils import _deprecated_kwarg , funcname , tmpfile
2936 from distributed .scheduler import TaskStateState as SchedulerTaskStateState
3037 from distributed .worker import Worker
3138 from distributed .worker_state_machine import TaskStateState as WorkerTaskStateState
39+ from distributed .node import ServerNode
3240
3341logger = logging .getLogger (__name__ )
3442
@@ -1102,3 +1110,165 @@ def setup(self, worker):
11021110
11031111 def teardown (self , worker ):
11041112 self ._exit_stack .close ()
1113+
1114+
1115+ @contextmanager
1116+ def serialize_module (
1117+ module : ModuleType , exclude : Tuple [str ] = ("__pycache__" , ".DS_Store" )
1118+ ) -> Path :
1119+ module_path = Path (module .__file__ )
1120+
1121+ if module_path .stem == "__init__" :
1122+ # In case of package we serialize the whole package
1123+ module_path = module_path .parent
1124+ if "." in module .__name__ :
1125+ # TODO: the problem is that we serialize the `package.module`, as module.egg that contains module.py,
1126+ # but it should contain the whole structure of the package (package/module.py)
1127+ raise Exception (
1128+ f"Plugin supports only top-level packages or single-file modules. You provided `{ module .__name__ } `, try `{ module .__name__ .split ('.' )[0 ]} `."
1129+ )
1130+
1131+ # In case of single file we don't need to serialize anything
1132+
1133+ with tempfile .TemporaryDirectory () as tmp :
1134+ package_name = module_path .name
1135+
1136+ package_copy_path = Path (tmp ).joinpath (package_name )
1137+ if module_path .is_dir ():
1138+ copied_package = Path (
1139+ shutil .copytree (
1140+ module_path ,
1141+ package_copy_path ,
1142+ ignore = shutil .ignore_patterns (f"{ package_name } .zip" , * exclude ),
1143+ )
1144+ )
1145+ else :
1146+ copied_package = Path (shutil .copy2 (module_path , package_copy_path ))
1147+
1148+ archive_path = shutil .make_archive (
1149+ # output path including a name w/o extension
1150+ base_name = str (copied_package ),
1151+ format = "zip" ,
1152+ # chroot
1153+ root_dir = copied_package .parent ,
1154+ # Name of the directory to archive and a common prefix of all files and directories in the archive
1155+ base_dir = package_name ,
1156+ )
1157+
1158+ egg_file = shutil .move (archive_path , package_copy_path .with_suffix (".egg" ))
1159+
1160+ # zip file handler
1161+ zip = zipfile .ZipFile (egg_file )
1162+ # list available files in the container
1163+ logger .debug (
1164+ "The egg file %s contains the following files %s" ,
1165+ str (egg_file ),
1166+ str (zip .namelist ()),
1167+ )
1168+
1169+ logger .info ("Created an egg file %s from %s" , str (egg_file ), str (module_path ))
1170+
1171+ yield Path (egg_file )
1172+
1173+
1174+ class AbstractUploadModulePlugin :
1175+ def __init__ (self , module : ModuleType ):
1176+ self ._module_name = module .__name__
1177+ self ._data : bytes
1178+ self ._filepath : Path
1179+ self ._filename : str
1180+ with serialize_module (module ) as filepath :
1181+ self ._filename = filepath .name
1182+ with open (filepath , "rb" ) as f :
1183+ self ._data = f .read ()
1184+
1185+ async def _upload_file (self , node : ServerNode ):
1186+ response = await node .upload_file (self ._filename , self ._data , load = True )
1187+ assert len (self ._data ) == response ["nbytes" ]
1188+
1189+ async def _upload (self , node : ServerNode ):
1190+ import zipfile
1191+ import sys
1192+ try :
1193+ from IPython .extensions .autoreload import superreload
1194+ except ImportError :
1195+ superreload = lambda x : x
1196+
1197+ # Try to find already loaded module
1198+ module = (
1199+ sys .modules [self ._module_name ] if self ._module_name in sys .modules else None
1200+ )
1201+ # Try to find module on disk
1202+ module_spec = find_spec (self ._module_name )
1203+
1204+ if not module_spec and not module :
1205+ # If module does not exist we keep it as egg file and load it.
1206+ logger .info (
1207+ 'Uploading a new module "%s" to "%s" on %s "%s"' ,
1208+ self ._module_name ,
1209+ str (self ._filename ),
1210+ "worker" if isinstance (node , Worker ) else "scheduler" ,
1211+ node .id ,
1212+ )
1213+ await self ._upload_file (node )
1214+ return
1215+
1216+ if module :
1217+ module_path = self ._get_module_dir (module )
1218+ else :
1219+ module_path = Path (module_spec .origin )
1220+
1221+ if ".egg" in str (module_path ):
1222+ # Update the previously uploaded egg module and reload it.
1223+ logger .info (
1224+ 'Uploading an update for a previously uploaded a new module "%s" to "%s" on %s "%s"' ,
1225+ self ._module_name ,
1226+ str (self ._filename ),
1227+ "worker" if isinstance (node , Worker ) else "scheduler" ,
1228+ node .id ,
1229+ )
1230+ await self ._upload_file (node )
1231+ return
1232+
1233+ with zipfile .ZipFile (BytesIO (self ._data ), "r" ) as zip_ref :
1234+ # In case, we received egg file for module that exists on node in source code,
1235+ # we overwrite each file separately by extracting it from the egg.
1236+ logger .info (
1237+ 'Uploading an update for an existing module "%s" in "%s" on %s "%s"' ,
1238+ self ._module_name ,
1239+ str (module_path .parent ),
1240+ "worker" if isinstance (node , Worker ) else "scheduler" ,
1241+ node .id ,
1242+ )
1243+ zip_ref .extractall (module_path .parent )
1244+
1245+ # TODO: Do we really need Jupyter's `superreload` here instead of built-in Python's function?
1246+ if self ._module_name in sys .modules :
1247+ # Reload module if it is already loaded
1248+ superreload (sys .modules [self ._module_name ])
1249+
1250+ @classmethod
1251+ def _get_module_dir (cls , module : ModuleType ) -> Path :
1252+ """Get the directory of the module."""
1253+ module_path = Path (sys .modules [module .__name__ ].__file__ )
1254+
1255+ if module_path .stem == "__init__" :
1256+ # In case of package we serialize the whole package
1257+ return module_path .parent
1258+
1259+ # In case of single file we don't need to serialize anything
1260+ return module_path
1261+
1262+
1263+ class UploadModule (WorkerPlugin , AbstractUploadModulePlugin ):
1264+ name = "upload_module"
1265+
1266+ async def setup (self , worker : Worker ):
1267+ await self ._upload (worker )
1268+
1269+
1270+ class SchedulerUploadModule (SchedulerPlugin , AbstractUploadModulePlugin ):
1271+ name = "upload_module"
1272+
1273+ async def start (self , scheduler : Scheduler ) -> None :
1274+ await self ._upload (scheduler )
0 commit comments