File size: 2,361 Bytes
d8b92ee
 
781de59
d8b92ee
 
 
 
 
76f084f
 
 
d8b92ee
 
 
 
 
 
781de59
76f084f
 
 
d8b92ee
 
 
 
 
 
 
 
76f084f
c128a5f
d8b92ee
 
 
 
 
 
 
781de59
 
 
76f084f
781de59
76f084f
781de59
76f084f
 
 
 
 
 
 
 
 
 
 
 
 
 
d8b92ee
76f084f
d8b92ee
 
781de59
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import pickle
from BPE import get_stats, merge
import regex as re

# Load merges and vocab from the file
with open('bpe_results.pkl', 'rb') as f:
    merges, ids, num_merges = pickle.load(f)

# Define the GPT-2 regex pattern (same as in BPE.py)
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
    # given ids (list of integers), return Python string
    tokens = [vocab[idx] for idx in ids]
    # Decode each token separately and join with tabs
    decoded_tokens = [token.decode("utf-8", errors="replace") for token in tokens]
    text = '\t'.join(decoded_tokens)
    
    # Write the decoded text to a new file
    with open('decoded_output.txt', 'w', encoding='utf-8') as f:
        f.write(text)
    
    return text

# Example: Decode a list of IDs
set_of_ids = [1072, 415, 308, 1406, 103, 279, 999, 260, 550, 46, 301, 39, 299, 1076, 1172, 562, 284, 111, 414, 1460, 46, 301, 116, 373, 308, 259, 562, 798, 832, 1460, 1449, 44, 892, 415, 308, 311, 112, 112, 549, 46]
decoded_text = decode(set_of_ids)  # Pass the list of IDs
print(decoded_text)

def encode():
    # Read input text from a new file
    with open('encode_input.txt', 'r', encoding='utf-8') as f:
        text = f.read()
    
    # Tokenize the text using the regex pattern
    tokens = re.findall(gpt2pat, text)
    
    # Convert tokens to byte sequences and maintain grouping
    byte_tokens = [token.encode('utf-8') for token in tokens]
    token_list = [list(token) for token in byte_tokens]
    
    # Process each token
    final_tokens = []
    for token in token_list:
        current_token = list(token)
        while len(current_token) >= 2:
            stats = get_stats([current_token])
            if not stats:
                break
            pair = min(stats, key=lambda p: merges.get(p, float("inf")))
            if pair not in merges:
                break
            idx = merges[pair]
            current_token = merge([current_token], pair, idx)[0]
        final_tokens.extend(current_token)
    
    return final_tokens

# Example: Encode text from a file
encoded_tokens = encode()
print(encoded_tokens)