Skip to content

Commit 6c1b13b

Browse files
committed
rm slow
1 parent 6254bb4 commit 6c1b13b

File tree

3 files changed

+226
-462
lines changed

3 files changed

+226
-462
lines changed

src/transformers/convert_slow_tokenizer.py

Lines changed: 147 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,51 +1326,163 @@ def decoder(self, replacement, add_prefix_space):
13261326
)
13271327

13281328

1329-
class LlamaConverter(SpmConverter):
1330-
handle_byte_fallback = True
1329+
class SpmTokenizer:
1330+
"""
1331+
Base SentencePiece tokenizer that can be instantiated with model-specific arguments.
1332+
This replaces the converter pattern with direct instantiation.
1333+
"""
1334+
1335+
def __init__(
1336+
self,
1337+
vocab_file: str,
1338+
handle_byte_fallback: bool = True,
1339+
legacy: bool = False,
1340+
add_prefix_space: bool = True,
1341+
special_tokens: Optional[dict] = None,
1342+
vocab: Optional[callable] = None,
1343+
unk_id: Optional[callable] = None,
1344+
normalizer: Optional[callable] = None,
1345+
pre_tokenizer: Optional[callable] = None,
1346+
decoder: Optional[callable] = None,
1347+
post_processor: Optional[callable] = None,
1348+
):
1349+
requires_backends(self, "protobuf")
1350+
1351+
self.vocab_file = vocab_file
1352+
self.handle_byte_fallback = handle_byte_fallback
1353+
self.legacy = legacy
1354+
self.add_prefix_space = add_prefix_space
1355+
self.special_tokens = special_tokens or {}
1356+
# Store user-provided callables under private names to avoid clashing with methods
1357+
self._vocab_fn = vocab
1358+
self._unk_id_fn = unk_id
1359+
self._normalizer_fn = normalizer
1360+
self._pre_tokenizer_fn = pre_tokenizer
1361+
self._decoder_fn = decoder
1362+
self._post_processor_fn = post_processor
1363+
1364+
# Load the protobuf model
1365+
model_pb2 = import_protobuf()
1366+
m = model_pb2.ModelProto()
1367+
with open(vocab_file, "rb") as f:
1368+
m.ParseFromString(f.read())
1369+
self.proto = m
13311370

13321371
def vocab(self, proto):
1333-
vocab = [
1334-
(self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
1335-
(self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
1336-
(self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
1337-
]
1338-
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1339-
return vocab
1372+
if self._vocab_fn is not None:
1373+
return self._vocab_fn(proto)
1374+
return [(piece.piece, piece.score) for piece in proto.pieces]
13401375

13411376
def unk_id(self, proto):
1342-
unk_id = 0
1343-
return unk_id
1377+
if self._unk_id_fn is not None:
1378+
return self._unk_id_fn(proto)
1379+
return proto.trainer_spec.unk_id
13441380

1345-
def decoder(self, replacement, add_prefix_space):
1346-
sequence = [
1347-
decoders.Replace("▁", " "),
1348-
decoders.ByteFallback(),
1349-
decoders.Fuse(),
1381+
def tokenizer(self, proto):
1382+
model_type = proto.trainer_spec.model_type
1383+
vocab_scores = self.vocab(proto)
1384+
1385+
if model_type == 1:
1386+
tokenizer = Tokenizer(
1387+
Unigram(
1388+
vocab_scores,
1389+
unk_id=self.unk_id(proto),
1390+
byte_fallback=self.handle_byte_fallback,
1391+
)
1392+
)
1393+
elif model_type == 2:
1394+
_, merges = SentencePieceExtractor(self.vocab_file).extract(vocab_scores)
1395+
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
1396+
tokenizer = Tokenizer(
1397+
BPE(
1398+
bpe_vocab,
1399+
merges,
1400+
unk_token=proto.trainer_spec.unk_piece,
1401+
fuse_unk=True,
1402+
byte_fallback=self.handle_byte_fallback,
1403+
dropout=None,
1404+
)
1405+
)
1406+
else:
1407+
raise Exception(
1408+
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
1409+
)
1410+
1411+
# Add special tokens
1412+
spm_added_tokens = [
1413+
(id, p.piece, p.type == 3 or p.piece in self.special_tokens)
1414+
for id, p in enumerate(proto.pieces)
1415+
if p.type in [3, 4]
13501416
]
1351-
if add_prefix_space:
1352-
sequence += [decoders.Strip(content=" ", left=1)]
1353-
return decoders.Sequence(sequence)
1417+
tokenizer.add_tokens(
1418+
[
1419+
AddedToken(token, normalized=False, special=special)
1420+
for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
1421+
]
1422+
)
1423+
1424+
return tokenizer
13541425

13551426
def normalizer(self, proto):
1356-
if getattr(self.original_tokenizer, "legacy", True):
1357-
sequence = []
1358-
if getattr(self.original_tokenizer, "add_prefix_space", True):
1359-
sequence += [normalizers.Prepend(prepend="▁")]
1360-
sequence += [normalizers.Replace(pattern=" ", content="▁")]
1361-
return normalizers.Sequence(sequence)
1362-
return None # non-legacy, no normalizer
1427+
if self._normalizer_fn is not None:
1428+
return self._normalizer_fn(proto)
1429+
1430+
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
1431+
_normalizers = [
1432+
normalizers.Strip(left=False, right=True),
1433+
normalizers.Replace(Regex(" {2,}"), "▁"),
1434+
]
1435+
if not precompiled_charsmap:
1436+
return normalizers.Sequence(_normalizers)
1437+
else:
1438+
return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
13631439

13641440
def pre_tokenizer(self, replacement, add_prefix_space):
1365-
if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
1366-
prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
1367-
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
1368-
return None
1441+
if self._pre_tokenizer_fn is not None:
1442+
return self._pre_tokenizer_fn(replacement, add_prefix_space)
1443+
1444+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self)
1445+
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
1446+
1447+
def decoder(self, replacement, add_prefix_space):
1448+
if self._decoder_fn is not None:
1449+
return self._decoder_fn(replacement, add_prefix_space)
1450+
1451+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self)
1452+
return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
13691453

13701454
def post_processor(self):
1371-
# the processor is defined in the LlamaTokenizerFast class.
1455+
if self._post_processor_fn is not None:
1456+
return self._post_processor_fn()
13721457
return None
13731458

1459+
def create_tokenizer(self) -> Tokenizer:
1460+
"""Create and return the configured tokenizer."""
1461+
tokenizer = self.tokenizer(self.proto)
1462+
1463+
# Tokenizer assemble
1464+
normalizer = self.normalizer(self.proto)
1465+
if normalizer is not None:
1466+
tokenizer.normalizer = normalizer
1467+
1468+
replacement = "▁"
1469+
add_prefix_space = self.add_prefix_space
1470+
1471+
pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
1472+
if pre_tokenizer is not None:
1473+
tokenizer.pre_tokenizer = pre_tokenizer
1474+
1475+
tokenizer.decoder = self.decoder(replacement, add_prefix_space)
1476+
post_processor = self.post_processor()
1477+
if post_processor:
1478+
tokenizer.post_processor = post_processor
1479+
1480+
return tokenizer
1481+
1482+
1483+
## NOTE: LLaMA-specific converter moved to `models/llama/tokenization_llama_fast.py`.
1484+
## The slow->fast conversion for LLaMA is now handled directly in the fast file.
1485+
13741486

13751487
class MarkupLMConverter(Converter):
13761488
def converted(self) -> Tokenizer:
@@ -1700,10 +1812,11 @@ def converted(self) -> Tokenizer:
17001812
"XLNetTokenizer": XLNetConverter,
17011813
"SplinterTokenizer": SplinterConverter,
17021814
"XGLMTokenizer": XGLMConverter,
1703-
"LlamaTokenizer": LlamaConverter,
1704-
"CodeLlamaTokenizer": LlamaConverter,
1815+
# LLaMA converters moved into fast file; slow->fast conversion is handled there.
1816+
# "LlamaTokenizer": LlamaConverter,
1817+
# "CodeLlamaTokenizer": LlamaConverter,
17051818
"GemmaTokenizer": GemmaConverter,
1706-
"Phi3Tokenizer": LlamaConverter,
1819+
# "Phi3Tokenizer": LlamaConverter,
17071820
}
17081821

17091822

0 commit comments

Comments
 (0)