Spaces:
Sleeping
Sleeping
import json | |
import os | |
from typing import Iterable | |
import torch | |
class Tokenizer: | |
def __init__(self, data_path: str = None): | |
self.config = None | |
self.stoi = None | |
self.itos = None | |
self.vocab_size = None | |
if data_path: | |
self.data = self.load_data(data_path) | |
else: | |
self.data = None | |
def from_pretrained(self, config_path: str): | |
with open(config_path) as f: | |
config = json.load(f) | |
self.config = config | |
if 'encode' not in config: | |
raise ValueError("Config file must contain an 'encode' key.") | |
if 'decode' not in config: | |
raise ValueError("Config file must contain a 'decode' key.") | |
if 'vocab_size' not in config: | |
raise ValueError("Config file must contain a 'vocab_size' key.") | |
stoi = config['encode'] | |
self.stoi = {k: int(v) for k, v in stoi.items()} | |
itos = config['decode'] | |
self.itos = {int(k): v for k, v in itos.items()} | |
self.vocab_size = config['vocab_size'] | |
return self | |
def load_data(self, path: str) -> str: | |
if not os.path.exists(path): | |
raise FileNotFoundError("File not found.") | |
if not path.endswith('.txt'): | |
raise ValueError("File must be a text file.") | |
with open(path, 'r', encoding='utf-8') as f: | |
text = f.read() | |
chars = sorted(list(set(text))) | |
vocab_size = len(chars) | |
stoi = {ch: i for i, ch in enumerate(chars)} | |
itos = {i: ch for i, ch in enumerate(chars)} | |
self.config = {"vocab_size": vocab_size, "encode": stoi, "decode": itos} | |
self.stoi = stoi | |
self.itos = itos | |
data = torch.tensor(self(text), dtype=torch.long) | |
n = int(0.9*len(data)) | |
train_data = data[:n] | |
val_data = data[n:] | |
self.train_data = train_data | |
self.val_data = val_data | |
self.vocab_size = vocab_size | |
return text | |
def __repr__(self) -> str: | |
if self.config: | |
return f"Tokenizer(config={self.config})" | |
else: | |
return f"Tokenizer()" | |
def __str__(self) -> str: | |
if self.config: | |
return f"Tokenizer(config_path={self.config})" | |
else: | |
return f"Tokenizer()" | |
def __len__(self) -> int: | |
return len(self.stoi) | |
def __getitem__(self, key: str) -> int: | |
return self.stoi[key] | |
def __contains__(self, key: str) -> bool: | |
return key in self.stoi | |
def __iter__(self): | |
return iter(self.stoi) | |
def __reversed__(self): | |
return reversed(self.stoi) | |
def keys(self): | |
return self.stoi.keys() | |
def values(self): | |
return self.stoi.values() | |
def items(self): | |
return self.stoi.items() | |
def __call__(self, *args, **kwds) -> list[int]: | |
return self.encode(*args, **kwds) | |
def encode(self, s: str | list[str]) -> list[int]: | |
if isinstance(s, str): | |
return [self.stoi[c] for c in s] | |
elif isinstance(s, list): | |
return [[self.stoi[i] for i in c] for c in s] | |
else: | |
raise ValueError("Input must be a string or a list of strings.") | |
def decode(self, l: list[int]) -> str: | |
if isinstance(l[0], int): | |
return ''.join([self.itos[i] for i in l]) | |
elif isinstance(l[0], Iterable): | |
return [''.join([self.itos[i] for i in c]) for c in l] | |
else: | |
raise ValueError("Input must be a list of integers or a list of list of integers.") | |
def save_pretrained(self, path: str) -> str: | |
with open(path + 'vocab.json', 'w') as f: | |
json.dump(self.config, f) | |
return "Tokenizer saved at {}.".format(path) | |