1- from __future__ import annotations
2-
3- import asyncio
41import os
5- from typing import Any , AsyncIterator , Iterable , Literal , TypedDict , cast
2+ from typing import Any , Iterable , Literal , TypedDict , cast
63
74import httpx
8- from openai import AsyncOpenAI , BaseModel , _exceptions
95from openai ._base_client import AsyncAPIClient , AsyncPaginator , make_request_options
106from openai ._compat import cached_property
117from openai ._qs import Querystring
1814from openai .resources .models import AsyncModels # noqa: F401
1915from typing_extensions import override
2016
17+ from openai import AsyncOpenAI , BaseModel , _exceptions
18+
2119from .trajectories import TrajectoryGroup
2220
2321
22+ class Model (BaseModel ):
23+ id : str
24+ entity : str
25+ project : str
26+ name : str
27+ base_model : str
28+
29+
2430class Checkpoint (BaseModel ):
2531 id : str
26- model_id : str
2732 step : int
2833 metrics : dict [str , float ]
2934
3035
3136class CheckpointListParams (TypedDict , total = False ):
32- model_id : str
37+ after : str
38+ limit : int
39+ order : Literal ["asc" , "desc" ]
3340
3441
3542class DeleteCheckpointsResponse (BaseModel ):
3643 deleted_count : int
3744 not_found_steps : list [int ]
3845
3946
40- class LogResponse (BaseModel ):
41- success : bool
47+ class ExperimentalTrainingConfig (TypedDict , total = False ):
48+ learning_rate : float | None
49+ precalculate_logprobs : bool | None
4250
4351
44- class Checkpoints (AsyncAPIResource ):
45- async def retrieve (
46- self , * , model_id : str , step : int | Literal ["latest" ]
47- ) -> Checkpoint :
48- return await self ._get (
49- f"/preview/models/{ model_id } /checkpoints/{ step } " ,
50- cast_to = Checkpoint ,
51- )
52+ class TrainingJob (BaseModel ):
53+ id : str
5254
53- def list (
55+
56+ class TrainingJobEventListParams (TypedDict , total = False ):
57+ after : str
58+ limit : int
59+
60+
61+ class TrainingJobEvent (BaseModel ):
62+ id : str
63+ type : Literal [
64+ "training_started" , "gradient_step" , "training_ended" , "training_failed"
65+ ]
66+ data : dict [str , Any ]
67+
68+
69+ class Models (AsyncAPIResource ):
70+ async def create (
5471 self ,
5572 * ,
56- after : str | NotGiven = NOT_GIVEN ,
57- limit : int | NotGiven = NOT_GIVEN ,
58- model_id : str ,
59- ) -> AsyncPaginator [Checkpoint , AsyncCursorPage [Checkpoint ]]:
60- return self ._get_api_list (
61- f"/preview/models/{ model_id } /checkpoints" ,
62- page = AsyncCursorPage [Checkpoint ],
63- options = make_request_options (
64- # extra_headers=extra_headers,
65- # extra_query=extra_query,
66- # extra_body=extra_body,
67- # timeout=timeout,
68- query = maybe_transform (
69- {
70- "after" : after ,
71- "limit" : limit ,
72- },
73- CheckpointListParams ,
74- ),
75- ),
76- model = Checkpoint ,
77- )
78-
79- async def delete (
80- self , * , model_id : str , steps : Iterable [int ]
81- ) -> DeleteCheckpointsResponse :
82- return await self ._delete (
83- f"/preview/models/{ model_id } /checkpoints" ,
84- body = {"steps" : steps },
85- cast_to = DeleteCheckpointsResponse ,
86- options = dict (max_retries = 0 ),
73+ entity : str | None = None ,
74+ project : str | None = None ,
75+ name : str | None = None ,
76+ base_model : str ,
77+ return_existing : bool = False ,
78+ ) -> Model :
79+ return await self ._post (
80+ "/preview/models" ,
81+ cast_to = Model ,
82+ body = {
83+ "entity" : entity ,
84+ "project" : project ,
85+ "name" : name ,
86+ "base_model" : base_model ,
87+ "return_existing" : return_existing ,
88+ },
8789 )
8890
89- async def log_trajectories (
91+ async def log (
9092 self ,
9193 * ,
9294 model_id : str ,
9395 trajectory_groups : list [TrajectoryGroup ],
94- split : str = "val" ,
95- ) -> LogResponse :
96+ split : str ,
97+ ) -> None :
9698 return await self ._post (
9799 f"/preview/models/{ model_id } /log" ,
98100 body = {
@@ -103,156 +105,47 @@ async def log_trajectories(
103105 ],
104106 "split" : split ,
105107 },
106- cast_to = LogResponse ,
107- options = dict (max_retries = 0 ),
108+ cast_to = type (None ),
108109 )
109110
110-
111- class Model (BaseModel ):
112- id : str
113- entity : str
114- project : str
115- name : str
116- base_model : str
117-
118- async def get_step (self ) -> int :
119- raise NotImplementedError
120-
121- async def train (self , trajectory_groups : list [TrajectoryGroup ]) -> None :
122- raise NotImplementedError
123-
124-
125- class ModelListParams (TypedDict , total = False ):
126- after : str
127- """A cursor for use in pagination.
128-
129- `after` is an object ID that defines your place in the list. For instance, if
130- you make a list request and receive 100 objects, ending with obj_foo, your
131- subsequent call can include after=obj_foo in order to fetch the next page of the
132- list.
133- """
134-
135- limit : int
136- """A limit on the number of objects to be returned.
137-
138- Limit can range between 1 and 100, and the default is 20.
139- """
140-
141- # order: Literal["asc", "desc"]
142- # """Sort order by the `created_at` timestamp of the objects.
143-
144- # `asc` for ascending order and `desc` for descending order.
145- # """
146-
147- entity : str
148- project : str
149- name : str
150- base_model : str
111+ @cached_property
112+ def checkpoints (self ) -> "Checkpoints" :
113+ return Checkpoints (cast (AsyncOpenAI , self ._client ))
151114
152115
153- class Models (AsyncAPIResource ):
154- async def create (
155- self ,
156- * ,
157- entity : str | None = None ,
158- project : str | None = None ,
159- name : str | None = None ,
160- base_model : str ,
161- return_existing : bool = False ,
162- ) -> Model :
163- return self ._patch_model (
164- await self ._post (
165- "/preview/models" ,
166- cast_to = Model ,
167- body = {
168- "entity" : entity ,
169- "project" : project ,
170- "name" : name ,
171- "base_model" : base_model ,
172- "return_existing" : return_existing ,
173- },
174- options = dict (max_retries = 0 ),
175- )
176- )
177-
178- async def list (
116+ class Checkpoints (AsyncAPIResource ):
117+ def list (
179118 self ,
180119 * ,
181120 after : str | NotGiven = NOT_GIVEN ,
182121 limit : int | NotGiven = NOT_GIVEN ,
183- # order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
184- entity : str | NotGiven = NOT_GIVEN ,
185- project : str | NotGiven = NOT_GIVEN ,
186- name : str | NotGiven = NOT_GIVEN ,
187- base_model : str | NotGiven = NOT_GIVEN ,
188- # # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
189- # # The extra values given here take precedence over values defined on the client or passed to this method.
190- # extra_headers: Headers | None = None,
191- # extra_query: Query | None = None,
192- # extra_body: Body | None = None,
193- # timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
194- ) -> AsyncIterator [Model ]:
195- """
196- Lists the currently available models, and provides basic information about each
197- one such as the owner and availability.
198- """
199- async for model in self ._get_api_list (
200- "/preview/models" ,
201- page = AsyncCursorPage [Model ],
122+ model_id : str ,
123+ order : Literal ["asc" , "desc" ] | NotGiven = NOT_GIVEN ,
124+ ) -> AsyncPaginator [Checkpoint , AsyncCursorPage [Checkpoint ]]:
125+ return self ._get_api_list (
126+ f"/preview/models/{ model_id } /checkpoints" ,
127+ page = AsyncCursorPage [Checkpoint ],
202128 options = make_request_options (
203- # extra_headers=extra_headers,
204- # extra_query=extra_query,
205- # extra_body=extra_body,
206- # timeout=timeout,
207129 query = maybe_transform (
208130 {
209131 "after" : after ,
210132 "limit" : limit ,
211- # "order": order,
212- "entity" : entity ,
213- "project" : project ,
214- "name" : name ,
215- "base_model" : base_model ,
133+ "order" : order ,
216134 },
217- ModelListParams ,
135+ CheckpointListParams ,
218136 ),
219137 ),
220- model = Model ,
221- ):
222- yield self ._patch_model (model )
223-
224- def _patch_model (self , model : Model ) -> Model :
225- """Patch model instance with async method implementations."""
226-
227- async def get_step () -> int :
228- return 0
229-
230- model .get_step = get_step
231-
232- async def train (trajectory_groups : list [TrajectoryGroup ]) -> None :
233- training_job = await cast ("Client" , self ._client ).training_jobs .create (
234- model_id = model .id ,
235- trajectory_groups = trajectory_groups ,
236- )
237- while training_job .status != "COMPLETED" :
238- await asyncio .sleep (1 )
239- training_job = await cast (
240- "Client" , self ._client
241- ).training_jobs .retrieve (training_job .id )
242-
243- model .train = train
244- return model
245-
246-
247- class ExperimentalTrainingConfig (TypedDict , total = False ):
248- learning_rate : float | None
249- precalculate_logprobs : bool | None
250-
138+ model = Checkpoint ,
139+ )
251140
252- class TrainingJob (BaseModel ):
253- id : str
254- status : str
255- experimental_config : ExperimentalTrainingConfig
141+ async def delete (
142+ self , * , model_id : str , steps : Iterable [int ]
143+ ) -> DeleteCheckpointsResponse :
144+ return await self ._delete (
145+ f"/preview/models/{ model_id } /checkpoints" ,
146+ body = {"steps" : steps },
147+ cast_to = DeleteCheckpointsResponse ,
148+ )
256149
257150
258151class TrainingJobs (AsyncAPIResource ):
@@ -269,38 +162,18 @@ async def create(
269162 body = {
270163 "model_id" : model_id ,
271164 "trajectory_groups" : [
272- trajectory_group .model_dump ()
165+ trajectory_group .model_dump (mode = "json" )
273166 for trajectory_group in trajectory_groups
274167 ],
275168 "experimental_config" : experimental_config ,
276169 },
277- options = dict (max_retries = 0 ),
278- )
279-
280- async def retrieve (self , training_job_id : int ) -> TrainingJob :
281- return await self ._get (
282- f"/preview/training-jobs/{ training_job_id } " ,
283- cast_to = TrainingJob ,
284170 )
285171
286172 @cached_property
287- def events (self ) -> TrainingJobEvents :
173+ def events (self ) -> " TrainingJobEvents" :
288174 return TrainingJobEvents (cast (AsyncOpenAI , self ._client ))
289175
290176
291- class TrainingJobEvent (BaseModel ):
292- id : str
293- type : Literal [
294- "training_started" , "gradient_step" , "training_ended" , "training_failed"
295- ]
296- data : dict [str , Any ]
297-
298-
299- class TrainingJobEventListParams (TypedDict , total = False ):
300- after : str
301- limit : int
302-
303-
304177class TrainingJobEvents (AsyncAPIResource ):
305178 def list (
306179 self ,
@@ -317,7 +190,6 @@ def list(
317190 {
318191 "after" : after ,
319192 "limit" : limit ,
320- "training_job_id" : training_job_id ,
321193 },
322194 TrainingJobEventListParams ,
323195 ),
@@ -341,18 +213,15 @@ def __init__(
341213 self .api_key = api_key
342214 super ().__init__ (
343215 version = __version__ ,
344- base_url = base_url or "http ://0.0.0.0:8000 /v1" ,
216+ base_url = base_url or "https ://api.training.wandb.ai /v1" ,
345217 _strict_response_validation = False ,
218+ max_retries = 0 ,
346219 )
347220
348221 @cached_property
349222 def models (self ) -> Models :
350223 return Models (cast (AsyncOpenAI , self ))
351224
352- @cached_property
353- def checkpoints (self ) -> Checkpoints :
354- return Checkpoints (cast (AsyncOpenAI , self ))
355-
356225 @cached_property
357226 def training_jobs (self ) -> TrainingJobs :
358227 return TrainingJobs (cast (AsyncOpenAI , self ))
0 commit comments