import torch from torch.utils.data import DataLoader, Dataset from datasets import load_dataset from transformers import AutoTokenizer def train_tokenizer(texts, vocab_size=50000, min_frequency=2): tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer = tokenizer.train_new_from_iterator(texts, vocab_size=vocab_size, min_frequency=min_frequency) if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) tokenizer.save_pretrained("./tokenizer") return tokenizer def load_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("./tokenizer") if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) return tokenizer class TextDataset(Dataset): def __init__(self, texts, tokenizer, max_length): self.texts = texts self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] encodings = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length) return torch.tensor(encodings['input_ids']) def get_dataloader(dataset_name, config_name, tokenizer, max_length, batch_size): dataset = load_dataset(dataset_name, config_name) texts = dataset['train']['text'][:50] #delete [:500 for actual training set w/ full voxabsize] dataset = TextDataset(texts, tokenizer, max_length) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) return dataloader