File size: 3,661 Bytes
d8b92ee
1e8ebcb
781de59
c44b75c
fc724d2
af587c3
fc724d2
 
850b586
c9f3e85
c44b75c
76f084f
1e8ebcb
 
781de59
 
 
76f084f
 
c44b75c
76f084f
 
c44b75c
76f084f
 
 
 
 
 
c44b75c
 
76f084f
 
fc724d2
76f084f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc724d2
c44b75c
d8b92ee
c9f3e85
d8b92ee
76f084f
 
 
 
 
d8b92ee
c128a5f
76f084f
 
 
 
 
d8b92ee
 
76f084f
 
 
d8b92ee
76f084f
 
 
 
d8b92ee
76f084f
 
 
 
 
 
 
 
d8b92ee
76f084f
d8b92ee
 
 
af587c3
 
d8b92ee
 
af587c3
d8b92ee
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import pickle
import regex as re
from tqdm import tqdm

# Read text from a file
with open('text_file.txt', 'r', encoding='utf-8') as file:
    text = file.read()

# Hindi-focused pattern
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{N}+| ?(?:[\u0904-\u0939\u093d-\u093d\u0950-\u0950\u0958-\u0961\u0970-\u097f\ua8f2-\ua8fe\U00011b00-\U00011b09\u1cd3-\u1cd3\u1ce9-\u1cec\u1cee-\u1cf3\u1cf5-\u1cf6\u1cfa-\u1cfa][\u0900-\u0903\u093a-\u093c\u093e-\u094f\u0951-\u0957\u0962-\u0963\ua8e0-\ua8f1\ua8ff-\ua8ff\u1cd0-\u1cd2\u1cd4-\u1ce8\u1ced-\u1ced\u1cf4-\u1cf4\u1cf7-\u1cf9]*)+| ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

# Apply the regex pattern to the raw text to tokenize it
tokens = re.findall(gpt2pat, text)

# Convert tokens to byte sequences
byte_tokens = [token.encode('utf-8') for token in tokens]

# Create a list of byte sequences, each representing a token
tokens = [list(token) for token in byte_tokens]

def get_stats(token_list):
    """Count frequency of pairs across all tokens"""
    counts = {}
    # Count pairs within each token
    for token in token_list:
        if len(token) < 2:
            continue
        for pair in zip(token, token[1:]):
            counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(token_list, pair, idx):
    """Merge all occurrences of pair within each token"""
    newids = []
    for token in token_list:
        if len(token) < 2:
            newids.append(token)
            continue
            
        new_token = []
        i = 0
        while i < len(token):
            if i < len(token) - 1 and (token[i], token[i+1]) == pair:
                new_token.append(idx)
                i += 2
            else:
                new_token.append(token[i])
                i += 1
        newids.append(new_token)
    return newids

def perform_bpe():
    vocab_size = 4000  # the desired final vocabulary size
    num_merges = vocab_size - 256
    token_list = list(tokens)  # copy so we don't destroy the original list
    
    # Calculate total bytes before compression
    total_bytes_before = sum(len(token) for token in token_list)
    
    merges = {}  # (int, int) -> int
    for i in tqdm(range(num_merges), desc="Performing BPE", unit="merge"):
        stats = get_stats(token_list)
        if not stats:  # No more pairs to merge
            break
            
        # Find most frequent pair
        pair = max(stats, key=stats.get)
        idx = 256 + i
        
        # Perform the merge
        token_list = merge(token_list, pair, idx)
        merges[pair] = idx
    
    # Calculate total bytes after compression
    total_bytes_after = sum(len(token) for token in token_list)
    
    print("---")
    print("Total bytes before:", total_bytes_before)
    print("Total bytes after:", total_bytes_after)
    print(f"Compression ratio: {total_bytes_before / total_bytes_after:.2f}X")
    
    # Flatten for storage, but maintain token boundaries
    flat_ids = []
    for token in token_list:
        flat_ids.extend(token)
    
    return merges, flat_ids, num_merges

if __name__ == "__main__":
    print('---')
    print("length of text (characters):", len(text))
    print("length of text (words):", len(text.split()))
    print('---')
    print("length of tokens:", len(tokens))
    #print("sample tokens:", tokens[:5])  # Show first 5 tokens
    
    # Run BPE and save results
    merges, ids, num_merges = perform_bpe()

    # Save merges and vocab to a file
    with open('bpe_results.pkl', 'wb') as f:
        pickle.dump((merges, ids, num_merges), f)