File size: 1,527 Bytes
c44b75c
 
fc724d2
 
 
 
 
 
c44b75c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc724d2
 
 
 
 
 
 
 
 
 
 
c44b75c
 
fc724d2
c44b75c
fc724d2
c44b75c
fc724d2
c44b75c
fc724d2
 
 
 
 
 
c44b75c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import regex as re

# Read text from a file
with open('text_file.txt', 'r', encoding='utf-8') as file:
    text = file.read()

tokens = text.encode("utf-8")  # raw bytes
tokens = list(map(int, tokens))  # convert to a list of integers in range 0..255 for convenience

print('---')
print("length of text:", len(text))
print('---')
#print(tokens)
print("length of tokens:", len(tokens))


def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    # in the list of ints (ids), replace all consecutive occurrences of pair with the new token idx
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

# ---
vocab_size = 500  # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens)  # copy so we don't destroy the original list

merges = {}  # (int, int) -> int
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    #print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

#print("tokens length:", len(tokens))
#print(ids)
print("---")
print("ids length:", len(ids))
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")