Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions sky/client/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from sky.client import common as client_common
from sky.client import oauth as oauth_lib
from sky.jobs import scheduler
from sky.jobs import utils as managed_job_utils
from sky.schemas.api import responses
from sky.server import common as server_common
from sky.server import rest
Expand Down Expand Up @@ -2347,15 +2348,17 @@ def api_stop() -> None:
with filelock.FileLock(
os.path.expanduser(constants.API_SERVER_CREATION_LOCK_PATH)):
try:
with open(os.path.expanduser(scheduler.JOB_CONTROLLER_PID_PATH),
'r',
encoding='utf-8') as f:
pids = f.read().split('\n')[:-1]
for pid in pids:
if subprocess_utils.is_process_alive(int(pid.strip())):
subprocess_utils.kill_children_processes(
parent_pids=[int(pid.strip())], force=True)
os.remove(os.path.expanduser(scheduler.JOB_CONTROLLER_PID_PATH))
records = scheduler.get_controller_process_records()
if records is not None:
for record in records:
try:
if managed_job_utils.controller_process_alive(
record, quiet=False):
subprocess_utils.kill_children_processes(
parent_pids=[record.pid], force=True)
except (psutil.NoSuchProcess, psutil.ZombieProcess):
continue
os.remove(os.path.expanduser(scheduler.JOB_CONTROLLER_PID_PATH))
except FileNotFoundError:
# its fine we will create it
pass
Expand Down
17 changes: 12 additions & 5 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sky import exceptions
from sky import sky_logging
from sky import skypilot_config
from sky.adaptors import common as adaptors_common
from sky.backends import backend_utils
from sky.backends import cloud_vm_ray_backend
from sky.data import data_utils
Expand All @@ -43,6 +44,11 @@
from sky.utils import status_lib
from sky.utils import ux_utils

if typing.TYPE_CHECKING:
import psutil
else:
psutil = adaptors_common.LazyImport('psutil')

logger = sky_logging.init_logger('sky.jobs.controller')

_background_tasks: Set[asyncio.Task] = set()
Expand Down Expand Up @@ -813,6 +819,7 @@ def __init__(self, controller_uuid: str) -> None:
self._starting_signal = asyncio.Condition(lock=self._job_tasks_lock)

self._pid = os.getpid()
self._pid_started_at = psutil.Process(self._pid).create_time()

async def _cleanup(self, job_id: int, pool: Optional[str] = None):
"""Clean up the cluster(s) and storages.
Expand Down Expand Up @@ -930,9 +937,9 @@ async def run_job_loop(self,
assert ctx is not None, 'Context is not initialized'
ctx.redirect_log(pathlib.Path(log_file))

logger.info('Starting job loop for %s', job_id)
logger.info(' log_file=%s', log_file)
logger.info(' pool=%s', pool)
logger.info(f'Starting job loop for {job_id}')
logger.info(f' log_file={log_file}')
logger.info(f' pool={pool}')
logger.info(f'From controller {self._controller_uuid}')
logger.info(f' pid={self._pid}')

Expand Down Expand Up @@ -1099,7 +1106,7 @@ async def cancel_job(self):

async def monitor_loop(self):
"""Monitor the job loop."""
logger.info(f'Starting monitor loop for pid {os.getpid()}...')
logger.info(f'Starting monitor loop for pid {self._pid}...')

while True:
async with self._job_tasks_lock:
Expand Down Expand Up @@ -1132,7 +1139,7 @@ async def monitor_loop(self):
# Check if there are any jobs that are waiting to launch
try:
waiting_job = await managed_job_state.get_waiting_job_async(
pid=-os.getpid())
pid=self._pid, pid_started_at=self._pid_started_at)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like we originally use negative pid but now it is all positive. is this expected? any backward compatibility that needs to be done?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backwards compatibility is handled in the other places in the PR that use the PID.

except Exception as e: # pylint: disable=broad-except
logger.error(f'Failed to get waiting job: {e}')
await asyncio.sleep(5)
Expand Down
109 changes: 80 additions & 29 deletions sky/jobs/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
import shutil
import sys
import typing
from typing import Set
from typing import List, Optional, Set
import uuid

import filelock
Expand Down Expand Up @@ -114,6 +114,71 @@
CURRENT_HASH = os.path.expanduser('~/.sky/wheels/current_sky_wheel_hash')


def _parse_controller_pid_entry(
entry: str) -> Optional[state.ControllerPidRecord]:
entry = entry.strip()
if not entry:
return None
# The entry should be like <pid>,<started_at>
# pid is an integer, started_at is a float
# For backwards compatibility, we also support just <pid>
entry_parts = entry.split(',')
if len(entry_parts) == 2:
[raw_pid, raw_started_at] = entry_parts
elif len(entry_parts) == 1:
# Backwards compatibility, pre-#7847
# TODO(cooperc): Remove for 0.13.0
raw_pid = entry_parts[0]
raw_started_at = None
else:
# Unknown format
return None

try:
pid = int(raw_pid)
except ValueError:
return None

started_at: Optional[float] = None
if raw_started_at:
try:
started_at = float(raw_started_at)
except ValueError:
started_at = None
return state.ControllerPidRecord(pid=pid, started_at=started_at)


def get_controller_process_records(
) -> Optional[List[state.ControllerPidRecord]]:
"""Return recorded controller processes if the file can be read."""
if not os.path.exists(JOB_CONTROLLER_PID_PATH):
# If the file doesn't exist, it means the controller server is not
# running, so we return an empty list
return []
try:
with open(JOB_CONTROLLER_PID_PATH, 'r', encoding='utf-8') as f:
lines = f.read().splitlines()
except (FileNotFoundError, OSError):
return None

records: List[state.ControllerPidRecord] = []
for line in lines:
record = _parse_controller_pid_entry(line)
if record is not None:
records.append(record)
return records


def _append_controller_pid_record(pid: int,
started_at: Optional[float]) -> None:
# Note: started_at is a float, but converting to a string will not lose any
# precision. See https://docs.python.org/3/tutorial/floatingpoint.html and
# https://github.com/python/cpython/issues/53583
entry = str(pid) if started_at is None else f'{pid},{started_at}'
with open(JOB_CONTROLLER_PID_PATH, 'a', encoding='utf-8') as f:
f.write(entry + '\n')


@annotations.lru_cache(scope='global')
def get_number_of_controllers() -> int:
"""Returns the number of controllers that should be running.
Expand Down Expand Up @@ -180,36 +245,21 @@ def start_controller() -> None:
logger.info(f'Running controller with command: {run_cmd}')

pid = subprocess_utils.launch_new_process_tree(run_cmd, log_output=log_path)
with open(JOB_CONTROLLER_PID_PATH, 'a', encoding='utf-8') as f:
f.write(str(pid) + '\n')
pid_started_at = psutil.Process(pid).create_time()
_append_controller_pid_record(pid, pid_started_at)


def get_alive_controllers() -> typing.Optional[int]:
if not os.path.exists(JOB_CONTROLLER_PID_PATH):
# if the file doesn't exist, it means the controller server is not
# running, so we return 0
return 0

try:
with open(JOB_CONTROLLER_PID_PATH, 'r', encoding='utf-8') as f:
pids = f.read().split('\n')[:-1]
except OSError:
# if the file is corrupted, or any issues with reading it, we just
# return None to be safe and not over start
def get_alive_controllers() -> Optional[int]:
records = get_controller_process_records()
if records is None:
# If we cannot read the file reliably, avoid starting extra controllers.
return None
if not records:
return 0

alive = 0
for pid in pids:
try:
# TODO(luca) there is a chance that the process that is alive is
# not the same controller process. a better solution is to also
# include a random UUID with each controller and store that in the
# db as well/in the command that spawns it.
if subprocess_utils.is_process_alive(int(pid.strip())):
alive += 1
except ValueError:
# if the pid is not an integer, let's assume it's alive to not
# over start new processes
for record in records:
if managed_job_utils.controller_process_alive(record, quiet=False):
alive += 1
return alive

Expand Down Expand Up @@ -280,10 +330,11 @@ def submit_job(job_id: int, dag_yaml_path: str, original_user_yaml_path: str,

The user hash should be set (e.g. via SKYPILOT_USER_ID) before calling this.
"""
controller_pid = state.get_job_controller_pid(job_id)
if controller_pid is not None:
controller_process = state.get_job_controller_process(job_id)
if controller_process is not None:
# why? TODO(cooperc): figure out why this is needed, fix it, and remove
if managed_job_utils.controller_process_alive(controller_pid, job_id):
if managed_job_utils.controller_process_alive(controller_process,
job_id):
# This can happen when HA recovery runs for some reason but the job
# controller is still alive.
logger.warning(f'Job {job_id} is still alive, skipping submission')
Expand Down
Loading
Loading