Spaces:
Sleeping
Sleeping
import pickle | |
import regex as re | |
from tqdm import tqdm | |
# Read text from a file | |
with open('text_file.txt', 'r', encoding='utf-8') as file: | |
text = file.read() | |
# Hindi-focused pattern | |
gpt2pat = re.compile(r""" | |
# Simpler syllable-based grouping | |
(?:[\p{Devanagari}&&[क-ह]][ा-ौ\u093C\u0901-\u0903]?) # Consonant + modifiers | |
# This part matches: | |
# - Any consonant [क-ह] | |
# - Optionally followed by: | |
# - maatras [ा-ौ] (like ा ि ी ु ू े ै ो ौ) | |
# - OR nukta (\u093C = ़) | |
# - OR chandrabindu (\u0901 = ँ) | |
# - OR anusvara (\u0902 = ं) | |
# - OR visarga (\u0903 = ः) | |
|[\u0905-\u0914] # Independent vowels | |
# Matches standalone vowels like अ आ इ ई उ ऊ ए ऐ ओ औ | |
|[क-ह]्[क-ह] # Basic conjuncts | |
# Matches basic consonant conjuncts: | |
# - First consonant + halant (्) + second consonant | |
# - Examples: क्क, न्न, त्त | |
|\p{N}+ # Numbers | |
# Matches one or more digits | |
|\s+ # Whitespace | |
# Matches spaces, tabs, newlines | |
|[।॥] # Punctuation | |
# Matches Hindi punctuation marks | |
|[^\s\p{Devanagari}\p{N}]+ # Other characters | |
# Matches any sequence of characters that aren't: | |
# - whitespace | |
# - Devanagari script | |
# - numbers | |
""", re.VERBOSE) | |
# Apply the regex pattern to the raw text to tokenize it | |
tokens = re.findall(gpt2pat, text) | |
# Convert tokens to byte sequences | |
byte_tokens = [token.encode('utf-8') for token in tokens] | |
# Create a list of byte sequences, each representing a token | |
tokens = [list(token) for token in byte_tokens] | |
def get_stats(token_list): | |
"""Count frequency of pairs across all tokens""" | |
counts = {} | |
# Count pairs within each token | |
for token in token_list: | |
if len(token) < 2: | |
continue | |
for pair in zip(token, token[1:]): | |
counts[pair] = counts.get(pair, 0) + 1 | |
return counts | |
def merge(token_list, pair, idx): | |
"""Merge all occurrences of pair within each token""" | |
newids = [] | |
for token in token_list: | |
if len(token) < 2: | |
newids.append(token) | |
continue | |
new_token = [] | |
i = 0 | |
while i < len(token): | |
if i < len(token) - 1 and (token[i], token[i+1]) == pair: | |
new_token.append(idx) | |
i += 2 | |
else: | |
new_token.append(token[i]) | |
i += 1 | |
newids.append(new_token) | |
return newids | |
def perform_bpe(): | |
vocab_size = 3500 # the desired final vocabulary size | |
num_merges = vocab_size - 256 | |
token_list = list(tokens) # copy so we don't destroy the original list | |
# Calculate total bytes before compression | |
total_bytes_before = sum(len(token) for token in token_list) | |
merges = {} # (int, int) -> int | |
for i in tqdm(range(num_merges), desc="Performing BPE", unit="merge"): | |
stats = get_stats(token_list) | |
if not stats: # No more pairs to merge | |
break | |
# Find most frequent pair | |
pair = max(stats, key=stats.get) | |
idx = 256 + i | |
# Perform the merge | |
token_list = merge(token_list, pair, idx) | |
merges[pair] = idx | |
# Calculate total bytes after compression | |
total_bytes_after = sum(len(token) for token in token_list) | |
print("---") | |
print("Total bytes before:", total_bytes_before) | |
print("Total bytes after:", total_bytes_after) | |
print(f"Compression ratio: {total_bytes_before / total_bytes_after:.2f}X") | |
# Flatten for storage, but maintain token boundaries | |
flat_ids = [] | |
for token in token_list: | |
flat_ids.extend(token) | |
return merges, flat_ids, num_merges | |
if __name__ == "__main__": | |
print('---') | |
print("length of text:", len(text)) | |
print('---') | |
print("length of tokens:", len(tokens)) | |
# Run BPE and save results | |
merges, ids, num_merges = perform_bpe() | |
# Save merges and vocab to a file | |
with open('bpe_results.pkl', 'wb') as f: | |
pickle.dump((merges, ids, num_merges), f) | |