Skip to content

Commit 980684f

Browse files
authored
Add get_guided_completion_params and use in tic tac toe self play (#147)
* Add `get_guided_completion_params` and use in tic tac toe self play * Rename shadowmaster to teacher
1 parent 87bb5d6 commit 980684f

File tree

3 files changed

+92
-25
lines changed

3 files changed

+92
-25
lines changed

examples/tic_tac_toe_self_play/rollout.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
render_board,
1717
unwrap_move,
1818
)
19+
from art.guided_completion import get_guided_completion_params
1920

2021
load_dotenv()
2122

@@ -36,7 +37,7 @@ async def get_agent_move(
3637
game: TicTacToeGame,
3738
player_state: PlayerState,
3839
model: art.Model,
39-
shadowmaster: art.Model | None = None,
40+
teacher: art.Model | None = None,
4041
predestined_move: str | None = None,
4142
) -> str:
4243
assert isinstance(model.config, ModelConfig)
@@ -46,21 +47,20 @@ async def get_agent_move(
4647

4748
messages = player_state.trajectory.messages()
4849
try:
49-
if shadowmaster and not predestined_move:
50-
assert isinstance(shadowmaster.config, ModelConfig)
51-
shadowmaster_client = shadowmaster.openai_client()
52-
shadowmaster_completion = await shadowmaster_client.chat.completions.create(
53-
model=shadowmaster.get_inference_name(),
50+
guided_choice = None
51+
if teacher and not predestined_move:
52+
assert isinstance(teacher.config, ModelConfig)
53+
teacher_client = teacher.openai_client()
54+
teacher_completion = await teacher_client.chat.completions.create(
55+
model=teacher.get_inference_name(),
5456
messages=messages,
5557
max_completion_tokens=2000
56-
if shadowmaster.config.requires_reasoning
58+
if teacher.config.requires_reasoning
5759
else 100,
58-
reasoning_effort="low"
59-
if shadowmaster.config.requires_reasoning
60-
else None,
60+
reasoning_effort="low" if teacher.config.requires_reasoning else None,
6161
temperature=1.0,
6262
)
63-
predestined_move = shadowmaster_completion.choices[0].message.content
63+
guided_choice, _, _ = get_guided_completion_params(teacher_completion)
6464

6565
client = model.openai_client()
6666
completion = await client.chat.completions.create(
@@ -69,7 +69,7 @@ async def get_agent_move(
6969
max_completion_tokens=2000 if model.config.requires_reasoning else 100,
7070
reasoning_effort="low" if model.config.requires_reasoning else None,
7171
temperature=1.0,
72-
extra_body={"guided_choice": [predestined_move]}
72+
extra_body={"guided_choice": guided_choice}
7373
if predestined_move and model.trainable
7474
else None,
7575
)
@@ -102,8 +102,8 @@ def record_first_move_metrics(trajectory: art.Trajectory, square: str) -> None:
102102
class TicTacToeScenario(BaseModel):
103103
step: int
104104
split: str
105-
x_shadowmaster: art.Model | None = None
106-
o_shadowmaster: art.Model | None = None
105+
x_teacher: art.Model | None = None
106+
o_teacher: art.Model | None = None
107107
initial_move: str | None = None
108108

109109

@@ -154,16 +154,14 @@ async def rollout(
154154
for symbol in ["x", "o"]:
155155
model = x_model if symbol == "x" else o_model
156156
player_state = player_states[symbol]
157-
shadowmaster = (
158-
scenario.x_shadowmaster if symbol == "x" else scenario.o_shadowmaster
159-
)
157+
teacher = scenario.x_teacher if symbol == "x" else scenario.o_teacher
160158

161159
try:
162160
square = await get_agent_move(
163161
game=game,
164162
player_state=player_state,
165163
model=model,
166-
shadowmaster=shadowmaster,
164+
teacher=teacher,
167165
predestined_move=scenario.initial_move
168166
if move_number == 0
169167
else None,
@@ -214,9 +212,7 @@ async def rollout(
214212
messages = messages[:-1]
215213

216214
model = x_model if symbol == "x" else o_model
217-
shadowmaster = (
218-
scenario.x_shadowmaster if symbol == "x" else scenario.o_shadowmaster
219-
)
215+
teacher = scenario.x_teacher if symbol == "x" else scenario.o_teacher
220216
try:
221217
reported_win = (
222218
trajectory.metrics["win"] if "win" in trajectory.metrics else -1
@@ -236,7 +232,7 @@ async def rollout(
236232
"reward": str(trajectory.reward),
237233
"invalid_move": str(player_state.invalid_move),
238234
"symbol": symbol,
239-
"shadowmaster": shadowmaster.name if shadowmaster else "",
235+
"teacher": teacher.name if teacher else "",
240236
"initial_move": unwrap_move(scenario.initial_move)
241237
if scenario.initial_move
242238
else "",

examples/tic_tac_toe_self_play/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
CLUSTER_NAME = "art4"
2121
PROJECT_NAME = "tic-tac-toe"
2222
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
23-
MODEL_NAME = "llama-8b-shadowmaster-001"
23+
MODEL_NAME = "llama-8b-student-001"
2424

2525

2626
async def main():
@@ -96,8 +96,8 @@ async def main():
9696
scenario=TicTacToeScenario(
9797
step=i,
9898
split="train",
99-
x_shadowmaster=o4_mini if j % 4 == 0 else None,
100-
o_shadowmaster=o4_mini if j % 4 == 1 else None,
99+
x_teacher=o4_mini if j % 4 == 0 else None,
100+
o_teacher=o4_mini if j % 4 == 1 else None,
101101
# ensure we learn how to play against all 9 possible opening moves
102102
initial_move=possible_moves[j % 9] if j < 63 else None,
103103
),

src/art/guided_completion.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from openai.types.chat.chat_completion import ChatCompletion
2+
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
3+
from openai.types.chat.chat_completion_tool_choice_option_param import (
4+
ChatCompletionToolChoiceOptionParam,
5+
)
6+
from pydantic import create_model
7+
from typing import Literal, Tuple, Iterable, List
8+
from copy import deepcopy
9+
import json
10+
11+
12+
def freeze_tool_schema(tool: dict, fixed_args: dict) -> ChatCompletionToolParam:
13+
"""
14+
Return a clone of *tool* whose parameters schema permits *only* `fixed_args`.
15+
Each field is cast to typing.Literal[value] so Pydantic emits an
16+
enum-of-one in the JSON schema, which vLLM's `guided_json` accepts.
17+
"""
18+
fields = {k: (Literal[v], ...) for k, v in fixed_args.items()}
19+
FrozenModel = create_model(
20+
f"{tool['function']['name'].title()}FrozenArgs", **fields
21+
)
22+
23+
locked = deepcopy(tool)
24+
locked["function"]["parameters"] = FrozenModel.model_json_schema()
25+
return locked
26+
27+
28+
def get_guided_completion_params(
29+
completion: ChatCompletion,
30+
base_tools: Iterable[ChatCompletionToolParam] | None = None,
31+
) -> Tuple[
32+
List[str] | None,
33+
ChatCompletionToolChoiceOptionParam | None,
34+
ChatCompletionToolParam | None,
35+
]:
36+
"""
37+
Given a completion from a teacher model, returns chat completion params that can be used to guide a student model's response.
38+
Useful for RL-based distillation.
39+
40+
When guiding the student model's completion, remember to set `num_scheduler_steps` to 1.
41+
42+
Args:
43+
completion: The completion of a teacher model
44+
base_tools: The base tools available to the teacher model
45+
46+
Returns a tuple of (guided_choice, tool_choice, tool_params).
47+
"""
48+
guided_choice, tool_choice, tool_params = None, None, None
49+
50+
if (
51+
completion.choices[0].message.tool_calls
52+
and len(completion.choices[0].message.tool_calls) > 0
53+
):
54+
tool_call = completion.choices[0].message.tool_calls[0]
55+
if not tool_call:
56+
raise ValueError("No tool call found in completion")
57+
if base_tools is None:
58+
raise ValueError("No base tools provided")
59+
tool_name = tool_call.function.name
60+
tool_choice = {
61+
"type": "function", # ← must call it
62+
"function": {"name": tool_name},
63+
}
64+
chosen_tool = next(t for t in base_tools if t["function"]["name"] == tool_name)
65+
tool_params = [
66+
freeze_tool_schema(chosen_tool, json.loads(tool_call.function.arguments))
67+
]
68+
else:
69+
content = completion.choices[0].message.content
70+
guided_choice = [content]
71+
return (guided_choice, tool_choice, tool_params)

0 commit comments

Comments
 (0)