Skip to content

Commit b765b33

Browse files
cg505cblmemo
andauthored
[jobs] record process timestamp to protect against reboot/pid reuse (#7847)
* [jobs] record process timestamp to protect against reboot/pid reuse * add test * lint * bump SKYLET_VERSION * add migration * address review comments, clean up * lint * fix tests and backwards compatibility for really old jobs * lint * Apply suggestion from @cblmemo Co-authored-by: Tian Xia <[email protected]> --------- Co-authored-by: Tian Xia <[email protected]>
1 parent 1a8f65d commit b765b33

File tree

12 files changed

+417
-127
lines changed

12 files changed

+417
-127
lines changed

sky/client/sdk.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sky.client import common as client_common
3333
from sky.client import oauth as oauth_lib
3434
from sky.jobs import scheduler
35+
from sky.jobs import utils as managed_job_utils
3536
from sky.schemas.api import responses
3637
from sky.server import common as server_common
3738
from sky.server import rest
@@ -2347,15 +2348,17 @@ def api_stop() -> None:
23472348
with filelock.FileLock(
23482349
os.path.expanduser(constants.API_SERVER_CREATION_LOCK_PATH)):
23492350
try:
2350-
with open(os.path.expanduser(scheduler.JOB_CONTROLLER_PID_PATH),
2351-
'r',
2352-
encoding='utf-8') as f:
2353-
pids = f.read().split('\n')[:-1]
2354-
for pid in pids:
2355-
if subprocess_utils.is_process_alive(int(pid.strip())):
2356-
subprocess_utils.kill_children_processes(
2357-
parent_pids=[int(pid.strip())], force=True)
2358-
os.remove(os.path.expanduser(scheduler.JOB_CONTROLLER_PID_PATH))
2351+
records = scheduler.get_controller_process_records()
2352+
if records is not None:
2353+
for record in records:
2354+
try:
2355+
if managed_job_utils.controller_process_alive(
2356+
record, quiet=False):
2357+
subprocess_utils.kill_children_processes(
2358+
parent_pids=[record.pid], force=True)
2359+
except (psutil.NoSuchProcess, psutil.ZombieProcess):
2360+
continue
2361+
os.remove(os.path.expanduser(scheduler.JOB_CONTROLLER_PID_PATH))
23592362
except FileNotFoundError:
23602363
# its fine we will create it
23612364
pass

sky/jobs/controller.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sky import exceptions
2121
from sky import sky_logging
2222
from sky import skypilot_config
23+
from sky.adaptors import common as adaptors_common
2324
from sky.backends import backend_utils
2425
from sky.backends import cloud_vm_ray_backend
2526
from sky.data import data_utils
@@ -43,6 +44,11 @@
4344
from sky.utils import status_lib
4445
from sky.utils import ux_utils
4546

47+
if typing.TYPE_CHECKING:
48+
import psutil
49+
else:
50+
psutil = adaptors_common.LazyImport('psutil')
51+
4652
logger = sky_logging.init_logger('sky.jobs.controller')
4753

4854
_background_tasks: Set[asyncio.Task] = set()
@@ -813,6 +819,7 @@ def __init__(self, controller_uuid: str) -> None:
813819
self._starting_signal = asyncio.Condition(lock=self._job_tasks_lock)
814820

815821
self._pid = os.getpid()
822+
self._pid_started_at = psutil.Process(self._pid).create_time()
816823

817824
async def _cleanup(self, job_id: int, pool: Optional[str] = None):
818825
"""Clean up the cluster(s) and storages.
@@ -930,9 +937,9 @@ async def run_job_loop(self,
930937
assert ctx is not None, 'Context is not initialized'
931938
ctx.redirect_log(pathlib.Path(log_file))
932939

933-
logger.info('Starting job loop for %s', job_id)
934-
logger.info(' log_file=%s', log_file)
935-
logger.info(' pool=%s', pool)
940+
logger.info(f'Starting job loop for {job_id}')
941+
logger.info(f' log_file={log_file}')
942+
logger.info(f' pool={pool}')
936943
logger.info(f'From controller {self._controller_uuid}')
937944
logger.info(f' pid={self._pid}')
938945

@@ -1099,7 +1106,7 @@ async def cancel_job(self):
10991106

11001107
async def monitor_loop(self):
11011108
"""Monitor the job loop."""
1102-
logger.info(f'Starting monitor loop for pid {os.getpid()}...')
1109+
logger.info(f'Starting monitor loop for pid {self._pid}...')
11031110

11041111
while True:
11051112
async with self._job_tasks_lock:
@@ -1132,7 +1139,7 @@ async def monitor_loop(self):
11321139
# Check if there are any jobs that are waiting to launch
11331140
try:
11341141
waiting_job = await managed_job_state.get_waiting_job_async(
1135-
pid=-os.getpid())
1142+
pid=self._pid, pid_started_at=self._pid_started_at)
11361143
except Exception as e: # pylint: disable=broad-except
11371144
logger.error(f'Failed to get waiting job: {e}')
11381145
await asyncio.sleep(5)

sky/jobs/scheduler.py

Lines changed: 80 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
import shutil
5050
import sys
5151
import typing
52-
from typing import Set
52+
from typing import List, Optional, Set
5353
import uuid
5454

5555
import filelock
@@ -114,6 +114,71 @@
114114
CURRENT_HASH = os.path.expanduser('~/.sky/wheels/current_sky_wheel_hash')
115115

116116

117+
def _parse_controller_pid_entry(
118+
entry: str) -> Optional[state.ControllerPidRecord]:
119+
entry = entry.strip()
120+
if not entry:
121+
return None
122+
# The entry should be like <pid>,<started_at>
123+
# pid is an integer, started_at is a float
124+
# For backwards compatibility, we also support just <pid>
125+
entry_parts = entry.split(',')
126+
if len(entry_parts) == 2:
127+
[raw_pid, raw_started_at] = entry_parts
128+
elif len(entry_parts) == 1:
129+
# Backwards compatibility, pre-#7847
130+
# TODO(cooperc): Remove for 0.13.0
131+
raw_pid = entry_parts[0]
132+
raw_started_at = None
133+
else:
134+
# Unknown format
135+
return None
136+
137+
try:
138+
pid = int(raw_pid)
139+
except ValueError:
140+
return None
141+
142+
started_at: Optional[float] = None
143+
if raw_started_at:
144+
try:
145+
started_at = float(raw_started_at)
146+
except ValueError:
147+
started_at = None
148+
return state.ControllerPidRecord(pid=pid, started_at=started_at)
149+
150+
151+
def get_controller_process_records(
152+
) -> Optional[List[state.ControllerPidRecord]]:
153+
"""Return recorded controller processes if the file can be read."""
154+
if not os.path.exists(JOB_CONTROLLER_PID_PATH):
155+
# If the file doesn't exist, it means the controller server is not
156+
# running, so we return an empty list
157+
return []
158+
try:
159+
with open(JOB_CONTROLLER_PID_PATH, 'r', encoding='utf-8') as f:
160+
lines = f.read().splitlines()
161+
except (FileNotFoundError, OSError):
162+
return None
163+
164+
records: List[state.ControllerPidRecord] = []
165+
for line in lines:
166+
record = _parse_controller_pid_entry(line)
167+
if record is not None:
168+
records.append(record)
169+
return records
170+
171+
172+
def _append_controller_pid_record(pid: int,
173+
started_at: Optional[float]) -> None:
174+
# Note: started_at is a float, but converting to a string will not lose any
175+
# precision. See https://docs.python.org/3/tutorial/floatingpoint.html and
176+
# https://github.com/python/cpython/issues/53583
177+
entry = str(pid) if started_at is None else f'{pid},{started_at}'
178+
with open(JOB_CONTROLLER_PID_PATH, 'a', encoding='utf-8') as f:
179+
f.write(entry + '\n')
180+
181+
117182
@annotations.lru_cache(scope='global')
118183
def get_number_of_controllers() -> int:
119184
"""Returns the number of controllers that should be running.
@@ -180,36 +245,21 @@ def start_controller() -> None:
180245
logger.info(f'Running controller with command: {run_cmd}')
181246

182247
pid = subprocess_utils.launch_new_process_tree(run_cmd, log_output=log_path)
183-
with open(JOB_CONTROLLER_PID_PATH, 'a', encoding='utf-8') as f:
184-
f.write(str(pid) + '\n')
248+
pid_started_at = psutil.Process(pid).create_time()
249+
_append_controller_pid_record(pid, pid_started_at)
185250

186251

187-
def get_alive_controllers() -> typing.Optional[int]:
188-
if not os.path.exists(JOB_CONTROLLER_PID_PATH):
189-
# if the file doesn't exist, it means the controller server is not
190-
# running, so we return 0
191-
return 0
192-
193-
try:
194-
with open(JOB_CONTROLLER_PID_PATH, 'r', encoding='utf-8') as f:
195-
pids = f.read().split('\n')[:-1]
196-
except OSError:
197-
# if the file is corrupted, or any issues with reading it, we just
198-
# return None to be safe and not over start
252+
def get_alive_controllers() -> Optional[int]:
253+
records = get_controller_process_records()
254+
if records is None:
255+
# If we cannot read the file reliably, avoid starting extra controllers.
199256
return None
257+
if not records:
258+
return 0
200259

201260
alive = 0
202-
for pid in pids:
203-
try:
204-
# TODO(luca) there is a chance that the process that is alive is
205-
# not the same controller process. a better solution is to also
206-
# include a random UUID with each controller and store that in the
207-
# db as well/in the command that spawns it.
208-
if subprocess_utils.is_process_alive(int(pid.strip())):
209-
alive += 1
210-
except ValueError:
211-
# if the pid is not an integer, let's assume it's alive to not
212-
# over start new processes
261+
for record in records:
262+
if managed_job_utils.controller_process_alive(record, quiet=False):
213263
alive += 1
214264
return alive
215265

@@ -280,10 +330,11 @@ def submit_job(job_id: int, dag_yaml_path: str, original_user_yaml_path: str,
280330
281331
The user hash should be set (e.g. via SKYPILOT_USER_ID) before calling this.
282332
"""
283-
controller_pid = state.get_job_controller_pid(job_id)
284-
if controller_pid is not None:
333+
controller_process = state.get_job_controller_process(job_id)
334+
if controller_process is not None:
285335
# why? TODO(cooperc): figure out why this is needed, fix it, and remove
286-
if managed_job_utils.controller_process_alive(controller_pid, job_id):
336+
if managed_job_utils.controller_process_alive(controller_process,
337+
job_id):
287338
# This can happen when HA recovery runs for some reason but the job
288339
# controller is still alive.
289340
logger.warning(f'Job {job_id} is still alive, skipping submission')

0 commit comments

Comments
 (0)