Skip to content

Commit dc0611f

Browse files
committed
llama
1 parent d5e56bb commit dc0611f

File tree

4 files changed

+405
-41
lines changed

4 files changed

+405
-41
lines changed

src/transformers/create_fast_tokenizer.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers
2121
from tokenizers.models import BPE, Unigram
22+
from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
2223

2324

2425
def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
@@ -31,6 +32,54 @@ def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
3132
return prepend_scheme
3233

3334

35+
def generate_merges(vocab, vocab_scores):
36+
reverse = vocab_scores is not None
37+
vocab_scores = dict(vocab_scores) if reverse else vocab
38+
39+
merges = []
40+
for merge, piece_score in vocab_scores.items():
41+
local = []
42+
for index in range(1, len(merge)):
43+
piece_l, piece_r = merge[:index], merge[index:]
44+
if piece_l in vocab and piece_r in vocab:
45+
local.append((piece_l, piece_r, piece_score))
46+
local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
47+
merges.extend(local)
48+
49+
merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
50+
merges = [(val[0], val[1]) for val in merges]
51+
return merges
52+
53+
54+
class SentencePieceExtractor:
55+
"""
56+
Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
57+
"""
58+
59+
def __init__(self, model: str):
60+
requires_backends(self, "sentencepiece")
61+
from sentencepiece import SentencePieceProcessor
62+
63+
self.sp = SentencePieceProcessor()
64+
self.sp.Load(model)
65+
66+
def extract(self, vocab_scores=None) -> tuple[dict[str, int], list[tuple]]:
67+
"""
68+
By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
69+
order the merges with respect to the piece scores instead.
70+
"""
71+
sp = self.sp
72+
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
73+
74+
# let's get the vocab_scores
75+
vocab_scores = {sp.id_to_piece(i): sp.get_score(i) for i in range(sp.GetPieceSize())}
76+
77+
merges = generate_merges(vocab, vocab_scores)
78+
79+
return vocab, merges
80+
81+
82+
3483
class SpmTokenizer:
3584
"""
3685
Base SentencePiece tokenizer that can be instantiated with model-specific arguments.
@@ -48,6 +97,7 @@ def __init__(
4897
pre_tokenizer: Optional[callable] = None,
4998
decoder: Optional[callable] = None,
5099
post_processor: Optional[callable] = None,
100+
tokenizer: Optional[callable] = None,
51101
):
52102
self.handle_byte_fallback = handle_byte_fallback
53103
self.legacy = legacy
@@ -60,6 +110,7 @@ def __init__(
60110
self._pre_tokenizer_fn = pre_tokenizer
61111
self._decoder_fn = decoder
62112
self._post_processor_fn = post_processor
113+
self._tokenizer_fn = tokenizer
63114

64115
def vocab(self):
65116
if self._vocab_fn is not None:
@@ -107,7 +158,10 @@ def post_processor(self):
107158

108159
def create_tokenizer(self) -> Tokenizer:
109160
"""Create and return the configured empty trainable tokenizer."""
110-
tokenizer = self.tokenizer()
161+
if self._tokenizer_fn is not None:
162+
tokenizer = self._tokenizer_fn()
163+
else:
164+
tokenizer = self.tokenizer()
111165

112166
# Tokenizer assemble
113167
normalizer = self.normalizer()
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# coding=utf-8
2+
# Copyright 2020 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import os
16+
from shutil import copyfile
17+
from typing import Optional
18+
19+
from tokenizers import processors
20+
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers
21+
from tokenizers.models import BPE, Unigram
22+
23+
from ...tokenization_utils_fast import PreTrainedTokenizerFast
24+
from ...utils import is_sentencepiece_available, logging, requires_backends
25+
from ...create_fast_tokenizer import _get_prepend_scheme
26+
27+
28+
logger = logging.get_logger(__name__)
29+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
30+
31+
B_INST, E_INST = "[INST]", "[/INST]"
32+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
33+
34+
# fmt: off
35+
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
36+
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
37+
that your responses are socially unbiased and positive in nature.
38+
39+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
40+
correct. If you don't know the answer to a question, please don't share false information."""
41+
# fmt: on
42+
43+
44+
class LlamaTokenizer(PreTrainedTokenizerFast):
45+
"""
46+
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
47+
48+
This uses notably ByteFallback and no normalization.
49+
50+
```python
51+
>>> from transformers import LlamaTokenizer
52+
53+
>>> tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
54+
>>> tokenizer.encode("Hello this is a test")
55+
[1, 15043, 445, 338, 263, 1243]
56+
```
57+
58+
If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
59+
call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
60+
values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
61+
[post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
62+
63+
64+
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
65+
refer to this superclass for more information regarding those methods.
66+
67+
Args:
68+
vocab_file (`str`, *optional*):
69+
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
70+
contains the vocabulary necessary to instantiate a tokenizer.
71+
tokenizer_file (`str`, *optional*):
72+
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
73+
contains everything needed to load the tokenizer.
74+
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
75+
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
76+
extra spaces.
77+
unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
78+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
79+
token instead.
80+
bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
81+
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
82+
eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
83+
The end of sequence token.
84+
add_bos_token (`bool`, *optional*, defaults to `True`):
85+
Whether or not to add an `bos_token` at the start of sequences.
86+
add_eos_token (`bool`, *optional*, defaults to `False`):
87+
Whether or not to add an `eos_token` at the end of sequences.
88+
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
89+
Whether or not the default system prompt for Llama should be used
90+
legacy (`bool`, *optional*):
91+
Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
92+
and #25224 which includes fixes to properly handle tokens that appear after special tokens.
93+
Make sure to also set `from_slow` to `True`.
94+
A simple example:
95+
96+
- `legacy=True`:
97+
```python
98+
>>> from transformers import LlamaTokenizer
99+
100+
>>> tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=True, from_scratch=True)
101+
>>> tokenizer.encode("Hello <s>.") # 869 is '▁.'
102+
[1, 15043, 29871, 1, 869]
103+
```
104+
- `legacy=False`:
105+
```python
106+
>>> from transformers import LlamaTokenizer
107+
108+
>>> tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False, from_scratch=True)
109+
>>> tokenizer.encode("Hello <s>.") # 29889 is '.'
110+
[1, 15043, 29871, 1, 29889]
111+
```
112+
Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
113+
add_prefix_space (`bool`, *optional*):
114+
Whether or not the tokenizer should automatically add a prefix space
115+
from_scratch (`bool`, *optional*, defaults to `False`):
116+
Whether to create an empty trainable tokenizer from scratch. When `True`, creates a minimal tokenizer
117+
with only basic special tokens that can be trained on new data.
118+
"""
119+
120+
vocab_files_names = VOCAB_FILES_NAMES
121+
slow_tokenizer_class = None # No slow tokenizer class needed
122+
padding_side = "left"
123+
model_input_names = ["input_ids", "attention_mask"]
124+
125+
def __init__(
126+
self,
127+
vocab_file=None,
128+
tokenizer_file=None,
129+
clean_up_tokenization_spaces=False,
130+
unk_token="<unk>",
131+
bos_token="<s>",
132+
eos_token="</s>",
133+
add_bos_token=True,
134+
add_eos_token=False,
135+
use_default_system_prompt=False,
136+
legacy=False,
137+
add_prefix_space=None,
138+
vocab=None,
139+
merges=None,
140+
**kwargs,
141+
):
142+
self.legacy = legacy
143+
144+
# Set add_prefix_space attribute for use in override methods
145+
self.add_prefix_space = add_prefix_space if add_prefix_space is not None else True
146+
147+
self._vocab = vocab
148+
self._merges = merges
149+
150+
# Prepare base-class construction helpers
151+
metaspace_override = None
152+
tokenizer_backend_config = None
153+
if tokenizer_file is None:
154+
tokenizer_backend_config = {
155+
"type": "spm",
156+
"handle_byte_fallback": True,
157+
"legacy": legacy,
158+
"add_prefix_space": add_prefix_space if add_prefix_space is not None else True,
159+
"vocab": self._vocab,
160+
"normalizer": self._normalizer,
161+
"pre_tokenizer": self._pre_tokenizer,
162+
"decoder": self._decoder,
163+
"tokenizer": self._tokenizer,
164+
}
165+
166+
# Initialize the base class which will build the backend tokenizer
167+
super().__init__(
168+
tokenizer_file=tokenizer_file,
169+
tokenizer_backend_config=tokenizer_backend_config,
170+
metaspace_override=metaspace_override,
171+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
172+
unk_token=unk_token,
173+
bos_token=bos_token,
174+
eos_token=eos_token,
175+
add_bos_token=add_bos_token,
176+
add_eos_token=add_eos_token,
177+
use_default_system_prompt=use_default_system_prompt,
178+
add_prefix_space=add_prefix_space,
179+
legacy=legacy,
180+
**kwargs,
181+
)
182+
183+
# TODO: how to do this cleanly? Need to trigger re-adding special tokens after setting the normalizer in Tokenizers
184+
self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="first", split=False)
185+
self._tokenizer.normalizer = None #normalizers.Sequence([normalizers.Prepend("▁"), normalizers.Replace(pattern=" ", content="▁")])
186+
self.add_tokens([AddedToken(token, special=True) for token in self.all_special_tokens])
187+
188+
self._add_bos_token = add_bos_token
189+
self._add_eos_token = add_eos_token
190+
self.update_post_processor()
191+
192+
self.use_default_system_prompt = use_default_system_prompt
193+
self.vocab_file = vocab_file
194+
195+
196+
def _tokenizer(self):
197+
"""Tokenizer configuration for this tokenizer."""
198+
return Tokenizer(BPE(vocab=self._vocab, merges=self._merges, fuse_unk=True, byte_fallback=True, dropout=None))
199+
200+
def _vocab(self):
201+
"""Vocabulary handling for this tokenizer."""
202+
# First 3 special pieces are fixed for LLaMA
203+
vocab = [
204+
("<unk>", 0.0),
205+
("<s>", 0.0),
206+
("</s>", 0.0),
207+
]
208+
return vocab
209+
210+
def _decoder(self, replacement, add_prefix_space):
211+
"""Decoder configuration for this tokenizer."""
212+
sequence = [
213+
decoders.Replace("▁", " "),
214+
decoders.ByteFallback(),
215+
decoders.Fuse(),
216+
]
217+
if add_prefix_space:
218+
sequence += [decoders.Strip(content=" ", left=1)]
219+
return decoders.Sequence(sequence)
220+
221+
def _normalizer(self):
222+
"""Normalizer configuration for this tokenizer."""
223+
if self.legacy:
224+
sequence = []
225+
if self.add_prefix_space:
226+
sequence += [normalizers.Prepend(prepend="▁")]
227+
sequence += [normalizers.Replace(pattern=" ", content="▁")]
228+
return normalizers.Sequence(sequence)
229+
return None
230+
231+
def _pre_tokenizer(self, replacement, add_prefix_space):
232+
"""Pre-tokenizer configuration for this tokenizer."""
233+
if not self.legacy:
234+
prepend_scheme = _get_prepend_scheme(add_prefix_space, self)
235+
return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
236+
return None
237+
238+
239+
__all__ = ["LlamaTokenizer"]

0 commit comments

Comments
 (0)