|  | 
|  | 1 | +{ | 
|  | 2 | + "cells": [ | 
|  | 3 | +  { | 
|  | 4 | +   "cell_type": "code", | 
|  | 5 | +   "execution_count": null, | 
|  | 6 | +   "id": "46a6ad6d", | 
|  | 7 | +   "metadata": {}, | 
|  | 8 | +   "outputs": [], | 
|  | 9 | +   "source": [ | 
|  | 10 | +    "%load_ext autoreload\n", | 
|  | 11 | +    "%autoreload 2" | 
|  | 12 | +   ] | 
|  | 13 | +  }, | 
|  | 14 | +  { | 
|  | 15 | +   "cell_type": "code", | 
|  | 16 | +   "execution_count": null, | 
|  | 17 | +   "id": "96d51078", | 
|  | 18 | +   "metadata": {}, | 
|  | 19 | +   "outputs": [], | 
|  | 20 | +   "source": [ | 
|  | 21 | +    "%%html\n", | 
|  | 22 | +    "<style>\n", | 
|  | 23 | +    ".cell-output-ipywidget-background {\n", | 
|  | 24 | +    "    background-color: transparent !important;\n", | 
|  | 25 | +    "}\n", | 
|  | 26 | +    ":root {\n", | 
|  | 27 | +    "    --jp-widgets-color: var(--vscode-editor-foreground);\n", | 
|  | 28 | +    "    --jp-widgets-font-size: var(--vscode-editor-font-size);\n", | 
|  | 29 | +    "}  \n", | 
|  | 30 | +    "</style>" | 
|  | 31 | +   ] | 
|  | 32 | +  }, | 
|  | 33 | +  { | 
|  | 34 | +   "cell_type": "code", | 
|  | 35 | +   "execution_count": null, | 
|  | 36 | +   "id": "7dd70e04", | 
|  | 37 | +   "metadata": {}, | 
|  | 38 | +   "outputs": [], | 
|  | 39 | +   "source": [ | 
|  | 40 | +    "import polars as pl\n", | 
|  | 41 | +    "\n", | 
|  | 42 | +    "splits = {\n", | 
|  | 43 | +    "    \"testmini\": \"data/testmini-00000-of-00001-725687bf7a18d64b.parquet\",\n", | 
|  | 44 | +    "    \"test\": \"data/test-*.parquet\",\n", | 
|  | 45 | +    "}\n", | 
|  | 46 | +    "df = pl.read_parquet(\"hf://datasets/AI4Math/MathVista/\" + splits[\"testmini\"]).sample(\n", | 
|  | 47 | +    "    fraction=1.0, shuffle=True, seed=42\n", | 
|  | 48 | +    ")" | 
|  | 49 | +   ] | 
|  | 50 | +  }, | 
|  | 51 | +  { | 
|  | 52 | +   "cell_type": "code", | 
|  | 53 | +   "execution_count": null, | 
|  | 54 | +   "id": "81e02b97", | 
|  | 55 | +   "metadata": {}, | 
|  | 56 | +   "outputs": [], | 
|  | 57 | +   "source": [ | 
|  | 58 | +    "from typing import Iterator, TypedDict, cast\n", | 
|  | 59 | +    "\n", | 
|  | 60 | +    "\n", | 
|  | 61 | +    "class DecodedImage(TypedDict):\n", | 
|  | 62 | +    "    bytes: bytes\n", | 
|  | 63 | +    "\n", | 
|  | 64 | +    "\n", | 
|  | 65 | +    "class Scenario(TypedDict):\n", | 
|  | 66 | +    "    pid: int\n", | 
|  | 67 | +    "    question: str\n", | 
|  | 68 | +    "    answer: str\n", | 
|  | 69 | +    "    image: str\n", | 
|  | 70 | +    "    decoded_image: DecodedImage\n", | 
|  | 71 | +    "\n", | 
|  | 72 | +    "\n", | 
|  | 73 | +    "val_scenarios = cast(list[Scenario], df.head(64).to_dicts())\n", | 
|  | 74 | +    "train_scenarios_iter = cast(Iterator[Scenario], df.tail(-64).iter_rows(named=True))" | 
|  | 75 | +   ] | 
|  | 76 | +  }, | 
|  | 77 | +  { | 
|  | 78 | +   "cell_type": "code", | 
|  | 79 | +   "execution_count": null, | 
|  | 80 | +   "id": "9287d8fe", | 
|  | 81 | +   "metadata": {}, | 
|  | 82 | +   "outputs": [], | 
|  | 83 | +   "source": [ | 
|  | 84 | +    "import re\n", | 
|  | 85 | +    "\n", | 
|  | 86 | +    "import art\n", | 
|  | 87 | +    "from art.local import LocalBackend\n", | 
|  | 88 | +    "\n", | 
|  | 89 | +    "model = art.TrainableModel(\n", | 
|  | 90 | +    "    name=\"002\",\n", | 
|  | 91 | +    "    project=\"math-vista\",\n", | 
|  | 92 | +    "    base_model=\"Qwen/Qwen2.5-VL-7B-Instruct\",\n", | 
|  | 93 | +    ")\n", | 
|  | 94 | +    "backend = LocalBackend()\n", | 
|  | 95 | +    "await model.register(backend)\n", | 
|  | 96 | +    "client = model.openai_client()" | 
|  | 97 | +   ] | 
|  | 98 | +  }, | 
|  | 99 | +  { | 
|  | 100 | +   "cell_type": "code", | 
|  | 101 | +   "execution_count": null, | 
|  | 102 | +   "id": "c92b4b11", | 
|  | 103 | +   "metadata": {}, | 
|  | 104 | +   "outputs": [], | 
|  | 105 | +   "source": [ | 
|  | 106 | +    "async def rollout(scenario: Scenario) -> art.Trajectory:\n", | 
|  | 107 | +    "    image_path = f\"/tmp/{scenario['image']}\"\n", | 
|  | 108 | +    "\n", | 
|  | 109 | +    "    import os\n", | 
|  | 110 | +    "\n", | 
|  | 111 | +    "    os.makedirs(os.path.dirname(image_path), exist_ok=True)\n", | 
|  | 112 | +    "\n", | 
|  | 113 | +    "    with open(image_path, \"wb\") as f:\n", | 
|  | 114 | +    "        f.write(scenario[\"decoded_image\"][\"bytes\"])\n", | 
|  | 115 | +    "\n", | 
|  | 116 | +    "    trajectory = art.Trajectory(messages_and_choices=[], reward=0.0)\n", | 
|  | 117 | +    "    trajectory.messages_and_choices = [\n", | 
|  | 118 | +    "        {\n", | 
|  | 119 | +    "            \"role\": \"user\",\n", | 
|  | 120 | +    "            \"content\": [\n", | 
|  | 121 | +    "                {\n", | 
|  | 122 | +    "                    \"type\": \"text\",\n", | 
|  | 123 | +    "                    \"text\": scenario[\"question\"]\n", | 
|  | 124 | +    "                    + \"\\n\\nNote: Provide your answer in a LaTeX box.\",\n", | 
|  | 125 | +    "                },\n", | 
|  | 126 | +    "                {\"type\": \"image_url\", \"image_url\": {\"url\": f\"file://{image_path}\"}},\n", | 
|  | 127 | +    "            ],\n", | 
|  | 128 | +    "        }\n", | 
|  | 129 | +    "    ]\n", | 
|  | 130 | +    "    chat_completion = await client.chat.completions.create(\n", | 
|  | 131 | +    "        model=model.name, messages=trajectory.messages()\n", | 
|  | 132 | +    "    )\n", | 
|  | 133 | +    "    choice = chat_completion.choices[0]\n", | 
|  | 134 | +    "    trajectory.messages_and_choices.append(choice)\n", | 
|  | 135 | +    "    content = choice.message.content\n", | 
|  | 136 | +    "    assert content is not None\n", | 
|  | 137 | +    "    if matches := list(re.finditer(r\"\\\\boxed\\{(.*?)\\}\", content, re.DOTALL)):\n", | 
|  | 138 | +    "        match = matches[-1]\n", | 
|  | 139 | +    "        answer = match.group(1)\n", | 
|  | 140 | +    "        if answer.lower() == scenario[\"answer\"].lower():\n", | 
|  | 141 | +    "            trajectory.reward = 1.0\n", | 
|  | 142 | +    "    return trajectory" | 
|  | 143 | +   ] | 
|  | 144 | +  }, | 
|  | 145 | +  { | 
|  | 146 | +   "cell_type": "code", | 
|  | 147 | +   "execution_count": null, | 
|  | 148 | +   "id": "359e530d", | 
|  | 149 | +   "metadata": {}, | 
|  | 150 | +   "outputs": [], | 
|  | 151 | +   "source": [ | 
|  | 152 | +    "import asyncio\n", | 
|  | 153 | +    "import itertools\n", | 
|  | 154 | +    "\n", | 
|  | 155 | +    "SCENARIOS_PER_STEP = 8\n", | 
|  | 156 | +    "TRAJECTORY_GROUP_SIZE = 8\n", | 
|  | 157 | +    "start = await model.get_step()\n", | 
|  | 158 | +    "train_scenarios_iter = itertools.cycle(train_scenarios_iter)\n", | 
|  | 159 | +    "for _ in range(start * SCENARIOS_PER_STEP):\n", | 
|  | 160 | +    "    next(train_scenarios_iter)\n", | 
|  | 161 | +    "\n", | 
|  | 162 | +    "for i in range(start, 1000):\n", | 
|  | 163 | +    "    train_scenarios = [next(train_scenarios_iter) for _ in range(SCENARIOS_PER_STEP)]\n", | 
|  | 164 | +    "    val_trajectories, train_trajectory_groups = await asyncio.gather(\n", | 
|  | 165 | +    "        art.gather_trajectories(\n", | 
|  | 166 | +    "            (rollout(scenario) for scenario in val_scenarios),\n", | 
|  | 167 | +    "            pbar_desc=\"gather(val)\",\n", | 
|  | 168 | +    "            max_exceptions=32,\n", | 
|  | 169 | +    "        ),\n", | 
|  | 170 | +    "        art.gather_trajectory_groups(\n", | 
|  | 171 | +    "            (\n", | 
|  | 172 | +    "                art.TrajectoryGroup(\n", | 
|  | 173 | +    "                    rollout(scenario) for _ in range(TRAJECTORY_GROUP_SIZE)\n", | 
|  | 174 | +    "                )\n", | 
|  | 175 | +    "                for scenario in train_scenarios\n", | 
|  | 176 | +    "            ),\n", | 
|  | 177 | +    "            pbar_desc=\"gather(train)\",\n", | 
|  | 178 | +    "            max_exceptions=32,\n", | 
|  | 179 | +    "        ),\n", | 
|  | 180 | +    "    )\n", | 
|  | 181 | +    "    await model.log(val_trajectories)\n", | 
|  | 182 | +    "    await model.train(train_trajectory_groups)" | 
|  | 183 | +   ] | 
|  | 184 | +  } | 
|  | 185 | + ], | 
|  | 186 | + "metadata": { | 
|  | 187 | +  "kernelspec": { | 
|  | 188 | +   "display_name": ".venv", | 
|  | 189 | +   "language": "python", | 
|  | 190 | +   "name": "python3" | 
|  | 191 | +  }, | 
|  | 192 | +  "language_info": { | 
|  | 193 | +   "codemirror_mode": { | 
|  | 194 | +    "name": "ipython", | 
|  | 195 | +    "version": 3 | 
|  | 196 | +   }, | 
|  | 197 | +   "file_extension": ".py", | 
|  | 198 | +   "mimetype": "text/x-python", | 
|  | 199 | +   "name": "python", | 
|  | 200 | +   "nbconvert_exporter": "python", | 
|  | 201 | +   "pygments_lexer": "ipython3", | 
|  | 202 | +   "version": "3.10.13" | 
|  | 203 | +  } | 
|  | 204 | + }, | 
|  | 205 | + "nbformat": 4, | 
|  | 206 | + "nbformat_minor": 5 | 
|  | 207 | +} | 
0 commit comments