Skip to content

Commit f16b430

Browse files
committed
rm protobuf dependency
1 parent 6c1b13b commit f16b430

File tree

2 files changed

+73
-103
lines changed

2 files changed

+73
-103
lines changed

src/transformers/convert_slow_tokenizer.py

Lines changed: 61 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,7 +1334,6 @@ class SpmTokenizer:
13341334

13351335
def __init__(
13361336
self,
1337-
vocab_file: str,
13381337
handle_byte_fallback: bool = True,
13391338
legacy: bool = False,
13401339
add_prefix_space: bool = True,
@@ -1346,9 +1345,6 @@ def __init__(
13461345
decoder: Optional[callable] = None,
13471346
post_processor: Optional[callable] = None,
13481347
):
1349-
requires_backends(self, "protobuf")
1350-
1351-
self.vocab_file = vocab_file
13521348
self.handle_byte_fallback = handle_byte_fallback
13531349
self.legacy = legacy
13541350
self.add_prefix_space = add_prefix_space
@@ -1360,82 +1356,31 @@ def __init__(
13601356
self._pre_tokenizer_fn = pre_tokenizer
13611357
self._decoder_fn = decoder
13621358
self._post_processor_fn = post_processor
1363-
1364-
# Load the protobuf model
1365-
model_pb2 = import_protobuf()
1366-
m = model_pb2.ModelProto()
1367-
with open(vocab_file, "rb") as f:
1368-
m.ParseFromString(f.read())
1369-
self.proto = m
13701359

1371-
def vocab(self, proto):
1360+
def vocab(self):
13721361
if self._vocab_fn is not None:
1373-
return self._vocab_fn(proto)
1374-
return [(piece.piece, piece.score) for piece in proto.pieces]
1362+
return self._vocab_fn()
1363+
# Return empty vocab for training
1364+
return []
13751365

1376-
def unk_id(self, proto):
1366+
def unk_id(self):
13771367
if self._unk_id_fn is not None:
1378-
return self._unk_id_fn(proto)
1379-
return proto.trainer_spec.unk_id
1380-
1381-
def tokenizer(self, proto):
1382-
model_type = proto.trainer_spec.model_type
1383-
vocab_scores = self.vocab(proto)
1384-
1385-
if model_type == 1:
1386-
tokenizer = Tokenizer(
1387-
Unigram(
1388-
vocab_scores,
1389-
unk_id=self.unk_id(proto),
1390-
byte_fallback=self.handle_byte_fallback,
1391-
)
1392-
)
1393-
elif model_type == 2:
1394-
_, merges = SentencePieceExtractor(self.vocab_file).extract(vocab_scores)
1395-
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
1396-
tokenizer = Tokenizer(
1397-
BPE(
1398-
bpe_vocab,
1399-
merges,
1400-
unk_token=proto.trainer_spec.unk_piece,
1401-
fuse_unk=True,
1402-
byte_fallback=self.handle_byte_fallback,
1403-
dropout=None,
1404-
)
1405-
)
1406-
else:
1407-
raise Exception(
1408-
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
1409-
)
1368+
return self._unk_id_fn()
1369+
return 0 # Default unk_id
14101370

1411-
# Add special tokens
1412-
spm_added_tokens = [
1413-
(id, p.piece, p.type == 3 or p.piece in self.special_tokens)
1414-
for id, p in enumerate(proto.pieces)
1415-
if p.type in [3, 4]
1416-
]
1417-
tokenizer.add_tokens(
1418-
[
1419-
AddedToken(token, normalized=False, special=special)
1420-
for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
1421-
]
1422-
)
1423-
1424-
return tokenizer
1371+
def tokenizer(self):
1372+
# Always create empty trainable tokenizer
1373+
minimal_vocab = [("<unk>", 0.0)]
1374+
return Tokenizer(Unigram(minimal_vocab, unk_id=self.unk_id(), byte_fallback=self.handle_byte_fallback))
14251375

1426-
def normalizer(self, proto):
1376+
def normalizer(self):
14271377
if self._normalizer_fn is not None:
1428-
return self._normalizer_fn(proto)
1429-
1430-
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
1378+
return self._normalizer_fn()
14311379
_normalizers = [
14321380
normalizers.Strip(left=False, right=True),
14331381
normalizers.Replace(Regex(" {2,}"), "▁"),
14341382
]
1435-
if not precompiled_charsmap:
1436-
return normalizers.Sequence(_normalizers)
1437-
else:
1438-
return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
1383+
return normalizers.Sequence(_normalizers)
14391384

14401385
def pre_tokenizer(self, replacement, add_prefix_space):
14411386
if self._pre_tokenizer_fn is not None:
@@ -1457,11 +1402,11 @@ def post_processor(self):
14571402
return None
14581403

14591404
def create_tokenizer(self) -> Tokenizer:
1460-
"""Create and return the configured tokenizer."""
1461-
tokenizer = self.tokenizer(self.proto)
1405+
"""Create and return the configured empty trainable tokenizer."""
1406+
tokenizer = self.tokenizer()
14621407

14631408
# Tokenizer assemble
1464-
normalizer = self.normalizer(self.proto)
1409+
normalizer = self.normalizer()
14651410
if normalizer is not None:
14661411
tokenizer.normalizer = normalizer
14671412

@@ -1483,6 +1428,50 @@ def create_tokenizer(self) -> Tokenizer:
14831428
## NOTE: LLaMA-specific converter moved to `models/llama/tokenization_llama_fast.py`.
14841429
## The slow->fast conversion for LLaMA is now handled directly in the fast file.
14851430

1431+
class LlamaConverter(SpmConverter):
1432+
handle_byte_fallback = True
1433+
1434+
def vocab(self, proto):
1435+
vocab = [
1436+
(self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
1437+
(self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
1438+
(self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
1439+
]
1440+
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1441+
return vocab
1442+
1443+
def unk_id(self, proto):
1444+
unk_id = 0
1445+
return unk_id
1446+
1447+
def decoder(self, replacement, add_prefix_space):
1448+
sequence = [
1449+
decoders.Replace("▁", " "),
1450+
decoders.ByteFallback(),
1451+
decoders.Fuse(),
1452+
]
1453+
if add_prefix_space:
1454+
sequence += [decoders.Strip(content=" ", left=1)]
1455+
return decoders.Sequence(sequence)
1456+
1457+
def normalizer(self, proto):
1458+
if getattr(self.original_tokenizer, "legacy", True):
1459+
sequence = []
1460+
if getattr(self.original_tokenizer, "add_prefix_space", True):
1461+
sequence += [normalizers.Prepend(prepend="▁")]
1462+
sequence += [normalizers.Replace(pattern=" ", content="▁")]
1463+
return normalizers.Sequence(sequence)
1464+
return None # non-legacy, no normalizer
1465+
1466+
def pre_tokenizer(self, replacement, add_prefix_space):
1467+
if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
1468+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
1469+
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
1470+
return None
1471+
1472+
def post_processor(self):
1473+
# the processor is defined in the LlamaTokenizerFast class.
1474+
return None
14861475

14871476
class MarkupLMConverter(Converter):
14881477
def converted(self) -> Tokenizer:

src/transformers/models/llama/tokenization_llama_fast.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -97,21 +97,24 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
9797
```python
9898
>>> from transformers import LlamaTokenizerFast
9999
100-
>>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_slow=True)
100+
>>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=True, from_scratch=True)
101101
>>> tokenizer.encode("Hello <s>.") # 869 is '▁.'
102102
[1, 15043, 29871, 1, 869]
103103
```
104104
- `legacy=False`:
105105
```python
106106
>>> from transformers import LlamaTokenizerFast
107107
108-
>>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=False, from_slow=True)
108+
>>> tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", legacy=False, from_scratch=True)
109109
>>> tokenizer.encode("Hello <s>.") # 29889 is '.'
110110
[1, 15043, 29871, 1, 29889]
111111
```
112112
Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
113113
add_prefix_space (`bool`, *optional*):
114114
Whether or not the tokenizer should automatically add a prefix space
115+
from_scratch (`bool`, *optional*, defaults to `False`):
116+
Whether to create an empty trainable tokenizer from scratch. When `True`, creates a minimal tokenizer
117+
with only basic special tokens that can be trained on new data.
115118
"""
116119

117120
vocab_files_names = VOCAB_FILES_NAMES
@@ -130,53 +133,32 @@ def __init__(
130133
add_bos_token=True,
131134
add_eos_token=False,
132135
use_default_system_prompt=False,
133-
legacy=None,
136+
legacy=False,
134137
add_prefix_space=None,
135138
**kwargs,
136139
):
137-
if legacy is None:
138-
logger.warning_once(
139-
f"You are using the default legacy behaviour of the {self.__class__}. This is"
140-
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
141-
" If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
142-
" means, and thoroughly read the reason why this was added as explained in"
143-
" https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file"
144-
" you can ignore this message."
145-
)
146-
legacy = True
147-
self.legacy = False
148-
legacy = False
140+
self.legacy = legacy
149141

150142
# Set add_prefix_space attribute for use in override methods
151143
self.add_prefix_space = add_prefix_space if add_prefix_space is not None else True
152144

153-
# Handle from_slow parameter - when True, force SpmTokenizer path even if tokenizer.json exists
154-
from_slow = kwargs.pop("from_slow", False)
145+
# Handle from_scratch parameter - when True, create empty trainable tokenizer
146+
from_scratch = kwargs.pop("from_scratch", False)
155147

156-
# Handle tokenizer creation
157-
if tokenizer_file is not None and not from_slow:
158-
# Load from existing tokenizer.json file (unless from_slow=True)
148+
if tokenizer_file is not None and not from_scratch:
159149
from tokenizers import Tokenizer as TokenizerFast
160150
fast_tokenizer = TokenizerFast.from_file(tokenizer_file)
161-
elif vocab_file is not None:
162-
# Create LLaMA-specific tokenizer using SpmTokenizer
163-
# This path is used when:
164-
# 1. vocab_file is provided and no tokenizer_file
165-
# 2. from_slow=True (forces SpmTokenizer path even if tokenizer.json exists)
151+
else:
166152
spm_tokenizer = SpmTokenizer(
167-
vocab_file=vocab_file,
168153
handle_byte_fallback=True,
169154
legacy=legacy,
170155
add_prefix_space=add_prefix_space if add_prefix_space is not None else True,
171156
vocab=self._vocab,
172-
#unk_id=self._unk_id,
173157
normalizer=self._normalizer,
174158
pre_tokenizer=self._pre_tokenizer,
175159
decoder=self._decoder,
176160
)
177161
fast_tokenizer = spm_tokenizer.create_tokenizer()
178-
else:
179-
raise ValueError("Either tokenizer_file or vocab_file must be provided")
180162

181163
# Initialize the base class with the fast tokenizer
182164
super().__init__(
@@ -198,15 +180,14 @@ def __init__(
198180
self.use_default_system_prompt = use_default_system_prompt
199181
self.vocab_file = vocab_file
200182

201-
def _vocab(self, proto):
183+
def _vocab(self):
202184
"""Vocabulary handling for this tokenizer."""
203185
# First 3 special pieces are fixed for LLaMA
204186
vocab = [
205187
("<unk>", 0.0),
206188
("<s>", 0.0),
207189
("</s>", 0.0),
208190
]
209-
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
210191
return vocab
211192

212193
def _decoder(self, replacement, add_prefix_space):

0 commit comments

Comments
 (0)