|
from transformers import PreTrainedTokenizer |
|
from typing import List, Optional |
|
import json |
|
|
|
class SPTTokenizer(PreTrainedTokenizer): |
|
def __init__(self, vocab_file=None, **kwargs): |
|
super().__init__(**kwargs) |
|
self.vocab = self.load_vocab(vocab_file) |
|
self.inv_vocab = {v: k for k, v in self.vocab.items()} |
|
self.pad_token = self.eos_token = "#" |
|
self.unk_token = "[UNK]" |
|
|
|
@property |
|
def vocab_size(self): |
|
return len(self.vocab) |
|
|
|
def get_vocab(self): |
|
return dict(self.vocab) |
|
|
|
def _tokenize(self, text): |
|
return list(text) |
|
|
|
def _convert_token_to_id(self, token): |
|
return self.vocab.get(token, self.vocab.get(self.unk_token)) |
|
|
|
def _convert_id_to_token(self, index): |
|
return self.inv_vocab.get(index, self.unk_token) |
|
|
|
def convert_tokens_to_string(self, tokens): |
|
return ''.join(tokens) |
|
|
|
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: |
|
if token_ids_1 is None: |
|
return token_ids_0 + [self.eos_token_id] |
|
return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] |
|
|
|
def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False) -> List[int]: |
|
if already_has_special_tokens: |
|
return [1 if token in [self.eos_token_id] else 0 for token in token_ids_0] |
|
if token_ids_1 is None: |
|
return [0] * len(token_ids_0) + [1] |
|
return [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1] |
|
|
|
def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]: |
|
if token_ids_1 is None: |
|
return [0] * (len(token_ids_0) + 1) |
|
return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 1) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): |
|
tokenizer = super().from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) |
|
return tokenizer |
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
|
import os |
|
|
|
if not os.path.isdir(save_directory): |
|
os.mkdir(save_directory) |
|
|
|
vocab_file = os.path.join( |
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json" |
|
) |
|
|
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
f.write(json.dumps(self.vocab, ensure_ascii=False)) |
|
|
|
return (vocab_file,) |
|
|
|
def load_vocab(self, vocab_file): |
|
if vocab_file is None: |
|
return {'\n': 0, |
|
' ': 1, |
|
'!': 2, |
|
'"': 3, |
|
'&': 4, |
|
"'": 5, |
|
'(': 6, |
|
')': 7, |
|
'*': 8, |
|
',': 9, |
|
'-': 10, |
|
'.': 11, |
|
'0': 12, |
|
'1': 13, |
|
'2': 14, |
|
'3': 15, |
|
'4': 16, |
|
'5': 17, |
|
'6': 18, |
|
'7': 19, |
|
'8': 20, |
|
'9': 21, |
|
':': 22, |
|
';': 23, |
|
'?': 24, |
|
'A': 25, |
|
'B': 26, |
|
'C': 27, |
|
'D': 28, |
|
'E': 29, |
|
'F': 30, |
|
'G': 31, |
|
'H': 32, |
|
'I': 33, |
|
'J': 34, |
|
'K': 35, |
|
'L': 36, |
|
'M': 37, |
|
'N': 38, |
|
'O': 39, |
|
'P': 40, |
|
'Q': 41, |
|
'R': 42, |
|
'S': 43, |
|
'T': 44, |
|
'U': 45, |
|
'V': 46, |
|
'W': 47, |
|
'X': 48, |
|
'Y': 49, |
|
'Z': 50, |
|
'[': 51, |
|
']': 52, |
|
'`': 53, |
|
'a': 54, |
|
'b': 55, |
|
'c': 56, |
|
'd': 57, |
|
'e': 58, |
|
'f': 59, |
|
'g': 60, |
|
'h': 61, |
|
'i': 62, |
|
'j': 63, |
|
'k': 64, |
|
'l': 65, |
|
'm': 66, |
|
'n': 67, |
|
'o': 68, |
|
'p': 69, |
|
'q': 70, |
|
'r': 71, |
|
's': 72, |
|
't': 73, |
|
'u': 74, |
|
'v': 75, |
|
'w': 76, |
|
'x': 77, |
|
'y': 78, |
|
'z': 79, |
|
'£': 80, |
|
'°': 81, |
|
'ß': 82, |
|
'à': 83, |
|
'â': 84, |
|
'è': 85, |
|
'é': 86, |
|
'ê': 87, |
|
'î': 88, |
|
'ñ': 89, |
|
'ô': 90, |
|
'ö': 91, |
|
'û': 92, |
|
'ü': 93} |
|
else: |
|
with open(vocab_file, 'r', encoding='utf-8') as f: |
|
return json.load(f) |