diff --git a/README.md b/README.md index ef8ae19..0d49939 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ You can find the Gemma models on the Hugging Face Hub, Kaggle, Google Cloud Vert * [CodeGemma](CodeGemma/README.md) * [PaliGemma](PaliGemma/README.md) * [Workshops and technical talks](Workshops/README.md) +* [Research](Research/): Notebooks for research focused models * [Showcase complex end-to-end use cases](Demos/README.md) * [Gemma on Google Cloud](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/open-models) : GCP open models has additional notebooks for using Gemma diff --git a/Research/[T5Gemma]Example.ipynb b/Research/[T5Gemma]Example.ipynb new file mode 100644 index 0000000..4587544 --- /dev/null +++ b/Research/[T5Gemma]Example.ipynb @@ -0,0 +1,1076 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Tce3stUlHN0L" + }, + "source": [ + "##### Copyright 2025 Google LLC." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "tuOe1ymfHZPu" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0O690N_JP3vQ" + }, + "source": [ + "# T5Gemma Example\n", + "\n", + "\n", + " \n", + "
\n", + " Run in Google Colab\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RP8VIh0ze87W" + }, + "source": [ + "We present [T5Gemma (aka encoder-decoder Gemma)](https://arxiv.org/abs/2504.06225), a family of encoder-decoder large langauge models, developed by adapting pretrained decoder-only models into encoder-decoder.\n", + "\n", + "T5Gemma includes pretrained and instruction-tuned variants, each with two groups of scales:\n", + "* [Gemma 2 scale](https://ai.google.dev/gemma/docs/core/model_card_2): 2B-2B, 9B-2B, and 9B-9B\n", + "* [T5 scale](https://arxiv.org/abs/1910.10683): Small, Base, Large, and XL. An additional ML scale model is added which is in-between T5 Large and T5 XL.\n", + "\n", + "In this notebook, we walk you through how to sampling (and tuning) with T5Gemma Small using Flax and Huggingface.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vfwOEx5Q1hbI" + }, + "source": [ + "# Huggingface" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Bsllv8MOzXoC" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "!pip install -q git+https://github.com/huggingface/transformers.git\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x36NV-4m0Fb-" + }, + "source": [ + "## HF login" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MQBH5vy40DKv" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "692dfae98c5049ebb58c496cca8a6ab5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='
(()=>{ if (customElements.get('treescope-container') === undefined) { class TreescopeContainer extends HTMLElement { constructor() { super(); this.attachShadow({mode: \"open\"}); this.defns = {}; this.state = {}; } } customElements.define(\"treescope-container\", TreescopeContainer); } if (customElements.get('treescope-run-here') === undefined) { class RunHere extends HTMLElement { constructor() { super() } connectedCallback() { const run = child => { const fn = new Function(child.textContent); child.textContent = \"\"; fn.call(this); this.remove(); }; const child = this.querySelector(\"script\"); if (child) { run(child); } else { new MutationObserver(()=>{ run(this.querySelector(\"script\")); }).observe(this, {childList: true}); } } } customElements.define(\"treescope-run-here\", RunHere); } })();
" + ], + "text/plain": [ + "{\n", + " 'decoder_input': # np.ndarray int64(8, 200) [≥0, ≤235_382] zero:1_395 nonzero:205\n", + " array([[ 2, 9776, 235290, ..., 0, 0, 0],\n", + " [ 2, 74259, 28399, ..., 0, 0, 0],\n", + " [ 2, 179284, 1804, ..., 0, 0, 0],\n", + " ...,\n", + " [ 2, 138447, 21392, ..., 0, 0, 0],\n", + " [ 2, 6151, 69182, ..., 0, 0, 0],\n", + " [ 2, 235277, 235303, ..., 0, 0, 0]])\n", + " ,\n", + " 'encoder_input': ,\n", + " 'input': ,\n", + " 'loss_mask': ,\n", + " 'target': ,\n", + "}" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import dataclasses\n", + "from etils import enp\n", + "from grain import python as grain\n", + "\n", + "@dataclasses.dataclass(kw_only=True, frozen=True)\n", + "class Deconly2EncDecPreprocessor(grain.MapTransform):\n", + "\n", + " in_input: kd.kontext.Key # \"input\"\n", + " in_target: kd.kontext.Key # \"target\"\n", + " in_loss_mask: kd.kontext.Key # \"loss_mask\"\n", + "\n", + " out_encoder_input: kd.kontext.Key # \"encoder_input\"\n", + " out_decoder_input: kd.kontext.Key # \"decoder_input\"\n", + " out_target: kd.kontext.Key # \"target\"\n", + " out_loss_mask: kd.kontext.Key # \"loss_mask\"\n", + "\n", + " pad_id: int = 0\n", + " max_len: int | None = None\n", + "\n", + "\n", + " def map(self, element):\n", + " \"\"\"Preprocess converting deconly example to encoder-decoder.\n", + "\n", + " Example:\n", + " Deconly:\n", + " Input: A B C 1 2 3\n", + " Target: A B C 1 2 3 \n", + " Loss Mask: 0 0 0 1 1 1 1\n", + "\n", + " ==>\n", + " Encoder-Decoder:\n", + " Encoder Input: A B C\n", + " Decoder Input: 1 2 3\n", + " Target: 1 2 3 \n", + " Loss Mask: 1 1 1 1\n", + " Args:\n", + " element: input single example in a dictionary format.\n", + " Returns:\n", + " A dictionary of preprocessed examples for encoder-decoder modeling.\n", + " \"\"\"\n", + " # Extract the values from the `dict` example.\n", + " deconly_input = kd.kontext.get_by_path(element, self.in_input)\n", + " deconly_target = kd.kontext.get_by_path(element, self.in_target)\n", + " deconly_loss_mask = kd.kontext.get_by_path(element, self.in_loss_mask)\n", + "\n", + " xnp = enp.lazy.get_xnp(deconly_input, strict=False)\n", + "\n", + " deconly_target = deconly_target[..., 0]\n", + " deconly_loss_mask = deconly_loss_mask[..., 0]\n", + " deconly_input_mask = deconly_input != self.pad_id\n", + " seq_len = deconly_input.shape[0]\n", + "\n", + " # Encoder input tokens\n", + " # Encoder mask -> positions -> gather input tokens from positions\n", + " # [1, 1, 1, 0, 0, 0, 0]\n", + " encdec_encoder_input_mask = xnp.logical_and(\n", + " ~deconly_loss_mask,\n", + " deconly_input_mask,\n", + " ).astype(xnp.int32)\n", + " # We didn't subtract it by 1 due to skipping \n", + " # [1, 2, 3, 0, 0, 0, 0]\n", + " encdec_encoder_input_positions = xnp.cumsum(\n", + " encdec_encoder_input_mask, axis=-1\n", + " ) * encdec_encoder_input_mask\n", + " # To avoid input-only errors\n", + " encdec_encoder_input_positions *= (\n", + " encdec_encoder_input_positions < seq_len\n", + " ).astype(xnp.int32)\n", + " # [A, B, C, 0, 0, 0, 0]\n", + " encdec_encoder_input_tokens = xnp.take_along_axis(\n", + " deconly_input, encdec_encoder_input_positions, axis=-1\n", + " ) * encdec_encoder_input_mask\n", + "\n", + " # Decoder input tokens\n", + " # Decoder mask -> positions -> move to beginning by sorting -> gather tokens\n", + " # [3]\n", + " num_encoder_tokens = xnp.sum(\n", + " encdec_encoder_input_mask, axis=-1, keepdims=True\n", + " )\n", + " # [0, 0, 0, 1, 1, 1, 1]\n", + " encdec_decoder_mask = xnp.logical_and(\n", + " deconly_loss_mask,\n", + " deconly_input_mask,\n", + " ).astype(xnp.int32)\n", + " # [0, 0, 0, 1, 2, 3, 4]\n", + " encdec_decoder_positions = xnp.cumsum(\n", + " encdec_decoder_mask, axis=-1\n", + " ) * encdec_decoder_mask\n", + " # Invalid tokens are set to seq_len+1\n", + " # [8, 8, 8, 1, 2, 3, 4]\n", + " encdec_decoder_positions += (1 - encdec_decoder_mask) * (seq_len+1)\n", + " # After sorting, all valid tokens are put into the beginning in order\n", + " # [1, 2, 3, 4, 8, 8, 8]\n", + " encdec_decoder_positions = xnp.sort(\n", + " encdec_decoder_positions, axis=-1\n", + " )\n", + " # Valid tokens should have positions <= seq_len\n", + " # [1, 1, 1, 1, 0, 0, 0]\n", + " encdec_decoder_mask = (\n", + " encdec_decoder_positions <= seq_len).astype(xnp.int32)\n", + " # [4, 5, 6, 7, 11, 11, 11]\n", + " encdec_decoder_positions += num_encoder_tokens\n", + " # [3, 4, 5, 6, 0, 0, 0]\n", + " encdec_decoder_target_positions = (\n", + " encdec_decoder_positions - 1\n", + " ) * encdec_decoder_mask\n", + " # The first token now changed to for decoder input\n", + " # [0, 4, 5, 6, 0, 0, 0]\n", + " encdec_decoder_input_positions = xnp.pad(\n", + " encdec_decoder_positions,\n", + " ((1, 0)),\n", + " 'constant',\n", + " constant_values=0,\n", + " )[:-1]\n", + " encdec_decoder_input_positions *= encdec_decoder_mask\n", + "\n", + " # [, 1, 2, 3, 0, 0, 0]\n", + " encdec_decoder_input_tokens = xnp.take_along_axis(\n", + " deconly_input, encdec_decoder_input_positions, axis=-1\n", + " ) * encdec_decoder_mask\n", + " # [1, 2, 3, , 0, 0, 0]\n", + " encdec_decoder_target_tokens = xnp.take_along_axis(\n", + " deconly_target, encdec_decoder_target_positions, axis=-1\n", + " ) * encdec_decoder_mask\n", + "\n", + " max_len = self.max_len\n", + " if max_len is None:\n", + " max_len = seq_len\n", + "\n", + " # Add the fields to the output `dict`.\n", + " # Equivalent to `element[self.out_input] = ...`\n", + " kd.kontext.set_by_path(\n", + " element,\n", + " self.out_encoder_input,\n", + " encdec_encoder_input_tokens[:max_len],\n", + " )\n", + " kd.kontext.set_by_path(\n", + " element,\n", + " self.out_decoder_input,\n", + " encdec_decoder_input_tokens[:max_len],\n", + " )\n", + " kd.kontext.set_by_path(\n", + " element,\n", + " self.out_target,\n", + " encdec_decoder_target_tokens[:max_len, None],\n", + " )\n", + " kd.kontext.set_by_path(\n", + " element,\n", + " self.out_loss_mask,\n", + " encdec_decoder_mask[:max_len, None],\n", + " )\n", + " return element\n", + "\n", + "ds = kd.data.py.Tfds(\n", + " name='mtnt/en-fr',\n", + " split='train',\n", + " shuffle=True,\n", + " batch_size=8,\n", + " transforms=[\n", + " # Create the model inputs/targets/loss_mask.\n", + " gm.data.Seq2SeqTask(\n", + " # Select which field from the dataset to use.\n", + " # https://www.tensorflow.org/datasets/catalog/mtnt\n", + " in_prompt='src',\n", + " in_response='dst',\n", + " # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}\n", + " out_input='input',\n", + " out_target='target',\n", + " out_target_mask='loss_mask',\n", + " tokenizer=preset.tokenizer,\n", + " # Padding parameters\n", + " max_length=200,\n", + " truncate=True,\n", + " ),\n", + " Deconly2EncDecPreprocessor(\n", + " in_input='input',\n", + " in_target='target',\n", + " in_loss_mask='loss_mask',\n", + " out_encoder_input='encoder_input',\n", + " out_decoder_input='decoder_input',\n", + " out_target='target',\n", + " out_loss_mask='loss_mask',\n", + " pad_id=preset.tokenizer.special_tokens.PAD,\n", + " max_len=200,\n", + " ),\n", + " ],\n", + ")\n", + "\n", + "ex = ds[0]\n", + "\n", + "treescope.show(ex)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qj2QY1T1t_LC" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "decoder_input\n", + "```\n", + "Est-ce que les femmes passent encore la nuit avant le mariage dans une autre chambre que celle de leur fiancée ???```\n", + "\n", + "encoder_input\n", + "```\n", + "user\n", + "Do woman still spend the night before their wedding away from their fiancee???\n", + "model\n", + "```\n", + "\n" + ] + } + ], + "source": [ + "for k in ex:\n", + " if k in ['decoder_input', 'encoder_input']:\n", + " print(f\"{k}\\n```\\n{preset.tokenizer.decode(ex[k][0])}```\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L_ND9CJDlcSy" + }, + "source": [ + "### Trainer\n", + "\n", + "Based on [kauldron](https://kauldron.readthedocs.io/en/latest/), following Gemma.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jMb2KpHR7vo5" + }, + "outputs": [], + "source": [ + "loss = kd.losses.SoftmaxCrossEntropyWithIntLabels(\n", + " logits=\"preds.logits\",\n", + " labels=\"batch.target\",\n", + " mask=\"batch.loss_mask\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Bv854FDSn7Z-" + }, + "outputs": [], + "source": [ + "model = preset.config.make(\n", + " \"transformer\",\n", + " input_tokens=\"batch.encoder_input\",\n", + " target_tokens=\"batch.decoder_input\",\n", + ")\n", + "\n", + "checkpoint = preset.get_checkpoint_from_kaggle(\n", + " t5gemma.CKPTType.IT,\n", + " t5gemma.PretrainType.PREFIXLM,\n", + ")\n", + "\n", + "trainer = kd.train.Trainer(\n", + " seed=42, # The seed of enlightenment\n", + " workdir='/tmp/ckpts',\n", + " # Dataset\n", + " train_ds=ds,\n", + " # Model\n", + " model=model,\n", + " # Params\n", + " init_transform=gm.ckpts.LoadCheckpoint(checkpoint),\n", + " # Training parameters\n", + " num_train_steps=500,\n", + " train_losses={\"loss\": loss},\n", + " optimizer=optax.adafactor(learning_rate=1e-4),\n", + " sharding=kd.sharding.ShardingStrategy(\n", + " ds=kd.sharding.FIRST_DIM,\n", + " params=kd.sharding.FSDPSharding(),\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xvIDsFPz75GT" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting training loop at step 0\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2c5f9397503141a7bcabb2eed5f20383", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "train: 0%| | 0/501 [00:00\n" + ] + } + ], + "source": [ + "sampler = t5gemma.Sampler(\n", + " model=model,\n", + " params=state.params,\n", + " tokenizer=preset.tokenizer,\n", + ")\n", + "\n", + "output = sampler.sample('user\\nHello! My next holidays are in Paris.\\nmodel\\n')\n", + "\n", + "print(output)" + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "name": "[T5Gemma]Example.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}