Skip to content

Commit 4deec26

Browse files
committed
fix format
1 parent 9fcea8a commit 4deec26

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

src/art/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, AsyncIterator, Iterable, Literal, TypedDict, cast
66

77
import httpx
8+
from openai import AsyncOpenAI, BaseModel, _exceptions
89
from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options
910
from openai._compat import cached_property
1011
from openai._qs import Querystring
@@ -17,8 +18,6 @@
1718
from openai.resources.models import AsyncModels # noqa: F401
1819
from typing_extensions import override
1920

20-
from openai import AsyncOpenAI, BaseModel, _exceptions
21-
2221
from .trajectories import TrajectoryGroup
2322

2423

@@ -291,7 +290,9 @@ def events(self) -> TrainingJobEvents:
291290

292291
class TrainingJobEvent(BaseModel):
293292
id: str
294-
type: Literal["training_started", "gradient_step", "training_ended", "training_failed"]
293+
type: Literal[
294+
"training_started", "gradient_step", "training_ended", "training_failed"
295+
]
295296
data: dict[str, Any]
296297

297298

src/art/gather.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ def record_metrics(context: "GatherContext", trajectory: Trajectory) -> None:
190190
if logprobs:
191191
# TODO: probably shouldn't average this
192192
trajectory.metrics["completion_tokens"] = sum(
193-
len(l.content or l.refusal or []) for l in logprobs # noqa: E741
193+
len(l.content or l.refusal or [])
194+
for l in logprobs # noqa: E741
194195
) / len(logprobs)
195196
context.metric_sums["reward"] += trajectory.reward # type: ignore
196197
context.metric_divisors["reward"] += 1

src/art/openai.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ def update_chat_completion(
128128
choice.message.tool_calls[tool_call.index].id = tool_call.id
129129
if tool_call.function:
130130
if tool_call.function.name:
131-
choice.message.tool_calls[tool_call.index].function.name = (
132-
tool_call.function.name
133-
)
131+
choice.message.tool_calls[
132+
tool_call.index
133+
].function.name = tool_call.function.name
134134
if tool_call.function.arguments:
135135
choice.message.tool_calls[
136136
tool_call.index

src/art/serverless/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ async def _train_model(
157157
elif event.type == "training_ended":
158158
return
159159
elif event.type == "training_failed":
160-
error_message = event.data.get("error_message", "Training failed with an unknown error")
160+
error_message = event.data.get(
161+
"error_message", "Training failed with an unknown error"
162+
)
161163
raise RuntimeError(f"Training job failed: {error_message}")
162164
after = event.id
163165

0 commit comments

Comments
 (0)