File size: 3,798 Bytes
f4e648b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
import os
from typing import Iterable
import torch

class Tokenizer:
    def __init__(self, data_path: str = None):
        self.config = None
        self.stoi = None
        self.itos = None
        self.vocab_size = None
        if data_path:
            self.data = self.load_data(data_path)
        else:
            self.data = None
    
    def from_pretrained(self, config_path: str):
        with open(config_path) as f:
            config = json.load(f)
        self.config = config
        if 'encode' not in config:
            raise ValueError("Config file must contain an 'encode' key.")
        if 'decode' not in config:
            raise ValueError("Config file must contain a 'decode' key.")
        if 'vocab_size' not in config:
            raise ValueError("Config file must contain a 'vocab_size' key.")
        stoi = config['encode']
        self.stoi = {k: int(v) for k, v in stoi.items()}
        itos = config['decode']
        self.itos = {int(k): v for k, v in itos.items()}
        self.vocab_size = config['vocab_size']
        return self
    
    def load_data(self, path: str) -> str:
        if not os.path.exists(path):
            raise FileNotFoundError("File not found.")
        if not path.endswith('.txt'):
            raise ValueError("File must be a text file.")
        with open(path, 'r', encoding='utf-8') as f:
            text = f.read()
        chars = sorted(list(set(text)))
        vocab_size = len(chars)
        stoi = {ch: i for i, ch in enumerate(chars)}
        itos = {i: ch for i, ch in enumerate(chars)}
        self.config = {"vocab_size": vocab_size, "encode": stoi, "decode": itos}
        self.stoi = stoi
        self.itos = itos
        data = torch.tensor(self(text), dtype=torch.long)
        n = int(0.9*len(data))
        train_data = data[:n]
        val_data = data[n:]
        self.train_data = train_data
        self.val_data = val_data
        self.vocab_size = vocab_size
        return text

    def __repr__(self) -> str:
        if self.config:
            return f"Tokenizer(config={self.config})"
        else:
            return f"Tokenizer()"
    
    def __str__(self) -> str:
        if self.config:
            return f"Tokenizer(config_path={self.config})"
        else:
            return f"Tokenizer()"
    
    def __len__(self) -> int:
        return len(self.stoi)
    
    def __getitem__(self, key: str) -> int:
        return self.stoi[key]
    
    def __contains__(self, key: str) -> bool:
        return key in self.stoi
    
    def __iter__(self):
        return iter(self.stoi)
    
    def __reversed__(self):
        return reversed(self.stoi)
    
    def keys(self):
        return self.stoi.keys()
    
    def values(self):
        return self.stoi.values()
    
    def items(self):
        return self.stoi.items()
    
    def __call__(self, *args, **kwds) -> list[int]:
        return self.encode(*args, **kwds)

    def encode(self, s: str | list[str]) -> list[int]:
        if isinstance(s, str):
            return [self.stoi[c] for c in s]
        elif isinstance(s, list):
            return [[self.stoi[i] for i in c] for c in s]
        else:
            raise ValueError("Input must be a string or a list of strings.")

    def decode(self, l: list[int]) -> str:
        if isinstance(l[0], int):
            return ''.join([self.itos[i] for i in l])
        elif isinstance(l[0], Iterable):
            return [''.join([self.itos[i] for i in c]) for c in l]
        else:
            raise ValueError("Input must be a list of integers or a list of list of integers.")
    
    def save_pretrained(self, path: str) -> str:
        with open(path + 'vocab.json', 'w') as f:
            json.dump(self.config, f)
        return "Tokenizer saved at {}.".format(path)