Skip to content

Commit 6a442e0

Browse files
authored
feat: Add VLM support (#412)
1 parent 64e4731 commit 6a442e0

File tree

15 files changed

+990
-39
lines changed

15 files changed

+990
-39
lines changed

dev/math-vista/math-vista.ipynb

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "46a6ad6d",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"%load_ext autoreload\n",
11+
"%autoreload 2"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": null,
17+
"id": "96d51078",
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"%%html\n",
22+
"<style>\n",
23+
".cell-output-ipywidget-background {\n",
24+
" background-color: transparent !important;\n",
25+
"}\n",
26+
":root {\n",
27+
" --jp-widgets-color: var(--vscode-editor-foreground);\n",
28+
" --jp-widgets-font-size: var(--vscode-editor-font-size);\n",
29+
"} \n",
30+
"</style>"
31+
]
32+
},
33+
{
34+
"cell_type": "code",
35+
"execution_count": null,
36+
"id": "7dd70e04",
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"import polars as pl\n",
41+
"\n",
42+
"splits = {\n",
43+
" \"testmini\": \"data/testmini-00000-of-00001-725687bf7a18d64b.parquet\",\n",
44+
" \"test\": \"data/test-*.parquet\",\n",
45+
"}\n",
46+
"df = pl.read_parquet(\"hf://datasets/AI4Math/MathVista/\" + splits[\"testmini\"]).sample(\n",
47+
" fraction=1.0, shuffle=True, seed=42\n",
48+
")"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": null,
54+
"id": "81e02b97",
55+
"metadata": {},
56+
"outputs": [],
57+
"source": [
58+
"from typing import Iterator, TypedDict, cast\n",
59+
"\n",
60+
"\n",
61+
"class DecodedImage(TypedDict):\n",
62+
" bytes: bytes\n",
63+
"\n",
64+
"\n",
65+
"class Scenario(TypedDict):\n",
66+
" pid: int\n",
67+
" question: str\n",
68+
" answer: str\n",
69+
" image: str\n",
70+
" decoded_image: DecodedImage\n",
71+
"\n",
72+
"\n",
73+
"val_scenarios = cast(list[Scenario], df.head(64).to_dicts())\n",
74+
"train_scenarios_iter = cast(Iterator[Scenario], df.tail(-64).iter_rows(named=True))"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"id": "9287d8fe",
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"import re\n",
85+
"\n",
86+
"import art\n",
87+
"from art.local import LocalBackend\n",
88+
"\n",
89+
"model = art.TrainableModel(\n",
90+
" name=\"002\",\n",
91+
" project=\"math-vista\",\n",
92+
" base_model=\"Qwen/Qwen2.5-VL-7B-Instruct\",\n",
93+
")\n",
94+
"backend = LocalBackend()\n",
95+
"await model.register(backend)\n",
96+
"client = model.openai_client()"
97+
]
98+
},
99+
{
100+
"cell_type": "code",
101+
"execution_count": null,
102+
"id": "c92b4b11",
103+
"metadata": {},
104+
"outputs": [],
105+
"source": [
106+
"async def rollout(scenario: Scenario) -> art.Trajectory:\n",
107+
" image_path = f\"/tmp/{scenario['image']}\"\n",
108+
"\n",
109+
" import os\n",
110+
"\n",
111+
" os.makedirs(os.path.dirname(image_path), exist_ok=True)\n",
112+
"\n",
113+
" with open(image_path, \"wb\") as f:\n",
114+
" f.write(scenario[\"decoded_image\"][\"bytes\"])\n",
115+
"\n",
116+
" trajectory = art.Trajectory(messages_and_choices=[], reward=0.0)\n",
117+
" trajectory.messages_and_choices = [\n",
118+
" {\n",
119+
" \"role\": \"user\",\n",
120+
" \"content\": [\n",
121+
" {\n",
122+
" \"type\": \"text\",\n",
123+
" \"text\": scenario[\"question\"]\n",
124+
" + \"\\n\\nNote: Provide your answer in a LaTeX box.\",\n",
125+
" },\n",
126+
" {\"type\": \"image_url\", \"image_url\": {\"url\": f\"file://{image_path}\"}},\n",
127+
" ],\n",
128+
" }\n",
129+
" ]\n",
130+
" chat_completion = await client.chat.completions.create(\n",
131+
" model=model.name, messages=trajectory.messages()\n",
132+
" )\n",
133+
" choice = chat_completion.choices[0]\n",
134+
" trajectory.messages_and_choices.append(choice)\n",
135+
" content = choice.message.content\n",
136+
" assert content is not None\n",
137+
" if matches := list(re.finditer(r\"\\\\boxed\\{(.*?)\\}\", content, re.DOTALL)):\n",
138+
" match = matches[-1]\n",
139+
" answer = match.group(1)\n",
140+
" if answer.lower() == scenario[\"answer\"].lower():\n",
141+
" trajectory.reward = 1.0\n",
142+
" return trajectory"
143+
]
144+
},
145+
{
146+
"cell_type": "code",
147+
"execution_count": null,
148+
"id": "359e530d",
149+
"metadata": {},
150+
"outputs": [],
151+
"source": [
152+
"import asyncio\n",
153+
"import itertools\n",
154+
"\n",
155+
"SCENARIOS_PER_STEP = 8\n",
156+
"TRAJECTORY_GROUP_SIZE = 8\n",
157+
"start = await model.get_step()\n",
158+
"train_scenarios_iter = itertools.cycle(train_scenarios_iter)\n",
159+
"for _ in range(start * SCENARIOS_PER_STEP):\n",
160+
" next(train_scenarios_iter)\n",
161+
"\n",
162+
"for i in range(start, 1000):\n",
163+
" train_scenarios = [next(train_scenarios_iter) for _ in range(SCENARIOS_PER_STEP)]\n",
164+
" val_trajectories, train_trajectory_groups = await asyncio.gather(\n",
165+
" art.gather_trajectories(\n",
166+
" (rollout(scenario) for scenario in val_scenarios),\n",
167+
" pbar_desc=\"gather(val)\",\n",
168+
" max_exceptions=32,\n",
169+
" ),\n",
170+
" art.gather_trajectory_groups(\n",
171+
" (\n",
172+
" art.TrajectoryGroup(\n",
173+
" rollout(scenario) for _ in range(TRAJECTORY_GROUP_SIZE)\n",
174+
" )\n",
175+
" for scenario in train_scenarios\n",
176+
" ),\n",
177+
" pbar_desc=\"gather(train)\",\n",
178+
" max_exceptions=32,\n",
179+
" ),\n",
180+
" )\n",
181+
" await model.log(val_trajectories)\n",
182+
" await model.train(train_trajectory_groups)"
183+
]
184+
}
185+
],
186+
"metadata": {
187+
"kernelspec": {
188+
"display_name": ".venv",
189+
"language": "python",
190+
"name": "python3"
191+
},
192+
"language_info": {
193+
"codemirror_mode": {
194+
"name": "ipython",
195+
"version": 3
196+
},
197+
"file_extension": ".py",
198+
"mimetype": "text/x-python",
199+
"name": "python",
200+
"nbconvert_exporter": "python",
201+
"pygments_lexer": "ipython3",
202+
"version": "3.10.13"
203+
}
204+
},
205+
"nbformat": 4,
206+
"nbformat_minor": 5
207+
}

dev/math-vista/math-vista.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import argparse
2+
import asyncio
3+
import itertools
4+
import os
5+
import re
6+
from typing import Iterator, TypedDict, cast
7+
8+
import polars as pl
9+
10+
import art
11+
from art.local import LocalBackend
12+
13+
14+
class DecodedImage(TypedDict):
15+
bytes: bytes
16+
17+
18+
class Scenario(TypedDict):
19+
pid: int
20+
question: str
21+
answer: str
22+
image: str
23+
decoded_image: DecodedImage
24+
25+
26+
async def main(model_name: str, steps: int) -> None:
27+
# Load and shuffle the dataset
28+
df = pl.read_parquet(
29+
"hf://datasets/AI4Math/MathVista/data/testmini-00000-of-00001-725687bf7a18d64b.parquet"
30+
).sample(fraction=1.0, shuffle=True, seed=42)
31+
32+
val_scenarios = cast(list[Scenario], df.head(64).to_dicts())
33+
train_scenarios_iter = cast(Iterator[Scenario], df.tail(-64).iter_rows(named=True))
34+
35+
# Initialize trainable model and backend
36+
model = art.TrainableModel(
37+
name=model_name,
38+
project="math-vista",
39+
base_model="Qwen/Qwen2.5-VL-7B-Instruct",
40+
)
41+
42+
async def rollout(scenario: Scenario) -> art.Trajectory:
43+
image_path = f"/tmp/{scenario['image']}"
44+
os.makedirs(os.path.dirname(image_path), exist_ok=True)
45+
with open(image_path, "wb") as f:
46+
f.write(scenario["decoded_image"]["bytes"])
47+
48+
trajectory = art.Trajectory(messages_and_choices=[], reward=0.0)
49+
trajectory.messages_and_choices = [
50+
{
51+
"role": "user",
52+
"content": [
53+
{
54+
"type": "text",
55+
"text": scenario["question"]
56+
+ "\n\nNote: Provide your answer in a LaTeX box.",
57+
},
58+
{"type": "image_url", "image_url": {"url": f"file://{image_path}"}},
59+
],
60+
}
61+
]
62+
63+
chat_completion = await client.chat.completions.create(
64+
model=model.name, messages=trajectory.messages()
65+
)
66+
choice = chat_completion.choices[0]
67+
trajectory.messages_and_choices.append(choice)
68+
content = choice.message.content
69+
assert content is not None
70+
71+
if matches := list(re.finditer(r"\\boxed\{(.*?)\}", content, re.DOTALL)):
72+
match = matches[-1]
73+
answer = match.group(1)
74+
if answer.lower() == scenario["answer"].lower():
75+
trajectory.reward = 1.0
76+
return trajectory
77+
78+
SCENARIOS_PER_STEP = 8
79+
TRAJECTORY_GROUP_SIZE = 8
80+
81+
with LocalBackend() as backend:
82+
await model.register(backend)
83+
client = model.openai_client()
84+
85+
start = await model.get_step()
86+
train_scenarios_iter = itertools.cycle(train_scenarios_iter)
87+
for _ in range(start * SCENARIOS_PER_STEP):
88+
next(train_scenarios_iter)
89+
90+
# Training loop
91+
for _ in range(start, steps):
92+
train_scenarios = [
93+
next(train_scenarios_iter) for _ in range(SCENARIOS_PER_STEP)
94+
]
95+
val_trajectories, train_trajectory_groups = await asyncio.gather(
96+
art.gather_trajectories(
97+
(rollout(scenario) for scenario in val_scenarios),
98+
pbar_desc="gather(val)",
99+
max_exceptions=32,
100+
),
101+
art.gather_trajectory_groups(
102+
(
103+
art.TrajectoryGroup(
104+
rollout(scenario) for _ in range(TRAJECTORY_GROUP_SIZE)
105+
)
106+
for scenario in train_scenarios
107+
),
108+
pbar_desc="gather(train)",
109+
max_exceptions=32,
110+
),
111+
)
112+
await model.log(val_trajectories)
113+
await model.train(train_trajectory_groups)
114+
115+
116+
def parse_args() -> argparse.Namespace:
117+
parser = argparse.ArgumentParser(description="Minimal MathVista trainer script")
118+
parser.add_argument(
119+
"-n",
120+
"--name",
121+
required=True,
122+
help="Run/model name to use for the TrainableModel",
123+
)
124+
parser.add_argument(
125+
"-s",
126+
"--steps",
127+
type=int,
128+
default=1000,
129+
help="Number of training steps to run",
130+
)
131+
return parser.parse_args()
132+
133+
134+
if __name__ == "__main__":
135+
args = parse_args()
136+
asyncio.run(main(args.name, args.steps))

0 commit comments

Comments
 (0)