File size: 2,710 Bytes
f4e648b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import torch


class DataLoader:
    def __init__(self, data_path):
        self.data_path = data_path
        self.batch_size = None
        self.block_size = None
        self.data = None
        self.train_data = None
        self.val_data = None

    def load_data(self, block_size=128, split=0.8, batch_size=64, device='cpu'):
        with open(self.data_path, 'r') as f:
            data = f.read()
        self.block_size = block_size
        self.batch_size = batch_size
        self.device = device
        self.data = data

    def __len__(self):
        return int(np.ceil(len(self.data) / self.batch_size))

    def __getitem__(self, index):
        indexes = self.indexes[index *
                               self.batch_size:(index + 1) * self.batch_size]
        batch = [self.data[i] for i in indexes]
        batch = np.array(batch)
        return batch

    def get_batch(self, split, device='cpu'):
        if self.data is None:
            raise ValueError('Data not loaded')
        data = self.train_data if split == 'train' else self.val_data
        ix = torch.randint(len(data) - self.block_size, (self.batch_size,))
        x = torch.stack([data[i:i+self.block_size] for i in ix])
        y = torch.stack([data[i+1:i+self.block_size+1] for i in ix])
        x, y = x.to(device), y.to(device)
        return x, y


class Encoder:
    def __init__(self, data, type='char'):
        self.data = data
        self.type = type
        self.vocab_size = None
        if type == 'char':
            self.chars = sorted(list(set(data)))
            self.stoi = {ch: i for i, ch in enumerate(self.chars)}
            self.itos = {i: ch for i, ch in enumerate(self.chars)}
            self.vocab_size = len(self.chars)
        elif type == 'word':
            self.words = data.split()
            self.stoi = {word: i for i, word in enumerate(self.words)}
            self.itos = {i: word for i, word in enumerate(self.words)}
            self.vocab_size = len(self.words)
        else:
            raise ValueError('Type must be either "char" or "word"')

    def encode(self, string: str):
        if self.type == 'char':
            return torch.tensor([self.stoi[c] for c in string])
        elif self.type == 'word':
            return torch.tensor([self.stoi[w] for w in string.split()])
        else:
            raise ValueError('Type must be either "char" or "word"')

    def decode(self, ids: list):
        if self.type == 'char':
            return ''.join([self.itos[i] for i in ids])
        elif self.type == 'word':
            return ' '.join([self.itos[i] for i in ids])
        else:
            raise ValueError('Type must be either "char" or "word"')