diff --git a/README.md b/README.md
index ef8ae19..0d49939 100644
--- a/README.md
+++ b/README.md
@@ -39,6 +39,7 @@ You can find the Gemma models on the Hugging Face Hub, Kaggle, Google Cloud Vert
* [CodeGemma](CodeGemma/README.md)
* [PaliGemma](PaliGemma/README.md)
* [Workshops and technical talks](Workshops/README.md)
+* [Research](Research/): Notebooks for research focused models
* [Showcase complex end-to-end use cases](Demos/README.md)
* [Gemma on Google Cloud](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/open-models) : GCP open models has additional notebooks for using Gemma
diff --git a/Research/[T5Gemma]Example.ipynb b/Research/[T5Gemma]Example.ipynb
new file mode 100644
index 0000000..4587544
--- /dev/null
+++ b/Research/[T5Gemma]Example.ipynb
@@ -0,0 +1,1076 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Tce3stUlHN0L"
+ },
+ "source": [
+ "##### Copyright 2025 Google LLC."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "tuOe1ymfHZPu"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0O690N_JP3vQ"
+ },
+ "source": [
+ "# T5Gemma Example\n",
+ "\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RP8VIh0ze87W"
+ },
+ "source": [
+ "We present [T5Gemma (aka encoder-decoder Gemma)](https://arxiv.org/abs/2504.06225), a family of encoder-decoder large langauge models, developed by adapting pretrained decoder-only models into encoder-decoder.\n",
+ "\n",
+ "T5Gemma includes pretrained and instruction-tuned variants, each with two groups of scales:\n",
+ "* [Gemma 2 scale](https://ai.google.dev/gemma/docs/core/model_card_2): 2B-2B, 9B-2B, and 9B-9B\n",
+ "* [T5 scale](https://arxiv.org/abs/1910.10683): Small, Base, Large, and XL. An additional ML scale model is added which is in-between T5 Large and T5 XL.\n",
+ "\n",
+ "In this notebook, we walk you through how to sampling (and tuning) with T5Gemma Small using Flax and Huggingface.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vfwOEx5Q1hbI"
+ },
+ "source": [
+ "# Huggingface"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Bsllv8MOzXoC"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Building wheel for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install -q git+https://github.com/huggingface/transformers.git\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "x36NV-4m0Fb-"
+ },
+ "source": [
+ "## HF login"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "MQBH5vy40DKv"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "692dfae98c5049ebb58c496cca8a6ab5",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(HTML(value='
user\\n{user_input}\\nmodel\\n'\n",
+ "prompt = chat_template.format(\n",
+ " user_input='Tell me an unknown interesting biology fact about the brain.'\n",
+ ")\n",
+ "\n",
+ "input_ids = tokenizer(prompt, return_tensors=\"pt\")\n",
+ "output = model.generate(**input_ids, max_new_tokens=128)\n",
+ "\n",
+ "print(tokenizer.decode(output[0], skip_special_tokens=True))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "out6yK8hhQSf"
+ },
+ "source": [
+ "# Flax\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "OgphmBO0ow6r"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m61.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m486.6/486.6 kB\u001b[0m \u001b[31m32.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.4/55.4 kB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m400.4/400.4 kB\u001b[0m \u001b[31m31.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.3/65.3 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m101.8/101.8 kB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m47.3/47.3 kB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m111.0/111.0 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.7/76.7 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.0/63.0 MB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m644.9/644.9 MB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m319.9/319.9 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m224.9/224.9 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m159.9/159.9 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.5/57.5 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.5/24.5 MB\u001b[0m \u001b[31m81.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m119.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.1/5.1 MB\u001b[0m \u001b[31m98.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.3/49.3 kB\u001b[0m \u001b[31m4.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m147.8/147.8 kB\u001b[0m \u001b[31m12.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.7/13.7 MB\u001b[0m \u001b[31m124.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.9/7.9 MB\u001b[0m \u001b[31m127.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.9/1.9 MB\u001b[0m \u001b[31m74.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.4/44.4 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.5/4.5 MB\u001b[0m \u001b[31m107.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m313.6/313.6 kB\u001b[0m \u001b[31m19.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m253.6/253.6 kB\u001b[0m \u001b[31m20.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m394.3/394.3 kB\u001b[0m \u001b[31m27.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m62.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m86.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.6/6.6 MB\u001b[0m \u001b[31m117.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m224.5/224.5 kB\u001b[0m \u001b[31m17.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m72.5/72.5 kB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Building wheel for gemma (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Building wheel for sqlalchemy (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install -q git+https://github.com/google-deepmind/gemma.git\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "l5IYD-_Jxf8G"
+ },
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "TZQiYQy7EJe3"
+ },
+ "outputs": [],
+ "source": [
+ "from etils import ecolab\n",
+ "import os\n",
+ "import optax\n",
+ "import treescope\n",
+ "import kagglehub\n",
+ "\n",
+ "\n",
+ "from kauldron import kd\n",
+ "from gemma import gm\n",
+ "\n",
+ "from gemma.research import t5gemma\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "D4A5IVRxxkcs"
+ },
+ "source": [
+ "## Kaggle login"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Z2R5XZlPxjtl"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "9b5ec9395c7740b5a98d8f410796ca9c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "VBox(children=(HTML(value='
user\\n{user_input}\\nmodel\\n'\n",
+ "ouptut = sampler.sample(\n",
+ " chat_template.format(\n",
+ " user_input='Tell me an unknown interesting biology fact about the brain.'\n",
+ " ),\n",
+ " max_new_tokens=32,\n",
+ ")\n",
+ "\n",
+ "print(ouptut)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GDe3ZYLERFun"
+ },
+ "source": [
+ "## Finetuning\n",
+ "\n",
+ "A simple example of finetuning encoder-decoder for machine translation."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "sK10MO1SqkMy"
+ },
+ "source": [
+ "### Preprocessor\n",
+ "\n",
+ "Convert decoder-only format to encoder-decoder."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qf3-uXF6n2e0"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:absl:Variant folder /root/tensorflow_datasets/mtnt/en-fr/1.0.0 has no dataset_info.json\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0...\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "28cbe3a0eb23486fb6735541ccaee5ba",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Dl Completed...: 0 url [00:00, ? url/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f277b57566254d8ba511854982be79f1",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Dl Size...: 0 MiB [00:00, ? MiB/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "180d8d3d715f492aa8d4761d3aebd6b9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Extraction completed...: 0 file [00:00, ? file/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "156e9ec5fc264904881f766e37272432",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating splits...: 0%| | 0/3 [00:00, ? splits/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "baddb8c2266d4c1aa8e29ce4c1f59456",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating train examples...: 0 examples [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d8bf0d46c344423eac75bb783b64e28d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Shuffling /root/tensorflow_datasets/mtnt/en-fr/incomplete.2PGJ2G_1.0.0/mtnt-train.array_record*...: 0%| …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b4f88587eee34b2ba5724218d11a577f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating test examples...: 0 examples [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "9c9d6e34ac374a34b0dfaf963806240c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Shuffling /root/tensorflow_datasets/mtnt/en-fr/incomplete.2PGJ2G_1.0.0/mtnt-test.array_record*...: 0%| …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "00cc593e4cb14116997e9987e68f3a89",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating valid examples...: 0 examples [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f8331beada5f4323af40eb221de3d091",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Shuffling /root/tensorflow_datasets/mtnt/en-fr/incomplete.2PGJ2G_1.0.0/mtnt-valid.array_record*...: 0%| …"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ "{\n",
+ " 'decoder_input': # np.ndarray int64(8, 200) [≥0, ≤235_382] zero:1_395 nonzero:205\n",
+ " array([[ 2, 9776, 235290, ..., 0, 0, 0],\n",
+ " [ 2, 74259, 28399, ..., 0, 0, 0],\n",
+ " [ 2, 179284, 1804, ..., 0, 0, 0],\n",
+ " ...,\n",
+ " [ 2, 138447, 21392, ..., 0, 0, 0],\n",
+ " [ 2, 6151, 69182, ..., 0, 0, 0],\n",
+ " [ 2, 235277, 235303, ..., 0, 0, 0]])\n",
+ " ,\n",
+ " 'encoder_input': ,\n",
+ " 'input': ,\n",
+ " 'loss_mask': ,\n",
+ " 'target': ,\n",
+ "}"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import dataclasses\n",
+ "from etils import enp\n",
+ "from grain import python as grain\n",
+ "\n",
+ "@dataclasses.dataclass(kw_only=True, frozen=True)\n",
+ "class Deconly2EncDecPreprocessor(grain.MapTransform):\n",
+ "\n",
+ " in_input: kd.kontext.Key # \"input\"\n",
+ " in_target: kd.kontext.Key # \"target\"\n",
+ " in_loss_mask: kd.kontext.Key # \"loss_mask\"\n",
+ "\n",
+ " out_encoder_input: kd.kontext.Key # \"encoder_input\"\n",
+ " out_decoder_input: kd.kontext.Key # \"decoder_input\"\n",
+ " out_target: kd.kontext.Key # \"target\"\n",
+ " out_loss_mask: kd.kontext.Key # \"loss_mask\"\n",
+ "\n",
+ " pad_id: int = 0\n",
+ " max_len: int | None = None\n",
+ "\n",
+ "\n",
+ " def map(self, element):\n",
+ " \"\"\"Preprocess converting deconly example to encoder-decoder.\n",
+ "\n",
+ " Example:\n",
+ " Deconly:\n",
+ " Input: A B C 1 2 3\n",
+ " Target: A B C 1 2 3 \n",
+ " Loss Mask: 0 0 0 1 1 1 1\n",
+ "\n",
+ " ==>\n",
+ " Encoder-Decoder:\n",
+ " Encoder Input: A B C\n",
+ " Decoder Input: 1 2 3\n",
+ " Target: 1 2 3 \n",
+ " Loss Mask: 1 1 1 1\n",
+ " Args:\n",
+ " element: input single example in a dictionary format.\n",
+ " Returns:\n",
+ " A dictionary of preprocessed examples for encoder-decoder modeling.\n",
+ " \"\"\"\n",
+ " # Extract the values from the `dict` example.\n",
+ " deconly_input = kd.kontext.get_by_path(element, self.in_input)\n",
+ " deconly_target = kd.kontext.get_by_path(element, self.in_target)\n",
+ " deconly_loss_mask = kd.kontext.get_by_path(element, self.in_loss_mask)\n",
+ "\n",
+ " xnp = enp.lazy.get_xnp(deconly_input, strict=False)\n",
+ "\n",
+ " deconly_target = deconly_target[..., 0]\n",
+ " deconly_loss_mask = deconly_loss_mask[..., 0]\n",
+ " deconly_input_mask = deconly_input != self.pad_id\n",
+ " seq_len = deconly_input.shape[0]\n",
+ "\n",
+ " # Encoder input tokens\n",
+ " # Encoder mask -> positions -> gather input tokens from positions\n",
+ " # [1, 1, 1, 0, 0, 0, 0]\n",
+ " encdec_encoder_input_mask = xnp.logical_and(\n",
+ " ~deconly_loss_mask,\n",
+ " deconly_input_mask,\n",
+ " ).astype(xnp.int32)\n",
+ " # We didn't subtract it by 1 due to skipping \n",
+ " # [1, 2, 3, 0, 0, 0, 0]\n",
+ " encdec_encoder_input_positions = xnp.cumsum(\n",
+ " encdec_encoder_input_mask, axis=-1\n",
+ " ) * encdec_encoder_input_mask\n",
+ " # To avoid input-only errors\n",
+ " encdec_encoder_input_positions *= (\n",
+ " encdec_encoder_input_positions < seq_len\n",
+ " ).astype(xnp.int32)\n",
+ " # [A, B, C, 0, 0, 0, 0]\n",
+ " encdec_encoder_input_tokens = xnp.take_along_axis(\n",
+ " deconly_input, encdec_encoder_input_positions, axis=-1\n",
+ " ) * encdec_encoder_input_mask\n",
+ "\n",
+ " # Decoder input tokens\n",
+ " # Decoder mask -> positions -> move to beginning by sorting -> gather tokens\n",
+ " # [3]\n",
+ " num_encoder_tokens = xnp.sum(\n",
+ " encdec_encoder_input_mask, axis=-1, keepdims=True\n",
+ " )\n",
+ " # [0, 0, 0, 1, 1, 1, 1]\n",
+ " encdec_decoder_mask = xnp.logical_and(\n",
+ " deconly_loss_mask,\n",
+ " deconly_input_mask,\n",
+ " ).astype(xnp.int32)\n",
+ " # [0, 0, 0, 1, 2, 3, 4]\n",
+ " encdec_decoder_positions = xnp.cumsum(\n",
+ " encdec_decoder_mask, axis=-1\n",
+ " ) * encdec_decoder_mask\n",
+ " # Invalid tokens are set to seq_len+1\n",
+ " # [8, 8, 8, 1, 2, 3, 4]\n",
+ " encdec_decoder_positions += (1 - encdec_decoder_mask) * (seq_len+1)\n",
+ " # After sorting, all valid tokens are put into the beginning in order\n",
+ " # [1, 2, 3, 4, 8, 8, 8]\n",
+ " encdec_decoder_positions = xnp.sort(\n",
+ " encdec_decoder_positions, axis=-1\n",
+ " )\n",
+ " # Valid tokens should have positions <= seq_len\n",
+ " # [1, 1, 1, 1, 0, 0, 0]\n",
+ " encdec_decoder_mask = (\n",
+ " encdec_decoder_positions <= seq_len).astype(xnp.int32)\n",
+ " # [4, 5, 6, 7, 11, 11, 11]\n",
+ " encdec_decoder_positions += num_encoder_tokens\n",
+ " # [3, 4, 5, 6, 0, 0, 0]\n",
+ " encdec_decoder_target_positions = (\n",
+ " encdec_decoder_positions - 1\n",
+ " ) * encdec_decoder_mask\n",
+ " # The first token now changed to for decoder input\n",
+ " # [0, 4, 5, 6, 0, 0, 0]\n",
+ " encdec_decoder_input_positions = xnp.pad(\n",
+ " encdec_decoder_positions,\n",
+ " ((1, 0)),\n",
+ " 'constant',\n",
+ " constant_values=0,\n",
+ " )[:-1]\n",
+ " encdec_decoder_input_positions *= encdec_decoder_mask\n",
+ "\n",
+ " # [, 1, 2, 3, 0, 0, 0]\n",
+ " encdec_decoder_input_tokens = xnp.take_along_axis(\n",
+ " deconly_input, encdec_decoder_input_positions, axis=-1\n",
+ " ) * encdec_decoder_mask\n",
+ " # [1, 2, 3, , 0, 0, 0]\n",
+ " encdec_decoder_target_tokens = xnp.take_along_axis(\n",
+ " deconly_target, encdec_decoder_target_positions, axis=-1\n",
+ " ) * encdec_decoder_mask\n",
+ "\n",
+ " max_len = self.max_len\n",
+ " if max_len is None:\n",
+ " max_len = seq_len\n",
+ "\n",
+ " # Add the fields to the output `dict`.\n",
+ " # Equivalent to `element[self.out_input] = ...`\n",
+ " kd.kontext.set_by_path(\n",
+ " element,\n",
+ " self.out_encoder_input,\n",
+ " encdec_encoder_input_tokens[:max_len],\n",
+ " )\n",
+ " kd.kontext.set_by_path(\n",
+ " element,\n",
+ " self.out_decoder_input,\n",
+ " encdec_decoder_input_tokens[:max_len],\n",
+ " )\n",
+ " kd.kontext.set_by_path(\n",
+ " element,\n",
+ " self.out_target,\n",
+ " encdec_decoder_target_tokens[:max_len, None],\n",
+ " )\n",
+ " kd.kontext.set_by_path(\n",
+ " element,\n",
+ " self.out_loss_mask,\n",
+ " encdec_decoder_mask[:max_len, None],\n",
+ " )\n",
+ " return element\n",
+ "\n",
+ "ds = kd.data.py.Tfds(\n",
+ " name='mtnt/en-fr',\n",
+ " split='train',\n",
+ " shuffle=True,\n",
+ " batch_size=8,\n",
+ " transforms=[\n",
+ " # Create the model inputs/targets/loss_mask.\n",
+ " gm.data.Seq2SeqTask(\n",
+ " # Select which field from the dataset to use.\n",
+ " # https://www.tensorflow.org/datasets/catalog/mtnt\n",
+ " in_prompt='src',\n",
+ " in_response='dst',\n",
+ " # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}\n",
+ " out_input='input',\n",
+ " out_target='target',\n",
+ " out_target_mask='loss_mask',\n",
+ " tokenizer=preset.tokenizer,\n",
+ " # Padding parameters\n",
+ " max_length=200,\n",
+ " truncate=True,\n",
+ " ),\n",
+ " Deconly2EncDecPreprocessor(\n",
+ " in_input='input',\n",
+ " in_target='target',\n",
+ " in_loss_mask='loss_mask',\n",
+ " out_encoder_input='encoder_input',\n",
+ " out_decoder_input='decoder_input',\n",
+ " out_target='target',\n",
+ " out_loss_mask='loss_mask',\n",
+ " pad_id=preset.tokenizer.special_tokens.PAD,\n",
+ " max_len=200,\n",
+ " ),\n",
+ " ],\n",
+ ")\n",
+ "\n",
+ "ex = ds[0]\n",
+ "\n",
+ "treescope.show(ex)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "qj2QY1T1t_LC"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "decoder_input\n",
+ "```\n",
+ "Est-ce que les femmes passent encore la nuit avant le mariage dans une autre chambre que celle de leur fiancée ???```\n",
+ "\n",
+ "encoder_input\n",
+ "```\n",
+ "user\n",
+ "Do woman still spend the night before their wedding away from their fiancee???\n",
+ "model\n",
+ "```\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "for k in ex:\n",
+ " if k in ['decoder_input', 'encoder_input']:\n",
+ " print(f\"{k}\\n```\\n{preset.tokenizer.decode(ex[k][0])}```\\n\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "L_ND9CJDlcSy"
+ },
+ "source": [
+ "### Trainer\n",
+ "\n",
+ "Based on [kauldron](https://kauldron.readthedocs.io/en/latest/), following Gemma.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jMb2KpHR7vo5"
+ },
+ "outputs": [],
+ "source": [
+ "loss = kd.losses.SoftmaxCrossEntropyWithIntLabels(\n",
+ " logits=\"preds.logits\",\n",
+ " labels=\"batch.target\",\n",
+ " mask=\"batch.loss_mask\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Bv854FDSn7Z-"
+ },
+ "outputs": [],
+ "source": [
+ "model = preset.config.make(\n",
+ " \"transformer\",\n",
+ " input_tokens=\"batch.encoder_input\",\n",
+ " target_tokens=\"batch.decoder_input\",\n",
+ ")\n",
+ "\n",
+ "checkpoint = preset.get_checkpoint_from_kaggle(\n",
+ " t5gemma.CKPTType.IT,\n",
+ " t5gemma.PretrainType.PREFIXLM,\n",
+ ")\n",
+ "\n",
+ "trainer = kd.train.Trainer(\n",
+ " seed=42, # The seed of enlightenment\n",
+ " workdir='/tmp/ckpts',\n",
+ " # Dataset\n",
+ " train_ds=ds,\n",
+ " # Model\n",
+ " model=model,\n",
+ " # Params\n",
+ " init_transform=gm.ckpts.LoadCheckpoint(checkpoint),\n",
+ " # Training parameters\n",
+ " num_train_steps=500,\n",
+ " train_losses={\"loss\": loss},\n",
+ " optimizer=optax.adafactor(learning_rate=1e-4),\n",
+ " sharding=kd.sharding.ShardingStrategy(\n",
+ " ds=kd.sharding.FIRST_DIM,\n",
+ " params=kd.sharding.FSDPSharding(),\n",
+ " )\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xvIDsFPz75GT"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Starting training loop at step 0\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2c5f9397503141a7bcabb2eed5f20383",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "train: 0%| | 0/501 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "state, aux = trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "x54YaAteRV94"
+ },
+ "source": [
+ "### Sampling"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yM0l9EnPMdHf"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Bonjour ! Je vais faire des vacances en France.\n"
+ ]
+ }
+ ],
+ "source": [
+ "sampler = t5gemma.Sampler(\n",
+ " model=model,\n",
+ " params=state.params,\n",
+ " tokenizer=preset.tokenizer,\n",
+ ")\n",
+ "\n",
+ "output = sampler.sample('user\\nHello! My next holidays are in Paris.\\nmodel\\n')\n",
+ "\n",
+ "print(output)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "TPU",
+ "colab": {
+ "name": "[T5Gemma]Example.ipynb",
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}