-
Notifications
You must be signed in to change notification settings - Fork 599
Description
Proposal
- Create an
@art.rolloutdecorator which wraps a rollout function and constructs a trajectory (potentially with multiple histories) automatically, similar to how@weave.opautomatically wraps an LLM-enabled function and records all function calls then reports a trace. - Allow rollout functions to access the current trajectory through some kind of
get_current_trajectory()helper function. - Store completion ids on messages to make it possible to access and manipulate a certain history using
trajectory.get_history(completion_id)- Useful when adding tool messages after executing a tool
- Also create a
gather_trajectoryhelper function that calls a rollout function decorated with@art.rolloutand returns the generated trajectory.
Worth taking a good look at @weave.op, and we may even want to integrate with them or wrap their decorator since they've already done the integration work to read completions through a lot of LLM clients.
Messy ideas in proposal doc.
Caveats
This @art.rollout decorator will need to automatically determine when LLM completions are part of the same history or separate histories.
Example
Our current rollout functions require the user to initialize and add messages to an art.Trajectory object, like so:
async def get_summary(model: art.Model, scenario: Scenario) -> art.Trajectory:
traj = art.Trajectory(
messages_and_choices=[
{
"role": "system",
"content": f"Summarize: {scenario.text}"
},
]
)
completion = await client.chat.completions.create(
model=model.name,
messages=traj.messages()
)
traj.messages_and_choices.append(completion.choices[0])
return traj
However, this makes our rollout functions verbose (because they have to initialize and update the trajectories) and difficult to use elsewhere in the codebase (because they don't return the processed type that the rollout function was meant to generate).
By decorating our function with @art.rollout and returning the summary as a string, our code will be made much cleaner:
@art.rollout
async def get_summary(model: art.Model, scenario: Scenario) -> str:
completion = await client.chat.completions.create(
model=model.name,
messages=[
{
"role": "system",
"content": f"Summarize: {scenario.text}"
},
]
)
return completion.choices[0].message.content
Used in production flow:
async def caller():
summary = await get_summary(model, scenario)
print(summary)
Used in training flow:
trajectory = await gather_trajectory(get_summary(model, scenario))