Skip to content

Commit 4d3daac

Browse files
committed
chore: Update client.py & backend.py
1 parent 126cd2d commit 4d3daac

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

src/art/client.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import AsyncIterator, Literal, TypedDict, cast
44

55
import httpx
6-
from openai import AsyncOpenAI, BaseModel, _exceptions
76
from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options
87
from openai._compat import cached_property
98
from openai._qs import Querystring
@@ -16,19 +15,27 @@
1615
from openai.resources.models import AsyncModels # noqa: F401
1716
from typing_extensions import override
1817

18+
from openai import AsyncOpenAI, BaseModel, _exceptions
19+
1920
from .trajectories import TrajectoryGroup
2021

2122

2223
class Checkpoint(BaseModel):
2324
id: str
2425
model_id: str
2526
step: int
27+
metrics: dict[str, float]
2628

2729

2830
class CheckpointListParams(TypedDict, total=False):
2931
model_id: str
3032

3133

34+
class DeleteCheckpointsResponse(BaseModel):
35+
deleted_count: int
36+
not_found_steps: list[int]
37+
38+
3239
class Checkpoints(AsyncAPIResource):
3340
async def retrieve(
3441
self, *, model_id: str, step: int | Literal["latest"]
@@ -64,6 +71,16 @@ def list(
6471
model=Checkpoint,
6572
)
6673

74+
async def delete(
75+
self, *, model_id: str, steps: list[int]
76+
) -> DeleteCheckpointsResponse:
77+
return await self._delete(
78+
f"/preview/models/{model_id}/checkpoints",
79+
body={"steps": steps},
80+
cast_to=DeleteCheckpointsResponse,
81+
options=dict(max_retries=0),
82+
)
83+
6784

6885
class Model(BaseModel):
6986
id: str
@@ -128,6 +145,7 @@ async def create(
128145
"base_model": base_model,
129146
"return_existing": return_existing,
130147
},
148+
options=dict(max_retries=0),
131149
)
132150
)
133151

@@ -229,6 +247,7 @@ async def create(
229247
],
230248
"experimental_config": experimental_config,
231249
},
250+
options=dict(max_retries=0),
232251
)
233252

234253
async def retrieve(self, training_job_id: int) -> TrainingJob:

src/art/serverless/backend.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,22 @@ async def _delete_checkpoints(
6969
benchmark: str,
7070
benchmark_smoothing: float,
7171
) -> None:
72-
raise NotImplementedError
72+
# TODO: potentially implement benchmark smoothing
73+
max_metric: float | None = None
74+
max_step: int | None = None
75+
all_steps: list[int] = []
76+
async for checkpoint in self._client.checkpoints.list(model_id=model.id):
77+
metric = checkpoint.metrics.get(benchmark, None)
78+
if metric is not None and (max_metric is None or metric > max_metric):
79+
max_metric = metric
80+
max_step = checkpoint.step
81+
all_steps.append(checkpoint.step)
82+
steps_to_delete = [step for step in all_steps[:-1] if step != max_step]
83+
if steps_to_delete:
84+
await self._client.checkpoints.delete(
85+
model_id=model.id,
86+
steps=steps_to_delete,
87+
)
7388

7489
async def _prepare_backend_for_training(
7590
self,

0 commit comments

Comments
 (0)