|
24 | 24 | from typing import Any, Optional, Union |
25 | 25 |
|
26 | 26 | import tokenizers.pre_tokenizers as pre_tokenizers_fast |
| 27 | +from tokenizers import processors |
27 | 28 | from tokenizers import Encoding as EncodingFast |
28 | 29 | from tokenizers import Tokenizer as TokenizerFast |
29 | 30 | from tokenizers.decoders import Decoder as DecoderFast |
@@ -237,6 +238,57 @@ def can_save_slow_tokenizer(self) -> bool: |
237 | 238 | else: |
238 | 239 | return True |
239 | 240 |
|
| 241 | + def update_post_processor(self): |
| 242 | + """ |
| 243 | + Updates the underlying post processor with the current `bos_token` and `eos_token`. |
| 244 | + """ |
| 245 | + bos = self.bos_token |
| 246 | + bos_token_id = self.bos_token_id |
| 247 | + if bos is None and self.add_bos_token: |
| 248 | + raise ValueError("add_bos_token = True but bos_token = None") |
| 249 | + |
| 250 | + eos = self.eos_token |
| 251 | + eos_token_id = self.eos_token_id |
| 252 | + if eos is None and self.add_eos_token: |
| 253 | + raise ValueError("add_eos_token = True but eos_token = None") |
| 254 | + |
| 255 | + single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}" |
| 256 | + pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}" |
| 257 | + |
| 258 | + special_tokens = [] |
| 259 | + if self.add_bos_token: |
| 260 | + special_tokens.append((bos, bos_token_id)) |
| 261 | + if self.add_eos_token: |
| 262 | + special_tokens.append((eos, eos_token_id)) |
| 263 | + |
| 264 | + new_processor = processors.TemplateProcessing( |
| 265 | + single=single, pair=pair, special_tokens=special_tokens |
| 266 | + ) |
| 267 | + if isinstance(self._tokenizer.post_processor, processors.Sequence): |
| 268 | + self._tokenizer.post_processor += [new_processor] |
| 269 | + elif self._tokenizer.post_processor is not None: |
| 270 | + self._tokenizer.post_processor = processors.Sequence([self._tokenizer.post_processor, new_processor]) |
| 271 | + else: |
| 272 | + self._tokenizer.post_processor = new_processor |
| 273 | + |
| 274 | + @property |
| 275 | + def add_eos_token(self): |
| 276 | + return self._add_eos_token |
| 277 | + |
| 278 | + @property |
| 279 | + def add_bos_token(self): |
| 280 | + return self._add_bos_token |
| 281 | + |
| 282 | + @add_eos_token.setter |
| 283 | + def add_eos_token(self, value): |
| 284 | + self._add_eos_token = value |
| 285 | + self.update_post_processor() |
| 286 | + |
| 287 | + @add_bos_token.setter |
| 288 | + def add_bos_token(self, value): |
| 289 | + self._add_bos_token = value |
| 290 | + self.update_post_processor() |
| 291 | + |
240 | 292 | @property |
241 | 293 | def vocab_size(self) -> int: |
242 | 294 | """ |
|
0 commit comments