File size: 2,580 Bytes
3c8aa4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import numpy as np
import tiktoken
from datasets import load_dataset, concatenate_datasets, interleave_datasets
from torch.utils.data import IterableDataset
import torch

class StreamingDataset(IterableDataset):
    """Streaming dataset that loads and processes data on the fly"""
    
    def __init__(self, dataset_configs, block_size=2048, batch_size=12):
        self.dataset_configs = dataset_configs
        self.block_size = block_size
        self.batch_size = batch_size
        self.enc = tiktoken.get_encoding("gpt2")
        
    def load_and_process_chunk(self, dataset_name, split="train"):
        # Load datasets with appropriate configs
        if dataset_name == "openwebtext":
            dataset = load_dataset(dataset_name, split=split, streaming=True, trust_remote_code=True)
        elif dataset_name == "the_pile":
            dataset = load_dataset("the_pile", split=split, streaming=True)
        elif dataset_name == "red_pajama":
            dataset = load_dataset("togethercomputer/RedPajama-Data-1T", split=split, streaming=True)
        
        for example in dataset:
            ids = self.enc.encode_ordinary(example['text'])
            ids.append(self.enc.eot_token)
            if len(ids) >= self.block_size:
                # Return chunks of block_size
                for i in range(0, len(ids) - self.block_size + 1, self.block_size):
                    yield torch.tensor(ids[i:i + self.block_size])
    
    def __iter__(self):
        # Interleave datasets with specified weights
        iterators = []
        weights = []
        for config in self.dataset_configs:
            iterators.append(self.load_and_process_chunk(config['name']))
            weights.append(config['weight'])
            
        # Normalize weights
        weights = np.array(weights) / sum(weights)
        
        while True:
            # Randomly select a dataset based on weights
            dataset_idx = np.random.choice(len(iterators), p=weights)
            try:
                batch = []
                for _ in range(self.batch_size):
                    batch.append(next(iterators[dataset_idx]))
                yield torch.stack(batch)
            except StopIteration:
                # Restart iterator if it's exhausted
                iterators[dataset_idx] = self.load_and_process_chunk(self.dataset_configs[dataset_idx]['name'])
                continue

# Example usage:
dataset_configs = [
    {'name': 'openwebtext', 'weight': 0.4},
    {'name': 'the_pile', 'weight': 0.3},
    {'name': 'red_pajama', 'weight': 0.3}
]