Skip to content

Commit 4d61c40

Browse files
committed
chore: Update client.py & backend.py
1 parent 523eb72 commit 4d61c40

File tree

2 files changed

+78
-14
lines changed

2 files changed

+78
-14
lines changed

src/art/client.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
import os
5-
from typing import AsyncIterator, Iterable, Literal, TypedDict, cast
5+
from typing import Any, AsyncIterator, Iterable, Literal, TypedDict, cast
66

77
import httpx
88
from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options
@@ -38,7 +38,6 @@ class DeleteCheckpointsResponse(BaseModel):
3838
not_found_steps: list[int]
3939

4040

41-
4241
class LogResponse(BaseModel):
4342
success: bool
4443

@@ -247,11 +246,12 @@ async def train(trajectory_groups: list[TrajectoryGroup]) -> None:
247246

248247

249248
class ExperimentalTrainingConfig(TypedDict, total=False):
250-
learning_rate: float
249+
learning_rate: float | None
250+
precalculate_logprobs: bool | None
251251

252252

253253
class TrainingJob(BaseModel):
254-
id: int
254+
id: str
255255
status: str
256256
experimental_config: ExperimentalTrainingConfig
257257

@@ -284,6 +284,46 @@ async def retrieve(self, training_job_id: int) -> TrainingJob:
284284
cast_to=TrainingJob,
285285
)
286286

287+
@cached_property
288+
def events(self) -> TrainingJobEvents:
289+
return TrainingJobEvents(cast(AsyncOpenAI, self._client))
290+
291+
292+
class TrainingJobEvent(BaseModel):
293+
id: str
294+
type: Literal["training_started", "gradient_step", "training_ended"]
295+
data: dict[str, Any]
296+
297+
298+
class TrainingJobEventListParams(TypedDict, total=False):
299+
after: str
300+
limit: int
301+
302+
303+
class TrainingJobEvents(AsyncAPIResource):
304+
def list(
305+
self,
306+
*,
307+
training_job_id: str,
308+
after: str | NotGiven = NOT_GIVEN,
309+
limit: int | NotGiven = NOT_GIVEN,
310+
) -> AsyncPaginator[TrainingJobEvent, AsyncCursorPage[TrainingJobEvent]]:
311+
return self._get_api_list(
312+
f"/preview/training-jobs/{training_job_id}/events",
313+
page=AsyncCursorPage[TrainingJobEvent],
314+
options=make_request_options(
315+
query=maybe_transform(
316+
{
317+
"after": after,
318+
"limit": limit,
319+
"training_job_id": training_job_id,
320+
},
321+
TrainingJobEventListParams,
322+
),
323+
),
324+
model=TrainingJobEvent,
325+
)
326+
287327

288328
class Client(AsyncAPIClient):
289329
api_key: str

src/art/serverless/backend.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import asyncio
2-
from typing import TYPE_CHECKING, AsyncIterator, Literal, cast
3-
import os
2+
from typing import TYPE_CHECKING, AsyncIterator, Literal
43

5-
from art.client import Client
4+
from openai._types import NOT_GIVEN
5+
from tqdm import auto as tqdm
6+
7+
from art.client import Client, ExperimentalTrainingConfig
68
from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider
79

810
from .. import dev
@@ -57,7 +59,6 @@ def _model_inference_name(self, model: "TrainableModel") -> str:
5759
assert model.entity is not None, "Model entity is required"
5860
return f"{model.entity}/{model.project}/{model.name}"
5961

60-
6162
async def _get_step(self, model: "Model") -> int:
6263
if model.trainable:
6364
assert model.id is not None, "Model ID is required"
@@ -75,6 +76,7 @@ async def _delete_checkpoints(
7576
benchmark_smoothing: float,
7677
) -> None:
7778
# TODO: potentially implement benchmark smoothing
79+
assert model.id is not None, "Model ID is required"
7880
max_metric: float | None = None
7981
max_step: int | None = None
8082
all_steps: list[int] = []
@@ -110,11 +112,12 @@ async def _log(
110112
print(f"Model {model.name} is not trainable; skipping logging.")
111113
return
112114

115+
assert model.id is not None, "Model ID is required"
116+
113117
await self._client.checkpoints.log_trajectories(
114118
model_id=model.id, trajectory_groups=trajectory_groups, split=split
115119
)
116120

117-
118121
async def _train_model(
119122
self,
120123
model: "TrainableModel",
@@ -124,15 +127,36 @@ async def _train_model(
124127
verbose: bool = False,
125128
) -> AsyncIterator[dict[str, float]]:
126129
assert model.id is not None, "Model ID is required"
130+
127131
training_job = await self._client.training_jobs.create(
128132
model_id=model.id,
129133
trajectory_groups=trajectory_groups,
130-
experimental_config=dict(learning_rate=config.learning_rate),
134+
experimental_config=ExperimentalTrainingConfig(
135+
learning_rate=config.learning_rate,
136+
precalculate_logprobs=dev_config.get("precalculate_logprobs", None),
137+
),
131138
)
132-
while training_job.status != "COMPLETED":
133-
await asyncio.sleep(1)
134-
training_job = await self._client.training_jobs.retrieve(training_job.id)
135-
yield {"num_gradient_steps": 1}
139+
after: str | None = None
140+
num_gradient_steps: int | None = None
141+
pbar: tqdm.tqdm | None = None
142+
while True:
143+
await asyncio.sleep(0.5)
144+
async for event in self._client.training_jobs.events.list(
145+
training_job_id=training_job.id, after=after or NOT_GIVEN
146+
):
147+
if event.type == "gradient_step":
148+
assert pbar is not None and num_gradient_steps is not None
149+
pbar.update(1)
150+
pbar.set_postfix(event.data)
151+
yield {**event.data, "num_gradient_steps": num_gradient_steps}
152+
elif event.type == "training_started":
153+
num_gradient_steps = event.data["num_gradient_steps"]
154+
if pbar is None:
155+
pbar = tqdm.tqdm(total=num_gradient_steps, desc="train")
156+
continue
157+
elif event.type == "training_ended":
158+
return
159+
after = event.id
136160

137161
# ------------------------------------------------------------------
138162
# Experimental support for S3

0 commit comments

Comments
 (0)