import os import numpy as np import tiktoken from datasets import load_dataset, concatenate_datasets, interleave_datasets from torch.utils.data import IterableDataset import torch class StreamingDataset(IterableDataset): """Streaming dataset that loads and processes data on the fly""" def __init__(self, dataset_configs, block_size=2048, batch_size=12): self.dataset_configs = dataset_configs self.block_size = block_size self.batch_size = batch_size self.enc = tiktoken.get_encoding("gpt2") def load_and_process_chunk(self, dataset_name, split="train"): # Load datasets with appropriate configs if dataset_name == "openwebtext": dataset = load_dataset(dataset_name, split=split, streaming=True, trust_remote_code=True) elif dataset_name == "the_pile": dataset = load_dataset("the_pile", split=split, streaming=True) elif dataset_name == "red_pajama": dataset = load_dataset("togethercomputer/RedPajama-Data-1T", split=split, streaming=True) for example in dataset: ids = self.enc.encode_ordinary(example['text']) ids.append(self.enc.eot_token) if len(ids) >= self.block_size: # Return chunks of block_size for i in range(0, len(ids) - self.block_size + 1, self.block_size): yield torch.tensor(ids[i:i + self.block_size]) def __iter__(self): # Interleave datasets with specified weights iterators = [] weights = [] for config in self.dataset_configs: iterators.append(self.load_and_process_chunk(config['name'])) weights.append(config['weight']) # Normalize weights weights = np.array(weights) / sum(weights) while True: # Randomly select a dataset based on weights dataset_idx = np.random.choice(len(iterators), p=weights) try: batch = [] for _ in range(self.batch_size): batch.append(next(iterators[dataset_idx])) yield torch.stack(batch) except StopIteration: # Restart iterator if it's exhausted iterators[dataset_idx] = self.load_and_process_chunk(self.dataset_configs[dataset_idx]['name']) continue # Example usage: dataset_configs = [ {'name': 'openwebtext', 'weight': 0.4}, {'name': 'the_pile', 'weight': 0.3}, {'name': 'red_pajama', 'weight': 0.3} ]