Spaces:
Sleeping
Sleeping
import regex as re | |
# Read text from a file | |
with open('text_file.txt', 'r', encoding='utf-8') as file: | |
text = file.read() | |
tokens = text.encode("utf-8") # raw bytes | |
tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience | |
print('---') | |
print("length of text:", len(text)) | |
print('---') | |
#print(tokens) | |
print("length of tokens:", len(tokens)) | |
def get_stats(ids): | |
counts = {} | |
for pair in zip(ids, ids[1:]): | |
counts[pair] = counts.get(pair, 0) + 1 | |
return counts | |
def merge(ids, pair, idx): | |
# in the list of ints (ids), replace all consecutive occurrences of pair with the new token idx | |
newids = [] | |
i = 0 | |
while i < len(ids): | |
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: | |
newids.append(idx) | |
i += 2 | |
else: | |
newids.append(ids[i]) | |
i += 1 | |
return newids | |
# --- | |
vocab_size = 500 # the desired final vocabulary size | |
num_merges = vocab_size - 256 | |
ids = list(tokens) # copy so we don't destroy the original list | |
merges = {} # (int, int) -> int | |
for i in range(num_merges): | |
stats = get_stats(ids) | |
pair = max(stats, key=stats.get) | |
idx = 256 + i | |
#print(f"merging {pair} into a new token {idx}") | |
ids = merge(ids, pair, idx) | |
merges[pair] = idx | |
#print("tokens length:", len(tokens)) | |
#print(ids) | |
print("---") | |
print("ids length:", len(ids)) | |
print(f"compression ratio: {len(tokens) / len(ids):.2f}X") | |