Spaces:
Sleeping
Sleeping
File size: 3,661 Bytes
d8b92ee 1e8ebcb 781de59 c44b75c fc724d2 af587c3 fc724d2 850b586 c9f3e85 c44b75c 76f084f 1e8ebcb 781de59 76f084f c44b75c 76f084f c44b75c 76f084f c44b75c 76f084f fc724d2 76f084f fc724d2 c44b75c d8b92ee c9f3e85 d8b92ee 76f084f d8b92ee c128a5f 76f084f d8b92ee 76f084f d8b92ee 76f084f d8b92ee 76f084f d8b92ee 76f084f d8b92ee af587c3 d8b92ee af587c3 d8b92ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import pickle
import regex as re
from tqdm import tqdm
# Read text from a file
with open('text_file.txt', 'r', encoding='utf-8') as file:
text = file.read()
# Hindi-focused pattern
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{N}+| ?(?:[\u0904-\u0939\u093d-\u093d\u0950-\u0950\u0958-\u0961\u0970-\u097f\ua8f2-\ua8fe\U00011b00-\U00011b09\u1cd3-\u1cd3\u1ce9-\u1cec\u1cee-\u1cf3\u1cf5-\u1cf6\u1cfa-\u1cfa][\u0900-\u0903\u093a-\u093c\u093e-\u094f\u0951-\u0957\u0962-\u0963\ua8e0-\ua8f1\ua8ff-\ua8ff\u1cd0-\u1cd2\u1cd4-\u1ce8\u1ced-\u1ced\u1cf4-\u1cf4\u1cf7-\u1cf9]*)+| ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
# Apply the regex pattern to the raw text to tokenize it
tokens = re.findall(gpt2pat, text)
# Convert tokens to byte sequences
byte_tokens = [token.encode('utf-8') for token in tokens]
# Create a list of byte sequences, each representing a token
tokens = [list(token) for token in byte_tokens]
def get_stats(token_list):
"""Count frequency of pairs across all tokens"""
counts = {}
# Count pairs within each token
for token in token_list:
if len(token) < 2:
continue
for pair in zip(token, token[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(token_list, pair, idx):
"""Merge all occurrences of pair within each token"""
newids = []
for token in token_list:
if len(token) < 2:
newids.append(token)
continue
new_token = []
i = 0
while i < len(token):
if i < len(token) - 1 and (token[i], token[i+1]) == pair:
new_token.append(idx)
i += 2
else:
new_token.append(token[i])
i += 1
newids.append(new_token)
return newids
def perform_bpe():
vocab_size = 4000 # the desired final vocabulary size
num_merges = vocab_size - 256
token_list = list(tokens) # copy so we don't destroy the original list
# Calculate total bytes before compression
total_bytes_before = sum(len(token) for token in token_list)
merges = {} # (int, int) -> int
for i in tqdm(range(num_merges), desc="Performing BPE", unit="merge"):
stats = get_stats(token_list)
if not stats: # No more pairs to merge
break
# Find most frequent pair
pair = max(stats, key=stats.get)
idx = 256 + i
# Perform the merge
token_list = merge(token_list, pair, idx)
merges[pair] = idx
# Calculate total bytes after compression
total_bytes_after = sum(len(token) for token in token_list)
print("---")
print("Total bytes before:", total_bytes_before)
print("Total bytes after:", total_bytes_after)
print(f"Compression ratio: {total_bytes_before / total_bytes_after:.2f}X")
# Flatten for storage, but maintain token boundaries
flat_ids = []
for token in token_list:
flat_ids.extend(token)
return merges, flat_ids, num_merges
if __name__ == "__main__":
print('---')
print("length of text (characters):", len(text))
print("length of text (words):", len(text.split()))
print('---')
print("length of tokens:", len(tokens))
#print("sample tokens:", tokens[:5]) # Show first 5 tokens
# Run BPE and save results
merges, ids, num_merges = perform_bpe()
# Save merges and vocab to a file
with open('bpe_results.pkl', 'wb') as f:
pickle.dump((merges, ids, num_merges), f)
|