Skip to content

Conversation

@itazap
Copy link
Collaborator

@itazap itazap commented Sep 17, 2025

Tokenization

Just as we moved towards a single backend library for model definition, we want Tokenizer to be a lot more intuitive.
With v5, you can now initialize an empty LlamaTokenizer and train it directly on your new task!

Defining a new tokenizer object should be as simple as this:

from transformers import TokenizersBackend, generate_merges
from tokenizers import pre_tokenizers, Tokenizer
from tokenizers.model import BPE

class Llama5Tokenizer(TokenizersBackend):
    def __init__(self,        unk_token="<unk>",bos_token="<s>", eos_token="</s>", vocab=None, merges=None ):
        if vocab is None:
            self._vocab = {
                str(unk_token): 0,
                str(bos_token): 1,
                str(eos_token): 2,
            }

        else:
            self._vocab = vocab

        if merges is not None:
            self._merges = merges
        else:
            self._merges = generate_merges(filtered_vocab)

        self._tokenizer = Tokenizer(
            BPE(vocab=self._vocab, merges=self._merges, fuse_unk=True)
        )
        self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
            replacement="▁", prepend_scheme=_get_prepend_scheme(self.add_prefix_space, self), split=False
        )
        super().__init__(
            tokenizer_object=self._tokenizer,
            unk_token=unk_token,
            bos_token=bos_token,
            eos_token=eos_token,
        )

And now if you call Llama5Tokenizer() you just get an empty, trainable tokenizer that follows the definition of the authors of Llama5 (it does not exist yet 😉).

The above is the main motivation towards refactoring tokenization: we want people to just instantiate a tokenizer like they would a model, empty or not and with exactly what they defined.

Non-tokenizers

If you tokenizers is not common, or you just don't want to rely on sentencepiece nor tokenizers you can just import the PythonBackend (previousl PreTrainedTokenzier) which has all the API and logic for added tokens, encoding and decoding wieht them etc.

If you want to have en less features, you can use the common PreTrainedTokenizerBase mixin, which mostly defines transformers tokenizer API: encode, decode, vocab_size, get_vocab, convert_tokens_to_ids, convert_ids_to_tokens, from_pretrained, save_pretrained, etc.

Backend Architecture Changes

Moving away from "slow" vs "fast" tokenizers:

Previously, transformers maintained two parallel implementations for many tokenizers:

  • "Slow" tokenizers (tokenization_<model>.py) - Python-based implementations, often using SentencePiece as the backend.
  • "Fast" tokenizers (tokenization_<model>_fast.py) - Rust-based implementations using the 🤗 tokenizers library.

In v5, we consolidate to a single tokenizer file per model: tokenization_<model>.py. This file will use the most appropriate backend available:

  1. TokenizersBackend (preferred): Rust-based tokenizers from the 🤗 tokenizers library. In general its performances are better, but it also offers a lot more features that are comonly adopted across the ecosystem, like handling additional tokens, easily update the state of the tokenizer, automatic parallelisation etc.
  2. SentencePieceBackend: For models requiring SentencePiece
  3. PythonBackend: Pure Python implementations
  4. MistralCommonBackend: Relies on MistralCommon's toknenization library. (Previously MistralCommonTokenizer)

The AutoTokenizer automatically selects the appropriate backend based on available files and dependencies. This is transparent, you continue to use AutoTokenizer.from_pretrained() as before. This allows transformers to be future-proof and modular to easily support future backends.

API Changes

1. Direct tokenizer initialization with vocab and merges:

In v5, you can now initialize tokenizers directly with vocabulary and merges, enabling training custom tokenizers from scratch:

# v5: Initialize a blank tokenizer for training
from transformers import LlamaTokenizer

# Create a tokenizer with custom vocabulary and merges
vocab = {"<unk>": 0, "<s>": 1, "</s>": 2, "hello": 3, "world": 4}
merges = [("h", "e"), ("l", "l"), ("o", " ")]

tokenizer = LlamaTokenizer(vocab=vocab, merges=merges)

# Or initialize a blank tokenizer to train on your own dataset
tokenizer = LlamaTokenizer()  # Creates a blank Llama-like tokenizer

But you can no longer pass a vocab file. As this accounts for from_pretrained use-case.

2. Simplified decoding API:

The batch_decode method has been unified with decode. Both single and batch decoding now use the same method:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small") 
inputs = ["hey how are you?", "fine"]
tokenizer.decode(tokenizer.encode(inputs))

Gives:

- 'hey how are you?</s> fine</s>'
+ ['hey how are you?</s>', 'fine</s>']

This is mostly because people get list[list[int]] out of generate, but then they would use decode because they use encode and would get:

   ...: tokenizer.decode([[1,2], [1,4]])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 4
      2 tokenizer = AutoTokenizer.from_pretrained("t5-small") 
      3 inputs = ["hey how are you?", "fine"]
----> 4 tokenizer.decode([[1,2], [1,4]])

File /raid/arthur/transformers/src/transformers/tokenization_utils_base.py:3948, in PreTrainedTokenizerBase.decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
   3945 # Convert inputs to python lists
   3946 token_ids = to_py_obj(token_ids)
-> 3948 return self._decode(
   3949     token_ids=token_ids,
   3950     skip_special_tokens=skip_special_tokens,
   3951     clean_up_tokenization_spaces=clean_up_tokenization_spaces,
   3952     **kwargs,
   3953 )

File /raid/arthur/transformers/src/transformers/tokenization_utils_fast.py:682, in PreTrainedTokenizerFast._decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
    680 if isinstance(token_ids, int):
    681     token_ids = [token_ids]
--> 682 text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
    684 clean_up_tokenization_spaces = (
    685     clean_up_tokenization_spaces
    686     if clean_up_tokenization_spaces is not None
    687     else self.clean_up_tokenization_spaces
    688 )
    689 if clean_up_tokenization_spaces:

TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

3. Unified encoding API:

The encode_plus is deprecated → call directly with __call__

3. apply_chat_template returns BatchEncoding:

Previously, apply_chat_template returned input_ids for backward compatibility. In v5, it now consistently returns a BatchEncoding dict like other tokenizer methods:

# v5
messages = [
    {"role": "user", "content": "Hello!"},
    {"role": "assistant", "content": "Hi there!"}
]

# Now returns BatchEncoding with input_ids, attention_mask, etc.
outputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
print(outputs.keys())  # dict_keys(['input_ids', 'attention_mask'])

Removed legacy configuration file saving:

  • special_tokens_map.json - special tokens are now stored in tokenizer_config.json.
  • added_tokens.json - added tokens are now stored in tokenizer.json.
  • added_tokens_decoder is only stored when there is no tokenizer.json.

When loading older tokenizers, these files are still read for backward compatibility, but new saves use the consolidated format.

@itazap itazap changed the title rm slow tokenizer llama rm slow tokenizers Sep 19, 2025
@itazap itazap force-pushed the one_tokenizer branch 4 times, most recently from af77c18 to dc0611f Compare September 25, 2025 11:31
@itazap itazap marked this pull request as draft September 30, 2025 09:03
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines 2452 to 2457
@require_tokenizers
def test_added_token_are_matched_longest_first(self):
if not self.test_slow_tokenizer:
self.skipTest(reason="This test is only for slow tokenizers")

tokenizers = self.get_tokenizers(fast=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be moved to sentencepiece as well

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it's in test_sentencepiece_backend_mixin.py

Comment on lines 2831 to 2883
words = ["Wonderful", "no", "inspiration", "example", "with", "subtoken"]
text = " ".join(words)
batch_size = 3

encoding = tokenizer_r.encode_plus(text, add_special_tokens=False)

batch_encoding = tokenizer_r([text] * batch_size, add_special_tokens=False)
num_tokens = len(encoding["input_ids"])

last_word_index = len(words) - 1
last_token_index = num_tokens - 1
last_batch_index = batch_size - 1
last_char_index = len(text) - 1

# words, tokens
self.assertEqual(len(encoding.words(0)), num_tokens)
self.assertEqual(max(encoding.words(0)), last_word_index)
self.assertEqual(min(encoding.words(0)), 0)
self.assertEqual(len(batch_encoding.words(last_batch_index)), num_tokens)
self.assertEqual(max(batch_encoding.words(last_batch_index)), last_word_index)
self.assertEqual(min(batch_encoding.words(last_batch_index)), 0)
self.assertEqual(len(encoding.tokens(0)), num_tokens)

# Assert token_to_word
self.assertEqual(encoding.token_to_word(0), 0)
self.assertEqual(encoding.token_to_word(0, 0), 0)
self.assertEqual(encoding.token_to_word(last_token_index), last_word_index)
self.assertEqual(encoding.token_to_word(0, last_token_index), last_word_index)
self.assertEqual(batch_encoding.token_to_word(1, 0), 0)
self.assertEqual(batch_encoding.token_to_word(0, last_token_index), last_word_index)
self.assertEqual(batch_encoding.token_to_word(last_batch_index, last_token_index), last_word_index)

# Assert word_to_tokens
self.assertEqual(encoding.word_to_tokens(0).start, 0)
self.assertEqual(encoding.word_to_tokens(0, 0).start, 0)
self.assertEqual(encoding.word_to_tokens(last_word_index).end, last_token_index + 1)
self.assertEqual(encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
self.assertEqual(batch_encoding.word_to_tokens(1, 0).start, 0)
self.assertEqual(batch_encoding.word_to_tokens(0, last_word_index).end, last_token_index + 1)
self.assertEqual(
batch_encoding.word_to_tokens(last_batch_index, last_word_index).end, last_token_index + 1
)

# Assert token_to_chars
self.assertEqual(encoding.token_to_chars(0).start, 0)
self.assertEqual(encoding.token_to_chars(0, 0).start, 0)
self.assertEqual(encoding.token_to_chars(last_token_index).end, last_char_index + 1)
self.assertEqual(encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
self.assertEqual(batch_encoding.token_to_chars(1, 0).start, 0)
self.assertEqual(batch_encoding.token_to_chars(0, last_token_index).end, last_char_index + 1)
self.assertEqual(
batch_encoding.token_to_chars(last_batch_index, last_token_index).end, last_char_index + 1
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed rust takes care of these for himself, the other part can be tested in sentencepiece file

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like we only tested tokenizer_r aka rust here, and spiece / slow never supported the tokens_to_chars, word_to_tokens, etc.

Comment on lines 2720 to 2721
self.skipTest(
reason="This test is now in TokenizersBackendTesterMixin - it tests tokenizers-backend API, not transformers code"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed thats for tokenizers overlay

Comment on lines -3757 to -3758
# Check the changes
for token in special_tokens_list:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this can go to trash as well

@itazap itazap mentioned this pull request Oct 10, 2025
@itazap itazap requested a review from ArthurZucker October 14, 2025 08:16
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good start:

    def __init__(self, vocab, merges):
    self.tokenizer = Tokenizer(
            BPE(
                vocab=vocab,
                merges=merges,
                dropout=None,
                unk_token=None,
                continuing_subword_prefix="",
                end_of_word_suffix="",
                fuse_unk=False,
                byte_fallback=False,
            )
        )

        tokenizer.normalizer = normalizers.NFC()

        tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
            [
                pre_tokenizers.Split(
                    Regex(
                        r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
                    ),
                    behavior="isolated",
                    invert=False,
                ),
                pre_tokenizers.ByteLevel(
                    add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False),
                    use_regex=False,
                ),
            ]
        )

        tokenizer.decoder = decoders.ByteLevel()
        tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)

        return tokenizer

Ideally I think we can even just do this, without defining the functions separately.
The only upside would have been that we can use modular for less copy pasting, but its so small that I want to have this explicit, without extra abstraction!

Comment on lines 1149 to 1152
logger.info(
"Falling back to PreTrainedSentencePieceTokenizer since tokenizer.model file was found "
"but no config or tokenizer class could be determined."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK if we want to fallback here! I think if tokenizer.json is not found -> we convert tokenizer.model to tokenizer.json, unless user enforces sentencepiece

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enforce by passing like. tokenizer_backend="sentencepiece" for ex?

Comment on lines 125 to 126
def _tokenizer(self) -> Tokenizer:
return Tokenizer(Unigram(self._vocab_scores, unk_id=self._unk_id(), byte_fallback=True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep that's good, tho I think we might want to abstract

def _model(self) -> Model:
     return Unigram(...)

return output
def _decoder(self, replacement=None, add_prefix_space=None):
return decoders.Sequence([decoders.Replace("▁", " "), decoders.ByteFallback(), decoders.Fuse()])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and then finally a function that shows how we build the final tokenizer. I think we want __init__ to make self.tokenizer = Tokenizer(model=self._model(), decoder=self._decoder, etc)

"""Tokenizer configuration for this tokenizer."""
return Tokenizer(BPE(vocab=self._vocab, merges=self._merges, fuse_unk=True, byte_fallback=True, dropout=None))

def _vocab(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be Initial vocab or something

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no legacy in general! (we want to hide non good default most probably) so the super class will support changing this, but the real llama tokenizer is not with legacy

Comment on lines 168 to 170
def _normalizer(self):
"""Normalizer configuration for this tokenizer."""
return normalizers.NFC()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second part of the review, very nice work on ubloating already

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice

self._special_tokens_map["additional_special_tokens"] = [] # BC default to empty list

# Directly set hidden values to allow init with tokens not yet in vocab
for key in list(kwargs.keys()):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can keep this as a TODO, but with the new logic that was added, we already have the self.xxx_token and self.xxx_token_id so IDK if additional_special_tokens is even useful. Let's leave it for later anyhways

Comment on lines 1121 to 1128
if not isinstance(value, (list, tuple)) or not all(isinstance(t, (str, AddedToken)) for t in value):
raise ValueError(f"Tokens {value} for key {key} should all be str or AddedToken instances")
new_tokens = [
(AddedToken(t, rstrip=False, lstrip=False, normalized=False, special=True) if isinstance(t, str) else t)
for t in value
if replace_additional_special_tokens or str(t) not in self.additional_special_tokens
]
if replace_additional_special_tokens and new_tokens:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would kind of want to get rid of this and put it only in spm, because tokenizers just supports tokenizer.special_tokens which gives all special tokens -> duplicated info with the additional special tokens

return all_toks
seen = set()
all_toks = []
for value in self.special_tokens_map.values():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, would leave as abstract and rely on tokenizers's special_tokens attr if we can!

Comment on lines 1875 to 1887
@classmethod
def convert_added_tokens(cls, obj: Union[AddedToken, Any], save=False, add_type_field=True):
if isinstance(obj, dict) and "__type" in obj and obj["__type"] == "AddedToken":
obj.pop("__type")
return AddedToken(**obj)
if isinstance(obj, AddedToken) and save:
obj = obj.__getstate__()
if add_type_field:
obj["__type"] = "AddedToken"
else:
# Don't save "special" for previous tokenizers
obj.pop("special")
return obj
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDRemember why we use this one? Only for SPM no?

) -> BatchEncoding:
# Input validation (from _call_one)
def _is_valid_text_input(t):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think (but I might be wrong here) that tokenizers does the typechecking itself as well

self.assertEqual(tokens, EXPECTED_TOKENS)
def test_integration_expected_token_ids(self):
for tok in self.tokenizers:
self.assertEqual(tok.encode(input_string), expected_token_ids)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just missing a decode test

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM!

@itazap itazap requested a review from ArthurZucker October 14, 2025 13:56
str(unk_token): 3,
}

self._merges = merges if merges is not None else generate_merges(self._vocab)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you actually should never generate merges out of the bos pad eos unk ! so the merge generation should happen before

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it all special tokens or just these 4? in convert_slow_tokenizer it currently indexes the vocab[3:]

Comment on lines 112 to 113
self.add_tokens(list(self.all_special_tokens), special_tokens=True)
self.update_post_processor()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both can probably be called from the TokenizerBackend class wdyt? As in we are adding the post processor thing to all of them, and that already by default special tokens need to be added?

sub_texts = "".join(sub_texts)

return sub_texts.replace(SPIECE_UNDERLINE, " ")
self._post_init()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also just call

self.add_tokens(list(self.all_special_tokens), special_tokens=True)

but adding token has historically been done in the super call!

@ArthurZucker ArthurZucker added for_v5? Core: Tokenization Internals of the library; Tokenization. labels Nov 27, 2025
@ArthurZucker ArthurZucker merged commit 05c0e1d into main Nov 27, 2025
18 of 24 checks passed
@ArthurZucker ArthurZucker deleted the one_tokenizer branch November 27, 2025 18:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Core: Tokenization Internals of the library; Tokenization. for_v5?

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants