spt / tokenization_spt.py
imdatta0's picture
Add Sherlock Pretrained tranformer
be3a39d
from transformers import PreTrainedTokenizer
from typing import List, Optional
import json
class SPTTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file=None, **kwargs):
super().__init__(**kwargs)
self.vocab = self.load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.pad_token = self.eos_token = "#"
self.unk_token = "[UNK]"
@property
def vocab_size(self):
return len(self.vocab)
def get_vocab(self):
return dict(self.vocab)
def _tokenize(self, text):
return list(text)
def _convert_token_to_id(self, token):
return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index):
return self.inv_vocab.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens):
return ''.join(tokens)
def build_inputs_with_special_tokens(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
if token_ids_1 is None:
return token_ids_0 + [self.eos_token_id]
return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False) -> List[int]:
if already_has_special_tokens:
return [1 if token in [self.eos_token_id] else 0 for token in token_ids_0]
if token_ids_1 is None:
return [0] * len(token_ids_0) + [1]
return [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1]
def create_token_type_ids_from_sequences(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None) -> List[int]:
if token_ids_1 is None:
return [0] * (len(token_ids_0) + 1)
return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 1)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
tokenizer = super().from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
return tokenizer
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
import os
if not os.path.isdir(save_directory):
os.mkdir(save_directory)
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.vocab, ensure_ascii=False))
return (vocab_file,)
def load_vocab(self, vocab_file):
if vocab_file is None:
return {'\n': 0,
' ': 1,
'!': 2,
'"': 3,
'&': 4,
"'": 5,
'(': 6,
')': 7,
'*': 8,
',': 9,
'-': 10,
'.': 11,
'0': 12,
'1': 13,
'2': 14,
'3': 15,
'4': 16,
'5': 17,
'6': 18,
'7': 19,
'8': 20,
'9': 21,
':': 22,
';': 23,
'?': 24,
'A': 25,
'B': 26,
'C': 27,
'D': 28,
'E': 29,
'F': 30,
'G': 31,
'H': 32,
'I': 33,
'J': 34,
'K': 35,
'L': 36,
'M': 37,
'N': 38,
'O': 39,
'P': 40,
'Q': 41,
'R': 42,
'S': 43,
'T': 44,
'U': 45,
'V': 46,
'W': 47,
'X': 48,
'Y': 49,
'Z': 50,
'[': 51,
']': 52,
'`': 53,
'a': 54,
'b': 55,
'c': 56,
'd': 57,
'e': 58,
'f': 59,
'g': 60,
'h': 61,
'i': 62,
'j': 63,
'k': 64,
'l': 65,
'm': 66,
'n': 67,
'o': 68,
'p': 69,
'q': 70,
'r': 71,
's': 72,
't': 73,
'u': 74,
'v': 75,
'w': 76,
'x': 77,
'y': 78,
'z': 79,
'£': 80,
'°': 81,
'ß': 82,
'à': 83,
'â': 84,
'è': 85,
'é': 86,
'ê': 87,
'î': 88,
'ñ': 89,
'ô': 90,
'ö': 91,
'û': 92,
'ü': 93}
else:
with open(vocab_file, 'r', encoding='utf-8') as f:
return json.load(f)