Skip to content

Add @art.rollout decorator to gather trajectories #347

@arcticfly

Description

@arcticfly

Proposal

  • Create an @art.rollout decorator which wraps a rollout function and constructs a trajectory (potentially with multiple histories) automatically, similar to how @weave.op automatically wraps an LLM-enabled function and records all function calls then reports a trace.
  • Allow rollout functions to access the current trajectory through some kind of get_current_trajectory() helper function.
  • Store completion ids on messages to make it possible to access and manipulate a certain history using trajectory.get_history(completion_id)
    • Useful when adding tool messages after executing a tool
  • Also create a gather_trajectory helper function that calls a rollout function decorated with @art.rollout and returns the generated trajectory.

Worth taking a good look at @weave.op, and we may even want to integrate with them or wrap their decorator since they've already done the integration work to read completions through a lot of LLM clients.

Messy ideas in proposal doc.

Caveats

This @art.rollout decorator will need to automatically determine when LLM completions are part of the same history or separate histories.

Example

Our current rollout functions require the user to initialize and add messages to an art.Trajectory object, like so:

async def get_summary(model: art.Model, scenario: Scenario) -> art.Trajectory:
    traj = art.Trajectory(
        messages_and_choices=[
            {
                "role": "system",
                "content": f"Summarize: {scenario.text}"
            },
        ]
    )

    completion = await client.chat.completions.create(
        model=model.name,
        messages=traj.messages()
    )

    traj.messages_and_choices.append(completion.choices[0])

    return traj

However, this makes our rollout functions verbose (because they have to initialize and update the trajectories) and difficult to use elsewhere in the codebase (because they don't return the processed type that the rollout function was meant to generate).

By decorating our function with @art.rollout and returning the summary as a string, our code will be made much cleaner:

@art.rollout
async def get_summary(model: art.Model, scenario: Scenario) -> str:
    completion = await client.chat.completions.create(
        model=model.name,
        messages=[
            {
                "role": "system",
                "content": f"Summarize: {scenario.text}"
            },
        ]
    )

    return completion.choices[0].message.content

Used in production flow:

async def caller():
    summary = await get_summary(model, scenario)
    print(summary)

Used in training flow:

trajectory = await gather_trajectory(get_summary(model, scenario))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions