Spaces:
Sleeping
Sleeping
File size: 29,938 Bytes
2260825 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 |
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities to convert slow tokenizers in their fast tokenizers counterparts.
All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
allow to make our dependency on SentencePiece optional.
"""
from typing import Dict, List, Tuple
from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE, Unigram, WordPiece
from .file_utils import requires_backends
class SentencePieceExtractor:
"""
Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
"""
def __init__(self, model: str):
requires_backends(self, "sentencepiece")
from sentencepiece import SentencePieceProcessor
self.sp = SentencePieceProcessor()
self.sp.Load(model)
def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
sp = self.sp
vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
# Merges
merges = []
for piece_l in vocab.keys():
for piece_r in vocab.keys():
merge = f"{piece_l}{piece_r}"
piece_id = vocab.get(merge, None)
if piece_id:
merges += [(piece_l, piece_r, piece_id)]
merges = sorted(merges, key=lambda val: val[2])
merges = [(val[0], val[1]) for val in merges]
return vocab, merges
def check_number_comma(piece: str) -> bool:
return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
class Converter:
def __init__(self, original_tokenizer):
self.original_tokenizer = original_tokenizer
def converted(self) -> Tokenizer:
raise NotImplementedError()
class BertConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")
return tokenizer
class FunnelConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")
return tokenizer
class MPNetConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")
return tokenizer
class OpenAIGPTConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
unk_token = self.original_tokenizer.unk_token
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
unk_token=str(unk_token),
end_of_word_suffix="</w>",
fuse_unk=False,
)
)
if tokenizer.token_to_id(str(unk_token)) is not None:
tokenizer.add_special_tokens([str(unk_token)])
tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")
return tokenizer
class GPT2Converter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False,
)
)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
return tokenizer
class HerbertConverter(Converter):
def converted(self) -> Tokenizer:
tokenizer_info_str = "#version:"
token_suffix = "</w>"
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
if tokenizer_info_str in merges[0][0]:
merges = merges[1:]
tokenizer = Tokenizer(
BPE(
vocab,
merges,
dropout=None,
unk_token=self.original_tokenizer.unk_token,
end_of_word_suffix=token_suffix,
)
)
tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
tokenizer.post_processor = processors.BertProcessing(
sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
)
return tokenizer
class RobertaConverter(Converter):
def converted(self) -> Tokenizer:
ot = self.original_tokenizer
vocab = ot.encoder
merges = list(ot.bpe_ranks.keys())
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False,
)
)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.RobertaProcessing(
sep=(ot.sep_token, ot.sep_token_id),
cls=(ot.cls_token, ot.cls_token_id),
add_prefix_space=ot.add_prefix_space,
trim_offsets=True, # True by default on Roberta (historical)
)
return tokenizer
class RoFormerConverter(Converter):
def converted(self) -> Tokenizer:
from .models.roformer.tokenization_utils import JiebaPreTokenizer
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=False,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")
return tokenizer
class DebertaConverter(Converter):
def converted(self) -> Tokenizer:
ot = self.original_tokenizer
vocab = ot.encoder
merges = list(ot.bpe_ranks.keys())
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
continuing_subword_prefix="",
end_of_word_suffix="",
fuse_unk=False,
)
)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:0 [SEP]:0",
special_tokens=[
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
],
)
return tokenizer
class SpmConverter(Converter):
def __init__(self, *args):
requires_backends(self, "protobuf")
super().__init__(*args)
from .utils import sentencepiece_model_pb2 as model_pb2
m = model_pb2.ModelProto()
with open(self.original_tokenizer.vocab_file, "rb") as f:
m.ParseFromString(f.read())
self.proto = m
def vocab(self, proto):
return [(piece.piece, piece.score) for piece in proto.pieces]
def unk_id(self, proto):
return proto.trainer_spec.unk_id
def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab = self.vocab(proto)
unk_id = self.unk_id(proto)
if model_type == 1:
tokenizer = Tokenizer(Unigram(vocab, unk_id))
elif model_type == 2:
_, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab)}
tokenizer = Tokenizer(
BPE(
bpe_vocab,
merges,
unk_token=proto.trainer_spec.unk_piece,
fuse_unk=True,
)
)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
return tokenizer
def normalizer(self, proto):
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
if not precompiled_charsmap:
return normalizers.Sequence([normalizers.Replace(Regex(" {2,}"), " ")])
else:
return normalizers.Sequence(
[normalizers.Precompiled(precompiled_charsmap), normalizers.Replace(Regex(" {2,}"), " ")]
)
def pre_tokenizer(self, replacement, add_prefix_space):
return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
def post_processor(self):
return None
def converted(self) -> Tokenizer:
tokenizer = self.tokenizer(self.proto)
# Tokenizer assemble
tokenizer.normalizer = self.normalizer(self.proto)
replacement = "▁"
add_prefix_space = True
tokenizer.pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
tokenizer.decoder = decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
post_processor = self.post_processor()
if post_processor:
tokenizer.post_processor = post_processor
return tokenizer
class AlbertConverter(SpmConverter):
def vocab(self, proto):
return [
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
for piece in proto.pieces
]
def normalizer(self, proto):
list_normalizers = [
normalizers.Replace("``", '"'),
normalizers.Replace("''", '"'),
]
if not self.original_tokenizer.keep_accents:
list_normalizers.append(normalizers.NFKD())
list_normalizers.append(normalizers.StripAccents())
if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
return normalizers.Sequence(list_normalizers)
def post_processor(self):
return processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
],
)
class BarthezConverter(SpmConverter):
def unk_id(self, proto):
unk_id = 3
return unk_id
def post_processor(self):
return processors.TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class CamembertConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>NOTUSED", 0.0),
("<pad>", 0.0),
("</s>NOTUSED", 0.0),
("<unk>", 0.0),
("<unk>NOTUSED", -100),
]
# We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
vocab += [("<mask>", 0.0)]
return vocab
def unk_id(self, proto):
# See vocab unk position
return 3
def post_processor(self):
return processors.TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class MBartConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
vocab += [
("ar_AR", 0.0),
("cs_CZ", 0.0),
("de_DE", 0.0),
("en_XX", 0.0),
("es_XX", 0.0),
("et_EE", 0.0),
("fi_FI", 0.0),
("fr_XX", 0.0),
("gu_IN", 0.0),
("hi_IN", 0.0),
("it_IT", 0.0),
("ja_XX", 0.0),
("kk_KZ", 0.0),
("ko_KR", 0.0),
("lt_LT", 0.0),
("lv_LV", 0.0),
("my_MM", 0.0),
("ne_NP", 0.0),
("nl_XX", 0.0),
("ro_RO", 0.0),
("ru_RU", 0.0),
("si_LK", 0.0),
("tr_TR", 0.0),
("vi_VN", 0.0),
("zh_CN", 0.0),
]
vocab += [("<mask>", 0.0)]
return vocab
def unk_id(self, proto):
return 3
def post_processor(self):
return processors.TemplateProcessing(
single="$A </s> en_XX",
pair="$A $B </s> en_XX",
special_tokens=[
("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class MBart50Converter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
# fmt: off
vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)]
# fmt: on
vocab += [("<mask>", 0.0)]
return vocab
def unk_id(self, proto):
return 3
def post_processor(self):
return processors.TemplateProcessing(
single="en_XX $A </s>",
pair="en_XX $A $B </s>",
special_tokens=[
("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class XLMRobertaConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
vocab += [("<mask>", 0.0)]
return vocab
def unk_id(self, proto):
unk_id = 3
return unk_id
def post_processor(self):
return processors.TemplateProcessing(
single="<s> $A </s>",
pair="<s> $A </s> </s> $B </s>",
special_tokens=[
("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class XLNetConverter(SpmConverter):
def vocab(self, proto):
return [
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
for piece in proto.pieces
]
def normalizer(self, proto):
list_normalizers = [
normalizers.Replace("``", '"'),
normalizers.Replace("''", '"'),
]
if not self.original_tokenizer.keep_accents:
list_normalizers.append(normalizers.NFKD())
list_normalizers.append(normalizers.StripAccents())
if self.original_tokenizer.do_lower_case:
list_normalizers.append(normalizers.Lowercase())
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
return normalizers.Sequence(list_normalizers)
def post_processor(self):
return processors.TemplateProcessing(
single="$A:0 <sep>:0 <cls>:2",
pair="$A:0 <sep>:0 $B:1 <sep>:1 <cls>:2",
special_tokens=[
("<sep>", self.original_tokenizer.convert_tokens_to_ids("<sep>")),
("<cls>", self.original_tokenizer.convert_tokens_to_ids("<cls>")),
],
)
class ReformerConverter(SpmConverter):
pass
class BertGenerationConverter(SpmConverter):
pass
class PegasusConverter(SpmConverter):
def vocab(self, proto):
vocab = [
(self.original_tokenizer.pad_token, 0.0),
(self.original_tokenizer.eos_token, 0.0),
]
if self.original_tokenizer.mask_token_sent is not None:
vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
if (
self.original_tokenizer.mask_token is not None
and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
):
vocab += [(self.original_tokenizer.mask_token, 0.0)]
vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
return vocab
def unk_id(self, proto):
return proto.trainer_spec.unk_id + self.original_tokenizer.offset
def pre_tokenizer(self, replacement, add_prefix_space):
return pre_tokenizers.Sequence(
[
pre_tokenizers.WhitespaceSplit(),
pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
]
)
def post_processor(self):
eos = self.original_tokenizer.eos_token
special_tokens = [
(eos, self.original_tokenizer.eos_token_id),
]
return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
class T5Converter(SpmConverter):
def vocab(self, proto):
num_extra_ids = self.original_tokenizer._extra_ids
vocab = [(piece.piece, piece.score) for piece in proto.pieces]
vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
return vocab
def post_processor(self):
return processors.TemplateProcessing(
single=["$A", "</s>"],
pair=["$A", "</s>", "$B", "</s>"],
special_tokens=[
("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
],
)
class BigBirdConverter(SpmConverter):
def post_processor(self):
return processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
],
)
class CLIPConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
tokenizer = Tokenizer(
BPE(
vocab=vocab,
merges=merges,
dropout=None,
continuing_subword_prefix="",
end_of_word_suffix="</w>",
fuse_unk=False,
)
)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
return tokenizer
SLOW_TO_FAST_CONVERTERS = {
"AlbertTokenizer": AlbertConverter,
"BartTokenizer": RobertaConverter,
"BarthezTokenizer": BarthezConverter,
"BertTokenizer": BertConverter,
"BigBirdTokenizer": BigBirdConverter,
"CamembertTokenizer": CamembertConverter,
"CLIPTokenizer": CLIPConverter,
"ConvBertTokenizer": BertConverter,
"DebertaTokenizer": DebertaConverter,
"DistilBertTokenizer": BertConverter,
"DPRReaderTokenizer": BertConverter,
"DPRQuestionEncoderTokenizer": BertConverter,
"DPRContextEncoderTokenizer": BertConverter,
"ElectraTokenizer": BertConverter,
"FunnelTokenizer": FunnelConverter,
"GPT2Tokenizer": GPT2Converter,
"HerbertTokenizer": HerbertConverter,
"LayoutLMTokenizer": BertConverter,
"LongformerTokenizer": RobertaConverter,
"LEDTokenizer": RobertaConverter,
"LxmertTokenizer": BertConverter,
"MBartTokenizer": MBartConverter,
"MBart50Tokenizer": MBart50Converter,
"MPNetTokenizer": MPNetConverter,
"MobileBertTokenizer": BertConverter,
"OpenAIGPTTokenizer": OpenAIGPTConverter,
"PegasusTokenizer": PegasusConverter,
"ReformerTokenizer": ReformerConverter,
"RetriBertTokenizer": BertConverter,
"RobertaTokenizer": RobertaConverter,
"RoFormerTokenizer": RoFormerConverter,
"SqueezeBertTokenizer": BertConverter,
"T5Tokenizer": T5Converter,
"XLMRobertaTokenizer": XLMRobertaConverter,
"XLNetTokenizer": XLNetConverter,
}
def convert_slow_tokenizer(transformer_tokenizer) -> Tokenizer:
"""
Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
Args:
transformer_tokenizer (:class:`~transformers.tokenization_utils_base.PreTrainedTokenizer`):
Instance of a slow tokenizer to convert in the backend tokenizer for
:class:`~transformers.tokenization_utils_base.PreTrainedTokenizerFast`.
Return:
A instance of :class:`~tokenizers.Tokenizer` to be used as the backend tokenizer of a
:class:`~transformers.tokenization_utils_base.PreTrainedTokenizerFast`
"""
tokenizer_class_name = transformer_tokenizer.__class__.__name__
if tokenizer_class_name not in SLOW_TO_FAST_CONVERTERS:
raise ValueError(
f"An instance of tokenizer class {tokenizer_class_name} cannot be converted in a Fast tokenizer instance. "
f"No converter was found. Currently available slow->fast convertors: {list(SLOW_TO_FAST_CONVERTERS.keys())}"
)
converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(transformer_tokenizer).converted()
|