Skip to content

Commit c0b9806

Browse files
committed
refactor: Clean up client.py & backend.py
1 parent c0365f0 commit c0b9806

File tree

2 files changed

+103
-238
lines changed

2 files changed

+103
-238
lines changed

src/art/client.py

Lines changed: 82 additions & 213 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
from __future__ import annotations
2-
3-
import asyncio
41
import os
5-
from typing import Any, AsyncIterator, Iterable, Literal, TypedDict, cast
2+
from typing import Any, Iterable, Literal, TypedDict, cast
63

74
import httpx
8-
from openai import AsyncOpenAI, BaseModel, _exceptions
95
from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options
106
from openai._compat import cached_property
117
from openai._qs import Querystring
@@ -18,81 +14,87 @@
1814
from openai.resources.models import AsyncModels # noqa: F401
1915
from typing_extensions import override
2016

17+
from openai import AsyncOpenAI, BaseModel, _exceptions
18+
2119
from .trajectories import TrajectoryGroup
2220

2321

22+
class Model(BaseModel):
23+
id: str
24+
entity: str
25+
project: str
26+
name: str
27+
base_model: str
28+
29+
2430
class Checkpoint(BaseModel):
2531
id: str
26-
model_id: str
2732
step: int
2833
metrics: dict[str, float]
2934

3035

3136
class CheckpointListParams(TypedDict, total=False):
32-
model_id: str
37+
after: str
38+
limit: int
39+
order: Literal["asc", "desc"]
3340

3441

3542
class DeleteCheckpointsResponse(BaseModel):
3643
deleted_count: int
3744
not_found_steps: list[int]
3845

3946

40-
class LogResponse(BaseModel):
41-
success: bool
47+
class ExperimentalTrainingConfig(TypedDict, total=False):
48+
learning_rate: float | None
49+
precalculate_logprobs: bool | None
4250

4351

44-
class Checkpoints(AsyncAPIResource):
45-
async def retrieve(
46-
self, *, model_id: str, step: int | Literal["latest"]
47-
) -> Checkpoint:
48-
return await self._get(
49-
f"/preview/models/{model_id}/checkpoints/{step}",
50-
cast_to=Checkpoint,
51-
)
52+
class TrainingJob(BaseModel):
53+
id: str
5254

53-
def list(
55+
56+
class TrainingJobEventListParams(TypedDict, total=False):
57+
after: str
58+
limit: int
59+
60+
61+
class TrainingJobEvent(BaseModel):
62+
id: str
63+
type: Literal[
64+
"training_started", "gradient_step", "training_ended", "training_failed"
65+
]
66+
data: dict[str, Any]
67+
68+
69+
class Models(AsyncAPIResource):
70+
async def create(
5471
self,
5572
*,
56-
after: str | NotGiven = NOT_GIVEN,
57-
limit: int | NotGiven = NOT_GIVEN,
58-
model_id: str,
59-
) -> AsyncPaginator[Checkpoint, AsyncCursorPage[Checkpoint]]:
60-
return self._get_api_list(
61-
f"/preview/models/{model_id}/checkpoints",
62-
page=AsyncCursorPage[Checkpoint],
63-
options=make_request_options(
64-
# extra_headers=extra_headers,
65-
# extra_query=extra_query,
66-
# extra_body=extra_body,
67-
# timeout=timeout,
68-
query=maybe_transform(
69-
{
70-
"after": after,
71-
"limit": limit,
72-
},
73-
CheckpointListParams,
74-
),
75-
),
76-
model=Checkpoint,
77-
)
78-
79-
async def delete(
80-
self, *, model_id: str, steps: Iterable[int]
81-
) -> DeleteCheckpointsResponse:
82-
return await self._delete(
83-
f"/preview/models/{model_id}/checkpoints",
84-
body={"steps": steps},
85-
cast_to=DeleteCheckpointsResponse,
86-
options=dict(max_retries=0),
73+
entity: str | None = None,
74+
project: str | None = None,
75+
name: str | None = None,
76+
base_model: str,
77+
return_existing: bool = False,
78+
) -> Model:
79+
return await self._post(
80+
"/preview/models",
81+
cast_to=Model,
82+
body={
83+
"entity": entity,
84+
"project": project,
85+
"name": name,
86+
"base_model": base_model,
87+
"return_existing": return_existing,
88+
},
8789
)
8890

89-
async def log_trajectories(
91+
async def log(
9092
self,
9193
*,
9294
model_id: str,
9395
trajectory_groups: list[TrajectoryGroup],
94-
split: str = "val",
95-
) -> LogResponse:
96+
split: str,
97+
) -> None:
9698
return await self._post(
9799
f"/preview/models/{model_id}/log",
98100
body={
@@ -103,156 +105,47 @@ async def log_trajectories(
103105
],
104106
"split": split,
105107
},
106-
cast_to=LogResponse,
107-
options=dict(max_retries=0),
108+
cast_to=type(None),
108109
)
109110

110-
111-
class Model(BaseModel):
112-
id: str
113-
entity: str
114-
project: str
115-
name: str
116-
base_model: str
117-
118-
async def get_step(self) -> int:
119-
raise NotImplementedError
120-
121-
async def train(self, trajectory_groups: list[TrajectoryGroup]) -> None:
122-
raise NotImplementedError
123-
124-
125-
class ModelListParams(TypedDict, total=False):
126-
after: str
127-
"""A cursor for use in pagination.
128-
129-
`after` is an object ID that defines your place in the list. For instance, if
130-
you make a list request and receive 100 objects, ending with obj_foo, your
131-
subsequent call can include after=obj_foo in order to fetch the next page of the
132-
list.
133-
"""
134-
135-
limit: int
136-
"""A limit on the number of objects to be returned.
137-
138-
Limit can range between 1 and 100, and the default is 20.
139-
"""
140-
141-
# order: Literal["asc", "desc"]
142-
# """Sort order by the `created_at` timestamp of the objects.
143-
144-
# `asc` for ascending order and `desc` for descending order.
145-
# """
146-
147-
entity: str
148-
project: str
149-
name: str
150-
base_model: str
111+
@cached_property
112+
def checkpoints(self) -> "Checkpoints":
113+
return Checkpoints(cast(AsyncOpenAI, self._client))
151114

152115

153-
class Models(AsyncAPIResource):
154-
async def create(
155-
self,
156-
*,
157-
entity: str | None = None,
158-
project: str | None = None,
159-
name: str | None = None,
160-
base_model: str,
161-
return_existing: bool = False,
162-
) -> Model:
163-
return self._patch_model(
164-
await self._post(
165-
"/preview/models",
166-
cast_to=Model,
167-
body={
168-
"entity": entity,
169-
"project": project,
170-
"name": name,
171-
"base_model": base_model,
172-
"return_existing": return_existing,
173-
},
174-
options=dict(max_retries=0),
175-
)
176-
)
177-
178-
async def list(
116+
class Checkpoints(AsyncAPIResource):
117+
def list(
179118
self,
180119
*,
181120
after: str | NotGiven = NOT_GIVEN,
182121
limit: int | NotGiven = NOT_GIVEN,
183-
# order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
184-
entity: str | NotGiven = NOT_GIVEN,
185-
project: str | NotGiven = NOT_GIVEN,
186-
name: str | NotGiven = NOT_GIVEN,
187-
base_model: str | NotGiven = NOT_GIVEN,
188-
# # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
189-
# # The extra values given here take precedence over values defined on the client or passed to this method.
190-
# extra_headers: Headers | None = None,
191-
# extra_query: Query | None = None,
192-
# extra_body: Body | None = None,
193-
# timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
194-
) -> AsyncIterator[Model]:
195-
"""
196-
Lists the currently available models, and provides basic information about each
197-
one such as the owner and availability.
198-
"""
199-
async for model in self._get_api_list(
200-
"/preview/models",
201-
page=AsyncCursorPage[Model],
122+
model_id: str,
123+
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
124+
) -> AsyncPaginator[Checkpoint, AsyncCursorPage[Checkpoint]]:
125+
return self._get_api_list(
126+
f"/preview/models/{model_id}/checkpoints",
127+
page=AsyncCursorPage[Checkpoint],
202128
options=make_request_options(
203-
# extra_headers=extra_headers,
204-
# extra_query=extra_query,
205-
# extra_body=extra_body,
206-
# timeout=timeout,
207129
query=maybe_transform(
208130
{
209131
"after": after,
210132
"limit": limit,
211-
# "order": order,
212-
"entity": entity,
213-
"project": project,
214-
"name": name,
215-
"base_model": base_model,
133+
"order": order,
216134
},
217-
ModelListParams,
135+
CheckpointListParams,
218136
),
219137
),
220-
model=Model,
221-
):
222-
yield self._patch_model(model)
223-
224-
def _patch_model(self, model: Model) -> Model:
225-
"""Patch model instance with async method implementations."""
226-
227-
async def get_step() -> int:
228-
return 0
229-
230-
model.get_step = get_step
231-
232-
async def train(trajectory_groups: list[TrajectoryGroup]) -> None:
233-
training_job = await cast("Client", self._client).training_jobs.create(
234-
model_id=model.id,
235-
trajectory_groups=trajectory_groups,
236-
)
237-
while training_job.status != "COMPLETED":
238-
await asyncio.sleep(1)
239-
training_job = await cast(
240-
"Client", self._client
241-
).training_jobs.retrieve(training_job.id)
242-
243-
model.train = train
244-
return model
245-
246-
247-
class ExperimentalTrainingConfig(TypedDict, total=False):
248-
learning_rate: float | None
249-
precalculate_logprobs: bool | None
250-
138+
model=Checkpoint,
139+
)
251140

252-
class TrainingJob(BaseModel):
253-
id: str
254-
status: str
255-
experimental_config: ExperimentalTrainingConfig
141+
async def delete(
142+
self, *, model_id: str, steps: Iterable[int]
143+
) -> DeleteCheckpointsResponse:
144+
return await self._delete(
145+
f"/preview/models/{model_id}/checkpoints",
146+
body={"steps": steps},
147+
cast_to=DeleteCheckpointsResponse,
148+
)
256149

257150

258151
class TrainingJobs(AsyncAPIResource):
@@ -269,38 +162,18 @@ async def create(
269162
body={
270163
"model_id": model_id,
271164
"trajectory_groups": [
272-
trajectory_group.model_dump()
165+
trajectory_group.model_dump(mode="json")
273166
for trajectory_group in trajectory_groups
274167
],
275168
"experimental_config": experimental_config,
276169
},
277-
options=dict(max_retries=0),
278-
)
279-
280-
async def retrieve(self, training_job_id: int) -> TrainingJob:
281-
return await self._get(
282-
f"/preview/training-jobs/{training_job_id}",
283-
cast_to=TrainingJob,
284170
)
285171

286172
@cached_property
287-
def events(self) -> TrainingJobEvents:
173+
def events(self) -> "TrainingJobEvents":
288174
return TrainingJobEvents(cast(AsyncOpenAI, self._client))
289175

290176

291-
class TrainingJobEvent(BaseModel):
292-
id: str
293-
type: Literal[
294-
"training_started", "gradient_step", "training_ended", "training_failed"
295-
]
296-
data: dict[str, Any]
297-
298-
299-
class TrainingJobEventListParams(TypedDict, total=False):
300-
after: str
301-
limit: int
302-
303-
304177
class TrainingJobEvents(AsyncAPIResource):
305178
def list(
306179
self,
@@ -317,7 +190,6 @@ def list(
317190
{
318191
"after": after,
319192
"limit": limit,
320-
"training_job_id": training_job_id,
321193
},
322194
TrainingJobEventListParams,
323195
),
@@ -341,18 +213,15 @@ def __init__(
341213
self.api_key = api_key
342214
super().__init__(
343215
version=__version__,
344-
base_url=base_url or "http://0.0.0.0:8000/v1",
216+
base_url=base_url or "https://api.training.wandb.ai/v1",
345217
_strict_response_validation=False,
218+
max_retries=0,
346219
)
347220

348221
@cached_property
349222
def models(self) -> Models:
350223
return Models(cast(AsyncOpenAI, self))
351224

352-
@cached_property
353-
def checkpoints(self) -> Checkpoints:
354-
return Checkpoints(cast(AsyncOpenAI, self))
355-
356225
@cached_property
357226
def training_jobs(self) -> TrainingJobs:
358227
return TrainingJobs(cast(AsyncOpenAI, self))

0 commit comments

Comments
 (0)