Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 111 additions & 3 deletions src/art/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
import inspect
import os
from typing import Any, Iterable, Literal, TypedDict, cast
from typing import (
Any,
AsyncIterable,
Awaitable,
Callable,
Iterable,
Literal,
ParamSpec,
TypedDict,
TypeVar,
cast,
overload,
)

import httpx
import tenacity
from openai import AsyncOpenAI, BaseModel, _exceptions
from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options
from openai._compat import cached_property
Expand All @@ -11,12 +25,103 @@
from openai._utils import is_mapping, maybe_transform
from openai._version import __version__
from openai.pagination import AsyncCursorPage
from openai.resources.files import AsyncFiles # noqa: F401
from openai.resources.models import AsyncModels # noqa: F401
from typing_extensions import override

from .trajectories import TrajectoryGroup

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")


@overload
def retry_status_codes(
fn: Callable[P, AsyncPaginator[R, AsyncCursorPage[R]]],
) -> Callable[P, AsyncIterable[R]]: ...


@overload
def retry_status_codes(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...


@overload
def retry_status_codes(fn: Callable[P, R]) -> Callable[P, R]: ...


def retry_status_codes(
fn: (
Callable[P, R]
| Callable[P, Awaitable[R]]
| Callable[P, AsyncPaginator[R, AsyncCursorPage[R]]]
),
) -> Callable[P, R | AsyncIterable[R]] | Callable[P, Awaitable[R]]:
def is_retryable_status(exc: BaseException) -> bool:
if isinstance(exc, _exceptions.APIStatusError):
response = exc.response
if response is not None:
status = response.status_code
return status in {429, *range(500, 600)}
return False

stop = tenacity.stop_after_attempt(3)
wait = tenacity.wait_random_exponential(multiplier=0.5, max=2.0)
retry = tenacity.retry_if_exception(is_retryable_status)
reraise = True

async def retrying_awaitable(awaitable_fn: Callable[[], Awaitable[T]]) -> T:
async for attempt in tenacity.AsyncRetrying(
stop=stop,
wait=wait,
retry=retry,
reraise=reraise,
):
with attempt:
return await awaitable_fn()

# Unreachable if tenacity produces at least one attempt
raise RuntimeError("retry attempt sequence unexpectedly exhausted")

if inspect.iscoroutinefunction(fn):
async_fn = cast(Callable[P, Awaitable[R]], fn)

async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return await retrying_awaitable(lambda: async_fn(*args, **kwargs))

return async_wrapper

async def retrying_async_iterable(
async_paginator: AsyncPaginator[R, AsyncCursorPage[R]],
) -> AsyncIterable[R]:
page = await retrying_awaitable(lambda: async_paginator)
for item in page._get_page_items():
yield item
while page.has_next_page():
page = await retrying_awaitable(lambda: page.get_next_page())
for item in page._get_page_items():
yield item

sync_fn = cast(Callable[P, R], fn)

def sync_or_async_paginator_wrapper(
*args: P.args, **kwargs: P.kwargs
) -> R | AsyncIterable[R]:
for attempt in tenacity.Retrying(
stop=stop,
wait=wait,
retry=retry,
reraise=reraise,
):
with attempt:
result = sync_fn(*args, **kwargs)
if isinstance(result, AsyncPaginator):
return retrying_async_iterable(result)
return result

# Unreachable if tenacity produces at least one attempt
raise RuntimeError("retry attempt sequence unexpectedly exhausted")

return sync_or_async_paginator_wrapper


class Model(BaseModel):
id: str
Expand Down Expand Up @@ -113,6 +218,7 @@ def checkpoints(self) -> "Checkpoints":


class Checkpoints(AsyncAPIResource):
@retry_status_codes
def list(
self,
*,
Expand All @@ -137,6 +243,7 @@ def list(
model=Checkpoint,
)

@retry_status_codes
async def delete(
self, *, model_id: str, steps: Iterable[int]
) -> DeleteCheckpointsResponse:
Expand Down Expand Up @@ -174,6 +281,7 @@ def events(self) -> "TrainingJobEvents":


class TrainingJobEvents(AsyncAPIResource):
@retry_status_codes
def list(
self,
*,
Expand Down