atiwari751 commited on
Commit
781de59
·
1 Parent(s): cae9627

regex on byte sequences

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. BPE.py +7 -7
  3. encode_decode.py +14 -6
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  .venv
2
  __pycache__
3
  test.csv
 
 
1
  .venv
2
  __pycache__
3
  test.csv
4
+ GPT2_encoder.py
BPE.py CHANGED
@@ -1,6 +1,6 @@
1
  import pickle
2
  import regex as re
3
- from tqdm import tqdm # Import tqdm for progress bar
4
 
5
  # Read text from a file
6
  with open('text_file_eng.txt', 'r', encoding='utf-8') as file:
@@ -9,12 +9,14 @@ with open('text_file_eng.txt', 'r', encoding='utf-8') as file:
9
  # Define the GPT-2 regex pattern
10
  gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
11
 
12
- # Apply the regex pattern to tokenize the text
13
  tokens = re.findall(gpt2pat, text)
14
 
15
- # Convert tokens to a list of integers in range 0..255 for convenience
16
- tokens = [ord(char) for token in tokens for char in token]
17
- #print(tokens)
 
 
18
 
19
  def get_stats(ids):
20
  counts = {}
@@ -23,7 +25,6 @@ def get_stats(ids):
23
  return counts
24
 
25
  def merge(ids, pair, idx):
26
- # in the list of ints (ids), replace all consecutive occurrences of pair with the new token idx
27
  newids = []
28
  i = 0
29
  while i < len(ids):
@@ -41,7 +42,6 @@ def perform_bpe():
41
  ids = list(tokens) # copy so we don't destroy the original list
42
 
43
  merges = {} # (int, int) -> int
44
- # Use tqdm to add a progress bar
45
  for i in tqdm(range(num_merges), desc="Performing BPE", unit="merge"):
46
  stats = get_stats(ids)
47
  pair = max(stats, key=stats.get)
 
1
  import pickle
2
  import regex as re
3
+ from tqdm import tqdm
4
 
5
  # Read text from a file
6
  with open('text_file_eng.txt', 'r', encoding='utf-8') as file:
 
9
  # Define the GPT-2 regex pattern
10
  gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
11
 
12
+ # Tokenize the text using the regex pattern
13
  tokens = re.findall(gpt2pat, text)
14
 
15
+ # Convert tokens to byte sequences
16
+ byte_tokens = [token.encode('utf-8') for token in tokens]
17
+
18
+ # Flatten the list of byte sequences into a single list of bytes
19
+ tokens = [b for token in byte_tokens for b in token]
20
 
21
  def get_stats(ids):
22
  counts = {}
 
25
  return counts
26
 
27
  def merge(ids, pair, idx):
 
28
  newids = []
29
  i = 0
30
  while i < len(ids):
 
42
  ids = list(tokens) # copy so we don't destroy the original list
43
 
44
  merges = {} # (int, int) -> int
 
45
  for i in tqdm(range(num_merges), desc="Performing BPE", unit="merge"):
46
  stats = get_stats(ids)
47
  pair = max(stats, key=stats.get)
encode_decode.py CHANGED
@@ -1,5 +1,6 @@
1
  import pickle
2
  from BPE import get_stats, merge
 
3
 
4
  # Load merges and vocab from the file
5
  with open('bpe_results.pkl', 'rb') as f:
@@ -11,8 +12,8 @@ for (p0, p1), idx in merges.items():
11
 
12
  def decode(ids):
13
  # given ids (list of integers), return Python string
14
- tokens = [vocab[idx].decode("utf-8", errors="replace") for idx in ids]
15
- text = ' '.join(tokens) # Join tokens with a single space
16
 
17
  # Write the decoded text to a new file
18
  with open('decoded_output.txt', 'w', encoding='utf-8') as f:
@@ -30,8 +31,15 @@ def encode():
30
  with open('encode_input.txt', 'r', encoding='utf-8') as f:
31
  text = f.read()
32
 
33
- # given a string, return list of integers (the tokens)
34
- tokens = list(text.encode("utf-8"))
 
 
 
 
 
 
 
35
  while len(tokens) >= 2:
36
  stats = get_stats(tokens)
37
  pair = min(stats, key=lambda p: merges.get(p, float("inf")))
@@ -43,5 +51,5 @@ def encode():
43
  return tokens
44
 
45
  # Example: Encode text from a file
46
- #encoded_tokens = encode()
47
- #print(encoded_tokens)
 
1
  import pickle
2
  from BPE import get_stats, merge
3
+ import regex as re
4
 
5
  # Load merges and vocab from the file
6
  with open('bpe_results.pkl', 'rb') as f:
 
12
 
13
  def decode(ids):
14
  # given ids (list of integers), return Python string
15
+ tokens = [vocab[idx] for idx in ids]
16
+ text = b''.join(tokens).decode("utf-8", errors="replace")
17
 
18
  # Write the decoded text to a new file
19
  with open('decoded_output.txt', 'w', encoding='utf-8') as f:
 
31
  with open('encode_input.txt', 'r', encoding='utf-8') as f:
32
  text = f.read()
33
 
34
+ # Tokenize the text using the regex pattern
35
+ tokens = re.findall(gpt2pat, text)
36
+
37
+ # Convert tokens to byte sequences
38
+ byte_tokens = [token.encode('utf-8') for token in tokens]
39
+
40
+ # Flatten the list of byte sequences into a single list of bytes
41
+ tokens = [b for token in byte_tokens for b in token]
42
+
43
  while len(tokens) >= 2:
44
  stats = get_stats(tokens)
45
  pair = min(stats, key=lambda p: merges.get(p, float("inf")))
 
51
  return tokens
52
 
53
  # Example: Encode text from a file
54
+ encoded_tokens = encode()
55
+ print(encoded_tokens)