Spaces:
Sleeping
Sleeping
File size: 1,493 Bytes
d8b92ee c128a5f d8b92ee 1e8ebcb c128a5f d8b92ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import pickle
from BPE import get_stats, merge
# Load merges and vocab from the file
with open('bpe_results.pkl', 'rb') as f:
merges, ids, num_merges = pickle.load(f)
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
def decode(ids):
# given ids (list of integers), return Python string
tokens = [vocab[idx].decode("utf-8", errors="replace") for idx in ids]
text = ' '.join(tokens) # Join tokens with a single space
# Write the decoded text to a new file
with open('decoded_output.txt', 'w', encoding='utf-8') as f:
f.write(text)
return text
# Example: Decode a list of IDs
set_of_ids = [25, 345, 992, 1353]
decoded_text = decode(set_of_ids) # Pass the list of IDs
print(decoded_text)
def encode():
# Read input text from a new file
with open('encode_input.txt', 'r', encoding='utf-8') as f:
text = f.read()
# given a string, return list of integers (the tokens)
tokens = list(text.encode("utf-8"))
while len(tokens) >= 2:
stats = get_stats(tokens)
pair = min(stats, key=lambda p: merges.get(p, float("inf")))
if pair not in merges:
break # nothing else can be merged
idx = merges[pair]
tokens = merge(tokens, pair, idx)
return tokens
# Example: Encode text from a file
#encoded_tokens = encode()
#print(encoded_tokens) |