-
Notifications
You must be signed in to change notification settings - Fork 860
Open
Description
import argparse
from dataclasses import dataclass, asdict
import sky
from sky import jobs
@dataclass
class HyperParamConfig:
run_name: str
batch_size: int
learning_rate: float
def run(pool: str):
configs = [
HyperParamConfig("run-v1", 8, 2e-15),
HyperParamConfig("run-v2", 16, 2e-15),
HyperParamConfig("run-v3", 32, 2e-15),
]
for config in configs:
task = sky.Task.from_yaml("train.yaml")
task.update_envs(asdict(config))
jobs.launch(task, name=f"sky-task-{config.run_name}", pool=pool)
print(f"Submitted hyperparameter tuning for {config}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pool",
type=str,
required=True,
help="SkyPilot pool name",
)
run(parser.parse_args().pool)
The second sky jobs launch will fail with the traceback saying that api server is <12. Further investigation, it seems that for the first run, we properly set the remote api server version, but for subsequent runs this version is reset to None
Metadata
Metadata
Assignees
Labels
No labels