@@ -1326,51 +1326,163 @@ def decoder(self, replacement, add_prefix_space):
13261326 )
13271327
13281328
1329- class LlamaConverter (SpmConverter ):
1330- handle_byte_fallback = True
1329+ class SpmTokenizer :
1330+ """
1331+ Base SentencePiece tokenizer that can be instantiated with model-specific arguments.
1332+ This replaces the converter pattern with direct instantiation.
1333+ """
1334+
1335+ def __init__ (
1336+ self ,
1337+ vocab_file : str ,
1338+ handle_byte_fallback : bool = True ,
1339+ legacy : bool = False ,
1340+ add_prefix_space : bool = True ,
1341+ special_tokens : Optional [dict ] = None ,
1342+ vocab : Optional [callable ] = None ,
1343+ unk_id : Optional [callable ] = None ,
1344+ normalizer : Optional [callable ] = None ,
1345+ pre_tokenizer : Optional [callable ] = None ,
1346+ decoder : Optional [callable ] = None ,
1347+ post_processor : Optional [callable ] = None ,
1348+ ):
1349+ requires_backends (self , "protobuf" )
1350+
1351+ self .vocab_file = vocab_file
1352+ self .handle_byte_fallback = handle_byte_fallback
1353+ self .legacy = legacy
1354+ self .add_prefix_space = add_prefix_space
1355+ self .special_tokens = special_tokens or {}
1356+ # Store user-provided callables under private names to avoid clashing with methods
1357+ self ._vocab_fn = vocab
1358+ self ._unk_id_fn = unk_id
1359+ self ._normalizer_fn = normalizer
1360+ self ._pre_tokenizer_fn = pre_tokenizer
1361+ self ._decoder_fn = decoder
1362+ self ._post_processor_fn = post_processor
1363+
1364+ # Load the protobuf model
1365+ model_pb2 = import_protobuf ()
1366+ m = model_pb2 .ModelProto ()
1367+ with open (vocab_file , "rb" ) as f :
1368+ m .ParseFromString (f .read ())
1369+ self .proto = m
13311370
13321371 def vocab (self , proto ):
1333- vocab = [
1334- (self .original_tokenizer .convert_ids_to_tokens (0 ), 0.0 ),
1335- (self .original_tokenizer .convert_ids_to_tokens (1 ), 0.0 ),
1336- (self .original_tokenizer .convert_ids_to_tokens (2 ), 0.0 ),
1337- ]
1338- vocab += [(piece .piece , piece .score ) for piece in proto .pieces [3 :]]
1339- return vocab
1372+ if self ._vocab_fn is not None :
1373+ return self ._vocab_fn (proto )
1374+ return [(piece .piece , piece .score ) for piece in proto .pieces ]
13401375
13411376 def unk_id (self , proto ):
1342- unk_id = 0
1343- return unk_id
1377+ if self ._unk_id_fn is not None :
1378+ return self ._unk_id_fn (proto )
1379+ return proto .trainer_spec .unk_id
13441380
1345- def decoder (self , replacement , add_prefix_space ):
1346- sequence = [
1347- decoders .Replace ("▁" , " " ),
1348- decoders .ByteFallback (),
1349- decoders .Fuse (),
1381+ def tokenizer (self , proto ):
1382+ model_type = proto .trainer_spec .model_type
1383+ vocab_scores = self .vocab (proto )
1384+
1385+ if model_type == 1 :
1386+ tokenizer = Tokenizer (
1387+ Unigram (
1388+ vocab_scores ,
1389+ unk_id = self .unk_id (proto ),
1390+ byte_fallback = self .handle_byte_fallback ,
1391+ )
1392+ )
1393+ elif model_type == 2 :
1394+ _ , merges = SentencePieceExtractor (self .vocab_file ).extract (vocab_scores )
1395+ bpe_vocab = {word : i for i , (word , score ) in enumerate (vocab_scores )}
1396+ tokenizer = Tokenizer (
1397+ BPE (
1398+ bpe_vocab ,
1399+ merges ,
1400+ unk_token = proto .trainer_spec .unk_piece ,
1401+ fuse_unk = True ,
1402+ byte_fallback = self .handle_byte_fallback ,
1403+ dropout = None ,
1404+ )
1405+ )
1406+ else :
1407+ raise Exception (
1408+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
1409+ )
1410+
1411+ # Add special tokens
1412+ spm_added_tokens = [
1413+ (id , p .piece , p .type == 3 or p .piece in self .special_tokens )
1414+ for id , p in enumerate (proto .pieces )
1415+ if p .type in [3 , 4 ]
13501416 ]
1351- if add_prefix_space :
1352- sequence += [decoders .Strip (content = " " , left = 1 )]
1353- return decoders .Sequence (sequence )
1417+ tokenizer .add_tokens (
1418+ [
1419+ AddedToken (token , normalized = False , special = special )
1420+ for id , token , special in sorted (spm_added_tokens , key = lambda x : x [0 ])
1421+ ]
1422+ )
1423+
1424+ return tokenizer
13541425
13551426 def normalizer (self , proto ):
1356- if getattr (self .original_tokenizer , "legacy" , True ):
1357- sequence = []
1358- if getattr (self .original_tokenizer , "add_prefix_space" , True ):
1359- sequence += [normalizers .Prepend (prepend = "▁" )]
1360- sequence += [normalizers .Replace (pattern = " " , content = "▁" )]
1361- return normalizers .Sequence (sequence )
1362- return None # non-legacy, no normalizer
1427+ if self ._normalizer_fn is not None :
1428+ return self ._normalizer_fn (proto )
1429+
1430+ precompiled_charsmap = proto .normalizer_spec .precompiled_charsmap
1431+ _normalizers = [
1432+ normalizers .Strip (left = False , right = True ),
1433+ normalizers .Replace (Regex (" {2,}" ), "▁" ),
1434+ ]
1435+ if not precompiled_charsmap :
1436+ return normalizers .Sequence (_normalizers )
1437+ else :
1438+ return normalizers .Sequence ([normalizers .Precompiled (precompiled_charsmap )] + _normalizers )
13631439
13641440 def pre_tokenizer (self , replacement , add_prefix_space ):
1365- if not getattr (self .original_tokenizer , "legacy" , True ): # non-legacy, we need a replace
1366- prepend_scheme = _get_prepend_scheme (add_prefix_space , self .original_tokenizer )
1367- return pre_tokenizers .Metaspace (replacement = replacement , prepend_scheme = prepend_scheme , split = False )
1368- return None
1441+ if self ._pre_tokenizer_fn is not None :
1442+ return self ._pre_tokenizer_fn (replacement , add_prefix_space )
1443+
1444+ prepend_scheme = _get_prepend_scheme (add_prefix_space , self )
1445+ return pre_tokenizers .Metaspace (replacement = replacement , prepend_scheme = prepend_scheme )
1446+
1447+ def decoder (self , replacement , add_prefix_space ):
1448+ if self ._decoder_fn is not None :
1449+ return self ._decoder_fn (replacement , add_prefix_space )
1450+
1451+ prepend_scheme = _get_prepend_scheme (add_prefix_space , self )
1452+ return decoders .Metaspace (replacement = replacement , prepend_scheme = prepend_scheme )
13691453
13701454 def post_processor (self ):
1371- # the processor is defined in the LlamaTokenizerFast class.
1455+ if self ._post_processor_fn is not None :
1456+ return self ._post_processor_fn ()
13721457 return None
13731458
1459+ def create_tokenizer (self ) -> Tokenizer :
1460+ """Create and return the configured tokenizer."""
1461+ tokenizer = self .tokenizer (self .proto )
1462+
1463+ # Tokenizer assemble
1464+ normalizer = self .normalizer (self .proto )
1465+ if normalizer is not None :
1466+ tokenizer .normalizer = normalizer
1467+
1468+ replacement = "▁"
1469+ add_prefix_space = self .add_prefix_space
1470+
1471+ pre_tokenizer = self .pre_tokenizer (replacement , add_prefix_space )
1472+ if pre_tokenizer is not None :
1473+ tokenizer .pre_tokenizer = pre_tokenizer
1474+
1475+ tokenizer .decoder = self .decoder (replacement , add_prefix_space )
1476+ post_processor = self .post_processor ()
1477+ if post_processor :
1478+ tokenizer .post_processor = post_processor
1479+
1480+ return tokenizer
1481+
1482+
1483+ ## NOTE: LLaMA-specific converter moved to `models/llama/tokenization_llama_fast.py`.
1484+ ## The slow->fast conversion for LLaMA is now handled directly in the fast file.
1485+
13741486
13751487class MarkupLMConverter (Converter ):
13761488 def converted (self ) -> Tokenizer :
@@ -1700,10 +1812,11 @@ def converted(self) -> Tokenizer:
17001812 "XLNetTokenizer" : XLNetConverter ,
17011813 "SplinterTokenizer" : SplinterConverter ,
17021814 "XGLMTokenizer" : XGLMConverter ,
1703- "LlamaTokenizer" : LlamaConverter ,
1704- "CodeLlamaTokenizer" : LlamaConverter ,
1815+ # LLaMA converters moved into fast file; slow->fast conversion is handled there.
1816+ # "LlamaTokenizer": LlamaConverter,
1817+ # "CodeLlamaTokenizer": LlamaConverter,
17051818 "GemmaTokenizer" : GemmaConverter ,
1706- "Phi3Tokenizer" : LlamaConverter ,
1819+ # "Phi3Tokenizer": LlamaConverter,
17071820}
17081821
17091822
0 commit comments