Spaces:
Build error
Build error
import os | |
import numpy as np | |
import tiktoken | |
from datasets import load_dataset, concatenate_datasets, interleave_datasets | |
from torch.utils.data import IterableDataset | |
import torch | |
class StreamingDataset(IterableDataset): | |
"""Streaming dataset that loads and processes data on the fly""" | |
def __init__(self, dataset_configs, block_size=2048, batch_size=12): | |
self.dataset_configs = dataset_configs | |
self.block_size = block_size | |
self.batch_size = batch_size | |
self.enc = tiktoken.get_encoding("gpt2") | |
def load_and_process_chunk(self, dataset_name, split="train"): | |
# Load datasets with appropriate configs | |
if dataset_name == "openwebtext": | |
dataset = load_dataset(dataset_name, split=split, streaming=True, trust_remote_code=True) | |
elif dataset_name == "the_pile": | |
dataset = load_dataset("the_pile", split=split, streaming=True) | |
elif dataset_name == "red_pajama": | |
dataset = load_dataset("togethercomputer/RedPajama-Data-1T", split=split, streaming=True) | |
for example in dataset: | |
ids = self.enc.encode_ordinary(example['text']) | |
ids.append(self.enc.eot_token) | |
if len(ids) >= self.block_size: | |
# Return chunks of block_size | |
for i in range(0, len(ids) - self.block_size + 1, self.block_size): | |
yield torch.tensor(ids[i:i + self.block_size]) | |
def __iter__(self): | |
# Interleave datasets with specified weights | |
iterators = [] | |
weights = [] | |
for config in self.dataset_configs: | |
iterators.append(self.load_and_process_chunk(config['name'])) | |
weights.append(config['weight']) | |
# Normalize weights | |
weights = np.array(weights) / sum(weights) | |
while True: | |
# Randomly select a dataset based on weights | |
dataset_idx = np.random.choice(len(iterators), p=weights) | |
try: | |
batch = [] | |
for _ in range(self.batch_size): | |
batch.append(next(iterators[dataset_idx])) | |
yield torch.stack(batch) | |
except StopIteration: | |
# Restart iterator if it's exhausted | |
iterators[dataset_idx] = self.load_and_process_chunk(self.dataset_configs[dataset_idx]['name']) | |
continue | |
# Example usage: | |
dataset_configs = [ | |
{'name': 'openwebtext', 'weight': 0.4}, | |
{'name': 'the_pile', 'weight': 0.3}, | |
{'name': 'red_pajama', 'weight': 0.3} | |
] | |