ARC125m / data /streaming_dataset.py
torinriley's picture
update
3c8aa4a
import os
import numpy as np
import tiktoken
from datasets import load_dataset, concatenate_datasets, interleave_datasets
from torch.utils.data import IterableDataset
import torch
class StreamingDataset(IterableDataset):
"""Streaming dataset that loads and processes data on the fly"""
def __init__(self, dataset_configs, block_size=2048, batch_size=12):
self.dataset_configs = dataset_configs
self.block_size = block_size
self.batch_size = batch_size
self.enc = tiktoken.get_encoding("gpt2")
def load_and_process_chunk(self, dataset_name, split="train"):
# Load datasets with appropriate configs
if dataset_name == "openwebtext":
dataset = load_dataset(dataset_name, split=split, streaming=True, trust_remote_code=True)
elif dataset_name == "the_pile":
dataset = load_dataset("the_pile", split=split, streaming=True)
elif dataset_name == "red_pajama":
dataset = load_dataset("togethercomputer/RedPajama-Data-1T", split=split, streaming=True)
for example in dataset:
ids = self.enc.encode_ordinary(example['text'])
ids.append(self.enc.eot_token)
if len(ids) >= self.block_size:
# Return chunks of block_size
for i in range(0, len(ids) - self.block_size + 1, self.block_size):
yield torch.tensor(ids[i:i + self.block_size])
def __iter__(self):
# Interleave datasets with specified weights
iterators = []
weights = []
for config in self.dataset_configs:
iterators.append(self.load_and_process_chunk(config['name']))
weights.append(config['weight'])
# Normalize weights
weights = np.array(weights) / sum(weights)
while True:
# Randomly select a dataset based on weights
dataset_idx = np.random.choice(len(iterators), p=weights)
try:
batch = []
for _ in range(self.batch_size):
batch.append(next(iterators[dataset_idx]))
yield torch.stack(batch)
except StopIteration:
# Restart iterator if it's exhausted
iterators[dataset_idx] = self.load_and_process_chunk(self.dataset_configs[dataset_idx]['name'])
continue
# Example usage:
dataset_configs = [
{'name': 'openwebtext', 'weight': 0.4},
{'name': 'the_pile', 'weight': 0.3},
{'name': 'red_pajama', 'weight': 0.3}
]