atiwari751 commited on
Commit
fa753cb
·
2 Parent(s): fa76461 c9f3e85

removed pkl file to address merge conflict

Browse files
Files changed (5) hide show
  1. .gitignore +1 -1
  2. BPE.py +65 -26
  3. decoded_output.txt +3 -1
  4. encode_decode.py +38 -15
  5. encode_input.txt +3 -1
.gitignore CHANGED
@@ -3,4 +3,4 @@ __pycache__
3
  test.csv
4
  GPT2_encoder.py
5
  Hindi_Regex.txt
6
- Hindi_no_Regex.txt
 
3
  test.csv
4
  GPT2_encoder.py
5
  Hindi_Regex.txt
6
+ Hindi_no_Regex.txt
BPE.py CHANGED
@@ -1,51 +1,90 @@
1
  import pickle
2
- from tqdm import tqdm # Import tqdm for progress bar
 
3
 
4
  # Read text from a file
5
  with open('text_file_eng_long.txt', 'r', encoding='utf-8') as file:
6
  text = file.read()
7
 
8
- tokens = text.encode("utf-8") # raw bytes
9
- tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience
10
 
11
- def get_stats(ids):
 
 
 
 
 
 
 
 
 
 
12
  counts = {}
13
- for pair in zip(ids, ids[1:]):
14
- counts[pair] = counts.get(pair, 0) + 1
 
 
 
 
15
  return counts
16
 
17
- def merge(ids, pair, idx):
18
- # in the list of ints (ids), replace all consecutive occurrences of pair with the new token idx
19
  newids = []
20
- i = 0
21
- while i < len(ids):
22
- if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
23
- newids.append(idx)
24
- i += 2
25
- else:
26
- newids.append(ids[i])
27
- i += 1
 
 
 
 
 
 
 
28
  return newids
29
 
30
  def perform_bpe():
31
- vocab_size = 3500 # the desired final vocabulary size
32
  num_merges = vocab_size - 256
33
- ids = list(tokens) # copy so we don't destroy the original list
34
-
 
 
 
35
  merges = {} # (int, int) -> int
36
- # Use tqdm to add a progress bar
37
  for i in tqdm(range(num_merges), desc="Performing BPE", unit="merge"):
38
- stats = get_stats(ids)
 
 
 
 
39
  pair = max(stats, key=stats.get)
40
  idx = 256 + i
41
- ids = merge(ids, pair, idx)
 
 
42
  merges[pair] = idx
43
-
 
 
 
44
  print("---")
45
- print("ids length:", len(ids))
46
- print(f"compression ratio: {len(tokens) / len(ids):.2f}X")
 
 
 
 
 
 
47
 
48
- return merges, ids, num_merges
49
 
50
  if __name__ == "__main__":
51
  print('---')
 
1
  import pickle
2
+ import regex as re
3
+ from tqdm import tqdm
4
 
5
  # Read text from a file
6
  with open('text_file_eng_long.txt', 'r', encoding='utf-8') as file:
7
  text = file.read()
8
 
9
+ # Hindi-focused pattern
10
+ gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{N}+| ?(?:[\u0904-\u0939\u093d-\u093d\u0950-\u0950\u0958-\u0961\u0970-\u097f\ua8f2-\ua8fe\U00011b00-\U00011b09\u1cd3-\u1cd3\u1ce9-\u1cec\u1cee-\u1cf3\u1cf5-\u1cf6\u1cfa-\u1cfa][\u0900-\u0903\u093a-\u093c\u093e-\u094f\u0951-\u0957\u0962-\u0963\ua8e0-\ua8f1\ua8ff-\ua8ff\u1cd0-\u1cd2\u1cd4-\u1ce8\u1ced-\u1ced\u1cf4-\u1cf4\u1cf7-\u1cf9]*)+| ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
11
 
12
+ # Apply the regex pattern to the raw text to tokenize it
13
+ tokens = re.findall(gpt2pat, text)
14
+
15
+ # Convert tokens to byte sequences
16
+ byte_tokens = [token.encode('utf-8') for token in tokens]
17
+
18
+ # Create a list of byte sequences, each representing a token
19
+ tokens = [list(token) for token in byte_tokens]
20
+
21
+ def get_stats(token_list):
22
+ """Count frequency of pairs across all tokens"""
23
  counts = {}
24
+ # Count pairs within each token
25
+ for token in token_list:
26
+ if len(token) < 2:
27
+ continue
28
+ for pair in zip(token, token[1:]):
29
+ counts[pair] = counts.get(pair, 0) + 1
30
  return counts
31
 
32
+ def merge(token_list, pair, idx):
33
+ """Merge all occurrences of pair within each token"""
34
  newids = []
35
+ for token in token_list:
36
+ if len(token) < 2:
37
+ newids.append(token)
38
+ continue
39
+
40
+ new_token = []
41
+ i = 0
42
+ while i < len(token):
43
+ if i < len(token) - 1 and (token[i], token[i+1]) == pair:
44
+ new_token.append(idx)
45
+ i += 2
46
+ else:
47
+ new_token.append(token[i])
48
+ i += 1
49
+ newids.append(new_token)
50
  return newids
51
 
52
  def perform_bpe():
53
+ vocab_size = 4000 # the desired final vocabulary size
54
  num_merges = vocab_size - 256
55
+ token_list = list(tokens) # copy so we don't destroy the original list
56
+
57
+ # Calculate total bytes before compression
58
+ total_bytes_before = sum(len(token) for token in token_list)
59
+
60
  merges = {} # (int, int) -> int
 
61
  for i in tqdm(range(num_merges), desc="Performing BPE", unit="merge"):
62
+ stats = get_stats(token_list)
63
+ if not stats: # No more pairs to merge
64
+ break
65
+
66
+ # Find most frequent pair
67
  pair = max(stats, key=stats.get)
68
  idx = 256 + i
69
+
70
+ # Perform the merge
71
+ token_list = merge(token_list, pair, idx)
72
  merges[pair] = idx
73
+
74
+ # Calculate total bytes after compression
75
+ total_bytes_after = sum(len(token) for token in token_list)
76
+
77
  print("---")
78
+ print("Total bytes before:", total_bytes_before)
79
+ print("Total bytes after:", total_bytes_after)
80
+ print(f"Compression ratio: {total_bytes_before / total_bytes_after:.2f}X")
81
+
82
+ # Flatten for storage, but maintain token boundaries
83
+ flat_ids = []
84
+ for token in token_list:
85
+ flat_ids.extend(token)
86
 
87
+ return merges, flat_ids, num_merges
88
 
89
  if __name__ == "__main__":
90
  print('---')
decoded_output.txt CHANGED
@@ -1 +1,3 @@
1
- Th ere 's a ch anc e this is not work ing, is n 't it ? Th ere ' re many pa per s, why will this work ? I ' ve got to make su re . I ' m now thin king some thing 's w ron g . I t 'll be sa d if there 's something w r ong and I mis s it, I'll be sor r y. I t 'd bet ter be re view ed well , I 'd want to be cer tain .
 
 
 
1
+ अम ज द के परिवार की तीन पी ढ़ ियां चांद नी चौ क निर्� � ाचन क्षेत्र में ह वे ली आज ़ म ख ां के नाम से पहच ाने जाने वाले एक दम स ट कर बने घ रों के झ ुण ्ड में रहती हैं . यह इलाक ा दिल्ली की ऐ त िहास िक ज ामा मस्जिद से पै दल की द ूरी पर है , और इस परिवार के 23 सदस्य मतदान केंद्र 10 पर प ंज ीक ृत मत द ाता हैं . लेकिन पिछले साल लोकसभा चुनावों के दौरान अम ज द को पता चला कि वह अपने परिवार के उन 20 लोगों में से एक हैं , जिन का नाम मत द ाता सू ची से इस वजह से काट दिया गया कि उन्होंने अपना घर बदल लिया है .
2
+
3
+ 5 5 वर्षीय अम ज द ने न्यू ज़ ल ॉन ्ड ्री को बताया , " हम ारे सामने ये पहली बार हुआ है . लेकिन नाम कट ने के बारे में सबसे ज्यादा निर ाश ाज न क बात ये थी कि इसका पता मतदान के दिन ही चला . जब हम पहली बार बू थ 10 पर गए तो उन्होंने हमें बताया कि उन्हें मत द ाता सू ची में हमारा नाम नहीं मिला . इसलिए हमें ज ामा मस्जिद में किसी दूसरे बू थ पर जाकर देखना चाहिए . वहां से हमें दूसरे बू थ पर भेज दिया गया . इस तरह हमने पांच से छह बू थ ों का दौर ा किया . और फिर अंत में हमें जो कारण बताया गया , वो यह था कि शायद घर - घ र जाकर सर्वे क्षण के दौरान बी एल ओ ( ब ू थ ले वल ऑफिस र ) को हम घर पर नहीं मिले इसलिए उन्होंने हमारे नाम काट दिए .”
encode_decode.py CHANGED
@@ -1,10 +1,14 @@
1
  import pickle
2
  from BPE import get_stats, merge
 
3
 
4
  # Load merges and vocab from the file
5
  with open('bpe_results.pkl', 'rb') as f:
6
  merges, ids, num_merges = pickle.load(f)
7
 
 
 
 
8
  vocab = {idx: bytes([idx]) for idx in range(256)}
9
  for (p0, p1), idx in merges.items():
10
  vocab[idx] = vocab[p0] + vocab[p1]
@@ -12,7 +16,7 @@ for (p0, p1), idx in merges.items():
12
  def decode(ids):
13
  # given ids (list of integers), return Python string
14
  tokens = [vocab[idx].decode("utf-8", errors="replace") for idx in ids]
15
- text = ' '.join(tokens) # Join tokens with a single space
16
 
17
  # Write the decoded text to a new file
18
  with open('decoded_output.txt', 'w', encoding='utf-8') as f:
@@ -21,7 +25,7 @@ def decode(ids):
21
  return text
22
 
23
  # Example: Decode a list of IDs
24
- set_of_ids = [312, 1366, 565, 278, 302, 717, 256, 429, 1496, 1687, 808, 411, 110, 2862, 289, 670, 312, 1366, 39, 1281, 1191, 2358, 456, 374, 2453, 574, 429, 1687, 670, 73, 39, 353, 1176, 286, 904, 367, 279, 2310, 39, 695, 1398, 999, 806, 1271, 3455, 565, 119, 1902, 103, 2310, 116, 851, 403, 379, 260, 846, 2713, 565, 3466, 119, 114, 588, 292, 360, 1263, 258, 1285, 1402, 403, 3305, 114, 1278, 73, 116, 887, 773, 363, 403, 279, 2035, 274, 1150, 3273, 887, 2398, 1219, 1031, 2514, 46]
25
  decoded_text = decode(set_of_ids) # Pass the list of IDs
26
  print(decoded_text)
27
 
@@ -30,18 +34,37 @@ def encode():
30
  with open('encode_input.txt', 'r', encoding='utf-8') as f:
31
  text = f.read()
32
 
33
- # given a string, return list of integers (the tokens)
34
- tokens = list(text.encode("utf-8"))
35
- while len(tokens) >= 2:
36
- stats = get_stats(tokens)
37
- pair = min(stats, key=lambda p: merges.get(p, float("inf")))
38
- if pair not in merges:
39
- break # nothing else can be merged
40
- idx = merges[pair]
41
- tokens = merge(tokens, pair, idx)
42
-
43
- return tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Example: Encode text from a file
46
- encoded_tokens = encode()
47
- print(encoded_tokens)
 
1
  import pickle
2
  from BPE import get_stats, merge
3
+ import regex as re
4
 
5
  # Load merges and vocab from the file
6
  with open('bpe_results.pkl', 'rb') as f:
7
  merges, ids, num_merges = pickle.load(f)
8
 
9
+ # Define the GPT-2 regex pattern (same as in BPE.py)
10
+ gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{N}+| ?(?:[\u0904-\u0939\u093d-\u093d\u0950-\u0950\u0958-\u0961\u0970-\u097f\ua8f2-\ua8fe\U00011b00-\U00011b09\u1cd3-\u1cd3\u1ce9-\u1cec\u1cee-\u1cf3\u1cf5-\u1cf6\u1cfa-\u1cfa][\u0900-\u0903\u093a-\u093c\u093e-\u094f\u0951-\u0957\u0962-\u0963\ua8e0-\ua8f1\ua8ff-\ua8ff\u1cd0-\u1cd2\u1cd4-\u1ce8\u1ced-\u1ced\u1cf4-\u1cf4\u1cf7-\u1cf9]*)+| ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
11
+
12
  vocab = {idx: bytes([idx]) for idx in range(256)}
13
  for (p0, p1), idx in merges.items():
14
  vocab[idx] = vocab[p0] + vocab[p1]
 
16
  def decode(ids):
17
  # given ids (list of integers), return Python string
18
  tokens = [vocab[idx].decode("utf-8", errors="replace") for idx in ids]
19
+ text = '\t'.join(tokens) # Join tokens with tabs
20
 
21
  # Write the decoded text to a new file
22
  with open('decoded_output.txt', 'w', encoding='utf-8') as f:
 
25
  return text
26
 
27
  # Example: Decode a list of IDs
28
+ set_of_ids = [2342, 307, 295, 286, 1413, 302, 839, 644, 574, 982, 3877, 405, 1086, 272, 978, 181, 3927, 1171, 294, 274, 964, 438, 767, 337, 284, 361, 332, 286, 776, 315, 2331, 429, 841, 631, 385, 1694, 273, 310, 418, 1607, 445, 935, 286, 962, 1244, 698, 294, 3069, 347, 46, 450, 1462, 259, 646, 302, 554, 276, 2252, 334, 292, 2835, 2500, 315, 1006, 3367, 302, 296, 1299, 330, 289, 44, 327, 345, 1413, 286, 2911, 1906, 2592, 1322, 888, 330, 279, 711, 1474, 997, 1068, 295, 1236, 347, 46, 513, 1067, 579, 1194, 2596, 286, 847, 732, 307, 295, 309, 1423, 1953, 340, 555, 563, 1413, 286, 376, 466, 596, 294, 315, 385, 347, 44, 1001, 478, 776, 1068, 295, 1236, 919, 1216, 315, 345, 1115, 315, 3189, 481, 437, 340, 557, 1125, 1135, 1501, 857, 289, 46, 10, 10, 53, 53, 2794, 732, 307, 295, 317, 2705, 2246, 280, 1308, 698, 486, 309, 739, 44, 32, 34, 808, 830, 1015, 516, 1315, 544, 667, 289, 46, 513, 776, 1914, 311, 286, 948, 294, 856, 915, 2438, 658, 367, 271, 272, 564, 516, 472, 340, 1571, 1423, 2592, 286, 638, 416, 1953, 46, 586, 462, 1315, 544, 3075, 583, 888, 330, 588, 444, 557, 1448, 739, 340, 737, 1068, 295, 1236, 919, 1216, 294, 3253, 776, 391, 1410, 46, 1496, 1448, 292, 2835, 2500, 294, 738, 1374, 3075, 583, 330, 2660, 3252, 904, 46, 1441, 315, 1448, 1374, 3075, 583, 330, 1473, 481, 437, 46, 345, 778, 1758, 1307, 315, 2210, 3075, 583, 299, 333, 751, 259, 420, 46, 327, 766, 1200, 294, 1448, 499, 1394, 739, 437, 44, 707, 450, 413, 340, 3602, 1135, 45, 864, 261, 2660, 2749, 1930, 286, 847, 447, 1782, 1633, 510, 308, 306, 583, 399, 1508, 2632, 261, 41, 309, 462, 1135, 330, 391, 1193, 1496, 557, 1574, 776, 3189, 1340, 3435]
29
  decoded_text = decode(set_of_ids) # Pass the list of IDs
30
  print(decoded_text)
31
 
 
34
  with open('encode_input.txt', 'r', encoding='utf-8') as f:
35
  text = f.read()
36
 
37
+ # Tokenize the text using the regex pattern
38
+ tokens = re.findall(gpt2pat, text)
39
+
40
+ # Convert tokens to byte sequences and maintain grouping
41
+ byte_tokens = [token.encode('utf-8') for token in tokens]
42
+ token_list = [list(token) for token in byte_tokens]
43
+
44
+ # Calculate total bytes before compression
45
+ total_bytes_before = sum(len(token) for token in token_list)
46
+
47
+ # Process each token
48
+ final_tokens = []
49
+ for token in token_list:
50
+ current_token = list(token)
51
+ while len(current_token) >= 2:
52
+ stats = get_stats([current_token])
53
+ if not stats:
54
+ break
55
+ pair = min(stats, key=lambda p: merges.get(p, float("inf")))
56
+ if pair not in merges:
57
+ break
58
+ idx = merges[pair]
59
+ current_token = merge([current_token], pair, idx)[0]
60
+ final_tokens.extend(current_token)
61
+
62
+ # Calculate compression ratio
63
+ compression_ratio = total_bytes_before / len(final_tokens)
64
+ print(f"Compression ratio: {compression_ratio:.2f}X")
65
+
66
+ return final_tokens, compression_ratio
67
 
68
  # Example: Encode text from a file
69
+ encoded_tokens, ratio = encode()
70
+ print(f"Encoded tokens: {encoded_tokens}")
encode_input.txt CHANGED
@@ -1 +1,3 @@
1
- There's a chance this is not working, isn't it? There're many papers, why will this work? I've got to make sure. I'm now thinking something's wrong. It'll be sad if there's something wrong and I miss it, I'll be sorry. It'd better be reviewed well, I'd want to be certain.
 
 
 
1
+ अमजद के परिवार की तीन पीढ़ियां चांदनी चौक निर्वाचन क्षेत्र में हवेली आज़म खां के नाम से पहचाने जाने वाले एकदम सटकर बने घरों के झुण्ड में रहती हैं. यह इलाका दिल्ली की ऐतिहासिक जामा मस्जिद से पैदल की दूरी पर है, और इस परिवार के 23 सदस्य मतदान केंद्र 10 पर पंजीकृत मतदाता हैं. लेकिन पिछले साल लोकसभा चुनावों के दौरान अमजद को पता चला कि वह अपने परिवार के उन 20 लोगों में से एक हैं, जिनका नाम मतदाता सूची से इस वजह से काट दिया गया कि उन्होंने अपना घर बदल लिया है.
2
+
3
+ 55 वर्षीय अमजद ने न्यूज़लॉन्ड्री को बताया, "हमारे सामने ये पहली बार हुआ है. लेकिन नाम कटने के बारे में सबसे ज्यादा निराशाजनक बात ये थी कि इसका पता मतदान के दिन ही चला. जब हम पहली बार बूथ 10 पर गए तो उन्होंने हमें बताया कि उन्हें मतदाता सूची में हमारा नाम नहीं मिला. इसलिए हमें जामा मस्जिद में किसी दूसरे बूथ पर जाकर देखना चाहिए. वहां से हमें दूसरे बूथ पर भेज दिया गया. इस तरह हमने पांच से छह बूथों का दौरा किया. और फिर अंत में हमें जो कारण बताया गया, वो यह था कि शायद घर-घर जाकर सर्वेक्षण के दौरान बीएलओ (बूथ लेवल ऑफिसर) को हम घर पर नहीं मिले इसलिए उन्होंने हमारे नाम काट दिए.”