QuillGPT / core /tokenizers /tokenizer.py
NotShrirang's picture
feat: add application file
f4e648b
raw
history blame
3.8 kB
import json
import os
from typing import Iterable
import torch
class Tokenizer:
def __init__(self, data_path: str = None):
self.config = None
self.stoi = None
self.itos = None
self.vocab_size = None
if data_path:
self.data = self.load_data(data_path)
else:
self.data = None
def from_pretrained(self, config_path: str):
with open(config_path) as f:
config = json.load(f)
self.config = config
if 'encode' not in config:
raise ValueError("Config file must contain an 'encode' key.")
if 'decode' not in config:
raise ValueError("Config file must contain a 'decode' key.")
if 'vocab_size' not in config:
raise ValueError("Config file must contain a 'vocab_size' key.")
stoi = config['encode']
self.stoi = {k: int(v) for k, v in stoi.items()}
itos = config['decode']
self.itos = {int(k): v for k, v in itos.items()}
self.vocab_size = config['vocab_size']
return self
def load_data(self, path: str) -> str:
if not os.path.exists(path):
raise FileNotFoundError("File not found.")
if not path.endswith('.txt'):
raise ValueError("File must be a text file.")
with open(path, 'r', encoding='utf-8') as f:
text = f.read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
self.config = {"vocab_size": vocab_size, "encode": stoi, "decode": itos}
self.stoi = stoi
self.itos = itos
data = torch.tensor(self(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
self.train_data = train_data
self.val_data = val_data
self.vocab_size = vocab_size
return text
def __repr__(self) -> str:
if self.config:
return f"Tokenizer(config={self.config})"
else:
return f"Tokenizer()"
def __str__(self) -> str:
if self.config:
return f"Tokenizer(config_path={self.config})"
else:
return f"Tokenizer()"
def __len__(self) -> int:
return len(self.stoi)
def __getitem__(self, key: str) -> int:
return self.stoi[key]
def __contains__(self, key: str) -> bool:
return key in self.stoi
def __iter__(self):
return iter(self.stoi)
def __reversed__(self):
return reversed(self.stoi)
def keys(self):
return self.stoi.keys()
def values(self):
return self.stoi.values()
def items(self):
return self.stoi.items()
def __call__(self, *args, **kwds) -> list[int]:
return self.encode(*args, **kwds)
def encode(self, s: str | list[str]) -> list[int]:
if isinstance(s, str):
return [self.stoi[c] for c in s]
elif isinstance(s, list):
return [[self.stoi[i] for i in c] for c in s]
else:
raise ValueError("Input must be a string or a list of strings.")
def decode(self, l: list[int]) -> str:
if isinstance(l[0], int):
return ''.join([self.itos[i] for i in l])
elif isinstance(l[0], Iterable):
return [''.join([self.itos[i] for i in c]) for c in l]
else:
raise ValueError("Input must be a list of integers or a list of list of integers.")
def save_pretrained(self, path: str) -> str:
with open(path + 'vocab.json', 'w') as f:
json.dump(self.config, f)
return "Tokenizer saved at {}.".format(path)