Skip to content

Commit 585f903

Browse files
committed
feat: Implement MathVista training script for image-based question answering
* Added a new script to train a model using image and question pairs from the MathVista dataset. * Integrated asynchronous processing for efficient training and trajectory logging. * Enhanced image handling by saving decoded images to a temporary directory for model input. * Improved argument parsing for customizable training runs.
1 parent 01cd44f commit 585f903

File tree

6 files changed

+150
-16
lines changed

6 files changed

+150
-16
lines changed

dev/math-vista/math-vista.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import argparse
2+
import asyncio
3+
import itertools
4+
import os
5+
import re
6+
from typing import Iterator, TypedDict, cast
7+
8+
import polars as pl
9+
10+
import art
11+
from art.local import LocalBackend
12+
13+
14+
class DecodedImage(TypedDict):
15+
bytes: bytes
16+
17+
18+
class Scenario(TypedDict):
19+
pid: int
20+
question: str
21+
answer: str
22+
image: str
23+
decoded_image: DecodedImage
24+
25+
26+
async def main(model_name: str, steps: int) -> None:
27+
# Load and shuffle the dataset
28+
df = pl.read_parquet(
29+
"hf://datasets/AI4Math/MathVista/data/testmini-00000-of-00001-725687bf7a18d64b.parquet"
30+
).sample(fraction=1.0, shuffle=True, seed=42)
31+
32+
val_scenarios = cast(list[Scenario], df.head(64).to_dicts())
33+
train_scenarios_iter = cast(Iterator[Scenario], df.tail(-64).iter_rows(named=True))
34+
35+
# Initialize trainable model and backend
36+
model = art.TrainableModel(
37+
name=model_name,
38+
project="math-vista",
39+
base_model="Qwen/Qwen2.5-VL-7B-Instruct",
40+
)
41+
42+
async def rollout(scenario: Scenario) -> art.Trajectory:
43+
image_path = f"/tmp/{scenario['image']}"
44+
os.makedirs(os.path.dirname(image_path), exist_ok=True)
45+
with open(image_path, "wb") as f:
46+
f.write(scenario["decoded_image"]["bytes"])
47+
48+
trajectory = art.Trajectory(messages_and_choices=[], reward=0.0)
49+
trajectory.messages_and_choices = [
50+
{
51+
"role": "user",
52+
"content": [
53+
{
54+
"type": "text",
55+
"text": scenario["question"]
56+
+ "\n\nNote: Provide your answer in a LaTeX box.",
57+
},
58+
{"type": "image_url", "image_url": {"url": f"file://{image_path}"}},
59+
],
60+
}
61+
]
62+
63+
chat_completion = await client.chat.completions.create(
64+
model=model.name, messages=trajectory.messages()
65+
)
66+
choice = chat_completion.choices[0]
67+
trajectory.messages_and_choices.append(choice)
68+
content = choice.message.content
69+
assert content is not None
70+
71+
if matches := list(re.finditer(r"\\boxed\{(.*?)\}", content, re.DOTALL)):
72+
match = matches[-1]
73+
answer = match.group(1)
74+
if answer.lower() == scenario["answer"].lower():
75+
trajectory.reward = 1.0
76+
return trajectory
77+
78+
SCENARIOS_PER_STEP = 8
79+
TRAJECTORY_GROUP_SIZE = 8
80+
81+
with LocalBackend() as backend:
82+
await model.register(backend)
83+
client = model.openai_client()
84+
85+
start = await model.get_step()
86+
train_scenarios_iter = itertools.cycle(train_scenarios_iter)
87+
for _ in range(start * SCENARIOS_PER_STEP):
88+
next(train_scenarios_iter)
89+
90+
# Training loop
91+
for _ in range(start, steps):
92+
train_scenarios = [
93+
next(train_scenarios_iter) for _ in range(SCENARIOS_PER_STEP)
94+
]
95+
val_trajectories, train_trajectory_groups = await asyncio.gather(
96+
art.gather_trajectories(
97+
(rollout(scenario) for scenario in val_scenarios),
98+
pbar_desc="gather(val)",
99+
max_exceptions=32,
100+
),
101+
art.gather_trajectory_groups(
102+
(
103+
art.TrajectoryGroup(
104+
rollout(scenario) for _ in range(TRAJECTORY_GROUP_SIZE)
105+
)
106+
for scenario in train_scenarios
107+
),
108+
pbar_desc="gather(train)",
109+
max_exceptions=32,
110+
),
111+
)
112+
await model.log(val_trajectories)
113+
await model.train(train_trajectory_groups)
114+
115+
116+
def parse_args() -> argparse.Namespace:
117+
parser = argparse.ArgumentParser(description="Minimal MathVista trainer script")
118+
parser.add_argument(
119+
"-n",
120+
"--name",
121+
required=True,
122+
help="Run/model name to use for the TrainableModel",
123+
)
124+
parser.add_argument(
125+
"-s",
126+
"--steps",
127+
type=int,
128+
default=1000,
129+
help="Number of training steps to run",
130+
)
131+
return parser.parse_args()
132+
133+
134+
if __name__ == "__main__":
135+
args = parse_args()
136+
asyncio.run(main(args.name, args.steps))

src/art/local/backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -498,9 +498,9 @@ async def _train_model(
498498
num_gradient_steps = int(
499499
result.pop("num_gradient_steps", estimated_gradient_steps)
500500
)
501-
assert (
502-
num_gradient_steps == estimated_gradient_steps
503-
), f"num_gradient_steps {num_gradient_steps} != estimated_gradient_steps {estimated_gradient_steps}"
501+
assert num_gradient_steps == estimated_gradient_steps, (
502+
f"num_gradient_steps {num_gradient_steps} != estimated_gradient_steps {estimated_gradient_steps}"
503+
)
504504
results.append(result)
505505
yield {**result, "num_gradient_steps": num_gradient_steps}
506506
pbar.update(1)

src/art/preprocessing/tokenize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,9 @@ def tokenize_trajectory(
202202
assistant_mask[start:end] = [1] * len(content_token_ids)
203203
else:
204204
choice = message
205-
assert (
206-
choice.logprobs or allow_training_without_logprobs
207-
), "Chat completion choices must have logprobs"
205+
assert choice.logprobs or allow_training_without_logprobs, (
206+
"Chat completion choices must have logprobs"
207+
)
208208
if not choice.logprobs:
209209
continue
210210
token_logprobs = choice.logprobs.content or choice.logprobs.refusal or []

src/art/unsloth/service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ async def train(
185185
for task in done:
186186
result = task.result()
187187
# If `result` is `None`, the training task finished somehow.
188-
assert (
189-
result is not None
190-
), "The training task should never finish."
188+
assert result is not None, (
189+
"The training task should never finish."
190+
)
191191
self.results_queue.task_done()
192192
if warmup:
193193
from .state import gc_and_empty_cuda_cache

src/art/unsloth/train.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def compute_loss(
114114
next_input_ids = shift_tensor(inputs["tokens"], 0)
115115
chunk_size = _config.get("logprob_calculation_chunk_size", 1024)
116116
# Assert that sequence length is evenly divisible by the chunk size
117-
assert (
118-
seq_len % chunk_size == 0
119-
), f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})"
117+
assert seq_len % chunk_size == 0, (
118+
f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})"
119+
)
120120
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
121121
forward_kwargs = {}
122122
if "pixel_values" in inputs:
@@ -371,9 +371,7 @@ def _calculate_logprobs(
371371
chunk_logits = torch.matmul(chunk_hs, lm_head_t) # [B, chunk_size, V]
372372
chunk_selected_logits = torch.gather(
373373
chunk_logits, dim=-1, index=chunk_input_ids.unsqueeze(-1)
374-
).squeeze(
375-
-1
376-
) # [B, chunk_size]
374+
).squeeze(-1) # [B, chunk_size]
377375
chunk_logsumexp = torch.logsumexp(chunk_logits, dim=-1) # [B, chunk_size]
378376
log_probs[:, i : i + chunk_size] = chunk_selected_logits - chunk_logsumexp
379377

src/art/utils/trajectory_logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, cast, Iterator
2+
from typing import Any, Iterator, cast
33

44
import yaml
55

0 commit comments

Comments
 (0)