import pickle from BPE import get_stats, merge # Load merges and vocab from the file with open('bpe_results.pkl', 'rb') as f: merges, ids, num_merges = pickle.load(f) vocab = {idx: bytes([idx]) for idx in range(256)} for (p0, p1), idx in merges.items(): vocab[idx] = vocab[p0] + vocab[p1] def decode(ids): # given ids (list of integers), return Python string tokens = [vocab[idx].decode("utf-8", errors="replace") for idx in ids] text = ' '.join(tokens) # Join tokens with a single space # Write the decoded text to a new file with open('decoded_output.txt', 'w', encoding='utf-8') as f: f.write(text) return text # Example: Decode a list of IDs set_of_ids = [25, 345, 992, 1353] decoded_text = decode(set_of_ids) # Pass the list of IDs print(decoded_text) def encode(): # Read input text from a new file with open('encode_input.txt', 'r', encoding='utf-8') as f: text = f.read() # given a string, return list of integers (the tokens) tokens = list(text.encode("utf-8")) while len(tokens) >= 2: stats = get_stats(tokens) pair = min(stats, key=lambda p: merges.get(p, float("inf"))) if pair not in merges: break # nothing else can be merged idx = merges[pair] tokens = merge(tokens, pair, idx) return tokens # Example: Encode text from a file #encoded_tokens = encode() #print(encoded_tokens)