Skip to content

Commit 2bca1fe

Browse files
committed
Implement module upload plugin (#8698)
1 parent 3d2685c commit 2bca1fe

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

distributed/diagnostics/plugin.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,21 @@
55
import functools
66
import logging
77
import os
8+
import shutil
89
import socket
910
import subprocess
1011
import sys
1112
import tempfile
1213
import uuid
1314
import zipfile
1415
from collections.abc import Awaitable
16+
from contextlib import contextmanager
17+
from importlib.util import find_spec
18+
from io import BytesIO
1519
from typing import TYPE_CHECKING, Any, Callable, ClassVar
20+
from types import ModuleType
21+
from typing import Any, Tuple
22+
from pathlib import Path
1623

1724
from dask.typing import Key
1825
from dask.utils import _deprecated_kwarg, funcname, tmpfile
@@ -29,6 +36,7 @@
2936
from distributed.scheduler import TaskStateState as SchedulerTaskStateState
3037
from distributed.worker import Worker
3138
from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState
39+
from distributed.node import ServerNode
3240

3341
logger = logging.getLogger(__name__)
3442

@@ -1102,3 +1110,165 @@ def setup(self, worker):
11021110

11031111
def teardown(self, worker):
11041112
self._exit_stack.close()
1113+
1114+
1115+
@contextmanager
1116+
def serialize_module(
1117+
module: ModuleType, exclude: Tuple[str] = ("__pycache__", ".DS_Store")
1118+
) -> Path:
1119+
module_path = Path(module.__file__)
1120+
1121+
if module_path.stem == "__init__":
1122+
# In case of package we serialize the whole package
1123+
module_path = module_path.parent
1124+
if "." in module.__name__:
1125+
# TODO: the problem is that we serialize the `package.module`, as module.egg that contains module.py,
1126+
# but it should contain the whole structure of the package (package/module.py)
1127+
raise Exception(
1128+
f"Plugin supports only top-level packages or single-file modules. You provided `{module.__name__}`, try `{module.__name__.split('.')[0]}`."
1129+
)
1130+
1131+
# In case of single file we don't need to serialize anything
1132+
1133+
with tempfile.TemporaryDirectory() as tmp:
1134+
package_name = module_path.name
1135+
1136+
package_copy_path = Path(tmp).joinpath(package_name)
1137+
if module_path.is_dir():
1138+
copied_package = Path(
1139+
shutil.copytree(
1140+
module_path,
1141+
package_copy_path,
1142+
ignore=shutil.ignore_patterns(f"{package_name}.zip", *exclude),
1143+
)
1144+
)
1145+
else:
1146+
copied_package = Path(shutil.copy2(module_path, package_copy_path))
1147+
1148+
archive_path = shutil.make_archive(
1149+
# output path including a name w/o extension
1150+
base_name=str(copied_package),
1151+
format="zip",
1152+
# chroot
1153+
root_dir=copied_package.parent,
1154+
# Name of the directory to archive and a common prefix of all files and directories in the archive
1155+
base_dir=package_name,
1156+
)
1157+
1158+
egg_file = shutil.move(archive_path, package_copy_path.with_suffix(".egg"))
1159+
1160+
# zip file handler
1161+
zip = zipfile.ZipFile(egg_file)
1162+
# list available files in the container
1163+
logger.debug(
1164+
"The egg file %s contains the following files %s",
1165+
str(egg_file),
1166+
str(zip.namelist()),
1167+
)
1168+
1169+
logger.info("Created an egg file %s from %s", str(egg_file), str(module_path))
1170+
1171+
yield Path(egg_file)
1172+
1173+
1174+
class AbstractUploadModulePlugin:
1175+
def __init__(self, module: ModuleType):
1176+
self._module_name = module.__name__
1177+
self._data: bytes
1178+
self._filepath: Path
1179+
self._filename: str
1180+
with serialize_module(module) as filepath:
1181+
self._filename = filepath.name
1182+
with open(filepath, "rb") as f:
1183+
self._data = f.read()
1184+
1185+
async def _upload_file(self, node: ServerNode):
1186+
response = await node.upload_file(self._filename, self._data, load=True)
1187+
assert len(self._data) == response["nbytes"]
1188+
1189+
async def _upload(self, node: ServerNode):
1190+
import zipfile
1191+
import sys
1192+
try:
1193+
from IPython.extensions.autoreload import superreload
1194+
except ImportError:
1195+
superreload = lambda x: x
1196+
1197+
# Try to find already loaded module
1198+
module = (
1199+
sys.modules[self._module_name] if self._module_name in sys.modules else None
1200+
)
1201+
# Try to find module on disk
1202+
module_spec = find_spec(self._module_name)
1203+
1204+
if not module_spec and not module:
1205+
# If module does not exist we keep it as egg file and load it.
1206+
logger.info(
1207+
'Uploading a new module "%s" to "%s" on %s "%s"',
1208+
self._module_name,
1209+
str(self._filename),
1210+
"worker" if isinstance(node, Worker) else "scheduler",
1211+
node.id,
1212+
)
1213+
await self._upload_file(node)
1214+
return
1215+
1216+
if module:
1217+
module_path = self._get_module_dir(module)
1218+
else:
1219+
module_path = Path(module_spec.origin)
1220+
1221+
if ".egg" in str(module_path):
1222+
# Update the previously uploaded egg module and reload it.
1223+
logger.info(
1224+
'Uploading an update for a previously uploaded a new module "%s" to "%s" on %s "%s"',
1225+
self._module_name,
1226+
str(self._filename),
1227+
"worker" if isinstance(node, Worker) else "scheduler",
1228+
node.id,
1229+
)
1230+
await self._upload_file(node)
1231+
return
1232+
1233+
with zipfile.ZipFile(BytesIO(self._data), "r") as zip_ref:
1234+
# In case, we received egg file for module that exists on node in source code,
1235+
# we overwrite each file separately by extracting it from the egg.
1236+
logger.info(
1237+
'Uploading an update for an existing module "%s" in "%s" on %s "%s"',
1238+
self._module_name,
1239+
str(module_path.parent),
1240+
"worker" if isinstance(node, Worker) else "scheduler",
1241+
node.id,
1242+
)
1243+
zip_ref.extractall(module_path.parent)
1244+
1245+
# TODO: Do we really need Jupyter's `superreload` here instead of built-in Python's function?
1246+
if self._module_name in sys.modules:
1247+
# Reload module if it is already loaded
1248+
superreload(sys.modules[self._module_name])
1249+
1250+
@classmethod
1251+
def _get_module_dir(cls, module: ModuleType) -> Path:
1252+
"""Get the directory of the module."""
1253+
module_path = Path(sys.modules[module.__name__].__file__)
1254+
1255+
if module_path.stem == "__init__":
1256+
# In case of package we serialize the whole package
1257+
return module_path.parent
1258+
1259+
# In case of single file we don't need to serialize anything
1260+
return module_path
1261+
1262+
1263+
class UploadModule(WorkerPlugin, AbstractUploadModulePlugin):
1264+
name = "upload_module"
1265+
1266+
async def setup(self, worker: Worker):
1267+
await self._upload(worker)
1268+
1269+
1270+
class SchedulerUploadModule(SchedulerPlugin, AbstractUploadModulePlugin):
1271+
name = "upload_module"
1272+
1273+
async def start(self, scheduler: Scheduler) -> None:
1274+
await self._upload(scheduler)

0 commit comments

Comments
 (0)