|
3 | 3 | from typing import AsyncIterator, Literal, TypedDict, cast |
4 | 4 |
|
5 | 5 | import httpx |
6 | | -from openai import AsyncOpenAI, BaseModel, _exceptions |
7 | 6 | from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options |
8 | 7 | from openai._compat import cached_property |
9 | 8 | from openai._qs import Querystring |
|
16 | 15 | from openai.resources.models import AsyncModels # noqa: F401 |
17 | 16 | from typing_extensions import override |
18 | 17 |
|
| 18 | +from openai import AsyncOpenAI, BaseModel, _exceptions |
| 19 | + |
19 | 20 | from .trajectories import TrajectoryGroup |
20 | 21 |
|
21 | 22 |
|
22 | 23 | class Checkpoint(BaseModel): |
23 | 24 | id: str |
24 | 25 | model_id: str |
25 | 26 | step: int |
| 27 | + metrics: dict[str, float] |
26 | 28 |
|
27 | 29 |
|
28 | 30 | class CheckpointListParams(TypedDict, total=False): |
29 | 31 | model_id: str |
30 | 32 |
|
31 | 33 |
|
| 34 | +class DeleteCheckpointsResponse(BaseModel): |
| 35 | + deleted_count: int |
| 36 | + not_found_steps: list[int] |
| 37 | + |
| 38 | + |
32 | 39 | class Checkpoints(AsyncAPIResource): |
33 | 40 | async def retrieve( |
34 | 41 | self, *, model_id: str, step: int | Literal["latest"] |
@@ -64,6 +71,16 @@ def list( |
64 | 71 | model=Checkpoint, |
65 | 72 | ) |
66 | 73 |
|
| 74 | + async def delete( |
| 75 | + self, *, model_id: str, steps: list[int] |
| 76 | + ) -> DeleteCheckpointsResponse: |
| 77 | + return await self._delete( |
| 78 | + f"/preview/models/{model_id}/checkpoints", |
| 79 | + body={"steps": steps}, |
| 80 | + cast_to=DeleteCheckpointsResponse, |
| 81 | + options=dict(max_retries=0), |
| 82 | + ) |
| 83 | + |
67 | 84 |
|
68 | 85 | class Model(BaseModel): |
69 | 86 | id: str |
@@ -128,6 +145,7 @@ async def create( |
128 | 145 | "base_model": base_model, |
129 | 146 | "return_existing": return_existing, |
130 | 147 | }, |
| 148 | + options=dict(max_retries=0), |
131 | 149 | ) |
132 | 150 | ) |
133 | 151 |
|
@@ -229,6 +247,7 @@ async def create( |
229 | 247 | ], |
230 | 248 | "experimental_config": experimental_config, |
231 | 249 | }, |
| 250 | + options=dict(max_retries=0), |
232 | 251 | ) |
233 | 252 |
|
234 | 253 | async def retrieve(self, training_job_id: int) -> TrainingJob: |
|
0 commit comments