kjcjohnson commited on
Commit
901bbd9
·
1 Parent(s): 9fbf7f9

Add GAD libraries

Browse files
Files changed (30) hide show
  1. transformers_gad/__init__.py +5 -0
  2. transformers_gad/__pycache__/__init__.cpython-311.pyc +0 -0
  3. transformers_gad/__pycache__/grammar_utils.cpython-311.pyc +0 -0
  4. transformers_gad/__pycache__/logging_config.cpython-311.pyc +0 -0
  5. transformers_gad/__pycache__/mapping.cpython-311.pyc +0 -0
  6. transformers_gad/__pycache__/parser.cpython-311.pyc +0 -0
  7. transformers_gad/__pycache__/recognizer.cpython-311.pyc +0 -0
  8. transformers_gad/__pycache__/token_grammar_recognizer.cpython-311.pyc +0 -0
  9. transformers_gad/__pycache__/trie.cpython-311.pyc +0 -0
  10. transformers_gad/__pycache__/utf8_utils.cpython-311.pyc +0 -0
  11. transformers_gad/__pycache__/utils.cpython-311.pyc +0 -0
  12. transformers_gad/__pycache__/vocab_struct.cpython-311.pyc +0 -0
  13. transformers_gad/generation/__init__.py +1 -0
  14. transformers_gad/generation/__pycache__/__init__.cpython-311.pyc +0 -0
  15. transformers_gad/generation/__pycache__/logits_process.cpython-311.pyc +0 -0
  16. transformers_gad/generation/logits_process.py +348 -0
  17. transformers_gad/grammar_utils.py +4 -0
  18. transformers_gad/logging_config.py +18 -0
  19. transformers_gad/mapping.py +209 -0
  20. transformers_gad/oracle/__init_.py +1 -0
  21. transformers_gad/oracle/__pycache__/oracle_trie.cpython-311.pyc +0 -0
  22. transformers_gad/oracle/oracle_trie.py +261 -0
  23. transformers_gad/parser.py +576 -0
  24. transformers_gad/parser_cfg.py +530 -0
  25. transformers_gad/recognizer.py +456 -0
  26. transformers_gad/token_grammar_recognizer.py +322 -0
  27. transformers_gad/trie.py +194 -0
  28. transformers_gad/utf8_utils.py +170 -0
  29. transformers_gad/utils.py +98 -0
  30. transformers_gad/vocab_struct.py +83 -0
transformers_gad/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .logging_config import setup_logging
2
+
3
+ setup_logging()
4
+
5
+ __version__ = "0.1.2"
transformers_gad/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (318 Bytes). View file
 
transformers_gad/__pycache__/grammar_utils.cpython-311.pyc ADDED
Binary file (330 Bytes). View file
 
transformers_gad/__pycache__/logging_config.cpython-311.pyc ADDED
Binary file (964 Bytes). View file
 
transformers_gad/__pycache__/mapping.cpython-311.pyc ADDED
Binary file (14.7 kB). View file
 
transformers_gad/__pycache__/parser.cpython-311.pyc ADDED
Binary file (24.1 kB). View file
 
transformers_gad/__pycache__/recognizer.cpython-311.pyc ADDED
Binary file (20.2 kB). View file
 
transformers_gad/__pycache__/token_grammar_recognizer.cpython-311.pyc ADDED
Binary file (16.3 kB). View file
 
transformers_gad/__pycache__/trie.cpython-311.pyc ADDED
Binary file (8.37 kB). View file
 
transformers_gad/__pycache__/utf8_utils.cpython-311.pyc ADDED
Binary file (6.17 kB). View file
 
transformers_gad/__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.66 kB). View file
 
transformers_gad/__pycache__/vocab_struct.cpython-311.pyc ADDED
Binary file (4.44 kB). View file
 
transformers_gad/generation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .logits_process import GrammarConstrainedLogitsProcessor, GrammarAlignedOracleLogitsProcessor
transformers_gad/generation/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (343 Bytes). View file
 
transformers_gad/generation/__pycache__/logits_process.cpython-311.pyc ADDED
Binary file (17.2 kB). View file
 
transformers_gad/generation/logits_process.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch.nn.functional as F
4
+
5
+ import torch
6
+ import logging
7
+ from transformers.generation.logits_process import (
8
+ LogitsProcessor,
9
+ LOGITS_PROCESSOR_INPUTS_DOCSTRING,
10
+ )
11
+ from transformers.utils import add_start_docstrings
12
+ from transformers_gad.grammar_utils import IncrementalGrammarConstraint
13
+ from transformers_gad.oracle.oracle_trie import Trie
14
+
15
+ class GrammarConstrainedLogitsProcessor(LogitsProcessor):
16
+ def __init__(self, grammar_constraint, parse_start_index=None, save_log=False):
17
+ # Parser variables
18
+ self.grammar_constraint = grammar_constraint
19
+ self.batch_parsing_states = None
20
+ self.parse_start_index = parse_start_index
21
+
22
+ # To start with a longer prefix in enumerative search
23
+ self.generate_start_index = None
24
+ self.generated_tokens = None
25
+
26
+ # Generation Log
27
+ self.save_log = save_log
28
+ self.history = []
29
+
30
+ def reset(self):
31
+ self.reset_parser()
32
+ self.reset_history()
33
+
34
+ def reset_parser(self):
35
+ self.batch_parsing_states = None
36
+ if self.grammar_constraint.is_incremental:
37
+ self.grammar_constraint.reset()
38
+
39
+ self.generate_start_index = None
40
+ self.generated_tokens = None
41
+
42
+ def reset_history(self):
43
+ self.history = []
44
+
45
+ def mask_scores(self, scores, device):
46
+ """
47
+ resolve each stack to a tensor of True/False for each token
48
+ indicating acceptance
49
+ """
50
+ masked_scores = scores.clone()
51
+ acceptance = self.grammar_constraint.batch_filter_vocab(
52
+ self.batch_parsing_states, device
53
+ )
54
+
55
+ if self.save_log:
56
+ self.store_detailed_history(acceptance, scores)
57
+
58
+ # Scores to -inf where False
59
+ masked_scores[~acceptance] = -math.inf
60
+
61
+ return masked_scores
62
+
63
+ def process_scores(self, input_ids, scores):
64
+ # we dynamically create stacks at the first call, so that we know the batch size and beam size
65
+ if self.batch_parsing_states is None:
66
+ self.batch_parsing_states = [
67
+ copy.deepcopy(
68
+ self.grammar_constraint.string_recognizer.get_initial_accept_state()
69
+ )
70
+ for _ in range(len(input_ids))
71
+ ]
72
+
73
+ # assume the generation starts from the same index
74
+ if self.generate_start_index is None:
75
+ # the default is the end of input sequence of tokens
76
+ self.generate_start_index = self.parse_start_index \
77
+ if self.parse_start_index else input_ids.size(1)
78
+ self.generated_tokens = input_ids[:, self.generate_start_index:]
79
+
80
+ # Advance parser states
81
+ self.batch_parsing_states = self.grammar_constraint.advance_token_ids(
82
+ input_ids, self.batch_parsing_states, self.parse_start_index
83
+ )
84
+
85
+ masked_scores = self.mask_scores(scores, scores.device)
86
+ return masked_scores
87
+
88
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
89
+ def __call__(
90
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
91
+ ) -> torch.FloatTensor:
92
+ return self.process_scores(input_ids, scores)
93
+
94
+ def reset_parser(self):
95
+ self.batch_parsing_states = None
96
+ if isinstance(self.grammar_constraint, IncrementalGrammarConstraint):
97
+ self.grammar_constraint.reset()
98
+
99
+ def get_accepted_tokens(self, acceptance):
100
+ """
101
+ Get the indices of accepted tokens and their corresponding string values for each item in the batch.
102
+
103
+ Parameters:
104
+ - acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch.
105
+ """
106
+ batch_size, _ = acceptance.shape
107
+ acceptance_np = acceptance.cpu().numpy()
108
+ accepted_x, accepted_y = acceptance_np.nonzero()
109
+
110
+ # Initialize the dictionary with empty lists for indices
111
+ accepted_token_indices = {i: [] for i in range(batch_size)}
112
+ for x, y in zip(accepted_x, accepted_y):
113
+ accepted_token_indices[x].append(y)
114
+
115
+ # Convert token IDs to tokens
116
+ accepted_tokens = {
117
+ i: [self.grammar_constraint.tokenizer.decode([token_id]) for token_id in token_ids]
118
+ for i, token_ids in accepted_token_indices.items()
119
+ }
120
+
121
+ return accepted_tokens
122
+
123
+ def store_detailed_history(self, acceptance, scores):
124
+ """
125
+ Processes and stores information for accepted tokens including their IDs, tokens,
126
+ raw scores, and logits.
127
+
128
+ Parameters:
129
+ - acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch.
130
+ - scores (torch.Tensor): The raw scores from the model output.
131
+ - adjusted_scores (torch.Tensor): The adjusted scores after applying expected future grammaticality.
132
+ """
133
+ likelihoods = F.softmax(scores, dim=-1)
134
+
135
+ # Initializing the list to store detailed information for each step
136
+ batch_accepted_info = []
137
+
138
+ for batch_index in range(acceptance.size(0)): # Iterate over batch items
139
+ accepted_info = []
140
+ accepted_indices = acceptance[batch_index].nonzero().squeeze(-1)
141
+
142
+ for idx in accepted_indices:
143
+ token_id = idx.item()
144
+ raw_score = scores[batch_index, idx].item()
145
+ likelihood = likelihoods[batch_index, idx].item()
146
+ token = self.grammar_constraint.tokenizer.decode([token_id])
147
+
148
+ # Store detailed information as a dictionary
149
+ accepted_info.append({
150
+ "token_id": token_id,
151
+ "token": str(token),
152
+ "raw_score": raw_score,
153
+ "raw_likelihood": likelihood
154
+ })
155
+
156
+ batch_accepted_info.append(accepted_info)
157
+
158
+ # Store this detailed information in the history
159
+ self.history.append(batch_accepted_info)
160
+
161
+ class GrammarAlignedOracleLogitsProcessor(LogitsProcessor):
162
+ def __init__(self, grammar_constraint, oracle_trie=Trie(), parse_start_index=None, save_log=False):
163
+ # Parser variables
164
+ self.grammar_constraint = grammar_constraint
165
+ self.batch_parsing_states = None
166
+ self.parse_start_index = parse_start_index
167
+
168
+ # ASAp oracle trie
169
+ self.oracle_trie = oracle_trie
170
+
171
+ # To start with a longer prefix in enumerative search
172
+ self.generate_start_index = None
173
+ self.generated_tokens = None
174
+
175
+ # Generation Log
176
+ self.save_log = save_log
177
+ self.history = []
178
+
179
+ def adjust_scores(self, scores, device):
180
+ """
181
+ resolve each stack to a tensor of True/False for each token
182
+ indicating acceptance
183
+ """
184
+ acceptance = self.grammar_constraint.batch_filter_vocab(
185
+ self.batch_parsing_states, device
186
+ )
187
+
188
+ current_parent = self.oracle_trie.search_last_parent(self.generated_tokens)
189
+ current_parent.insert_accepted_tokens(scores, acceptance)
190
+ adjusted_scores = self.apply_oracle_adjustments(acceptance, scores, current_parent)
191
+
192
+ if self.save_log:
193
+ self.store_detailed_history(acceptance, scores, adjusted_scores)
194
+
195
+ # Scores to -inf where False
196
+ adjusted_scores[~acceptance] = -math.inf
197
+
198
+ return adjusted_scores
199
+
200
+ def apply_oracle_adjustments(self, acceptance, scores, current_parent):
201
+ """
202
+ Multiply expected future grammarticality
203
+ Use the normalized (and unmasked) probabiltiy
204
+
205
+ Parameters:
206
+ - acceptance (torch.Tensor): A characteristic vector of valid tokens
207
+ used to updated only valid tokens
208
+ - scores (torch.Tensor): Unnormalized logits from language model
209
+ - current_parent (TrieNode): The trie node for the current prefix
210
+ """
211
+ adjusted_scores = scores.clone()
212
+ likelihoods = F.softmax(adjusted_scores, dim=-1)
213
+ log_likelihoods = torch.log(likelihoods)
214
+
215
+ for batch_index in range(acceptance.size(0)):
216
+ accepted_indices = acceptance[batch_index].nonzero().squeeze(-1)
217
+
218
+ for idx in accepted_indices:
219
+ token_id = idx.item()
220
+ log_likelihood = log_likelihoods[batch_index, idx].item()
221
+
222
+ # Get theta (log of expected future grammaticality) for this specific token
223
+ success_rate = current_parent.get_success_rate(token_id)
224
+
225
+ if not isinstance(success_rate, torch.Tensor):
226
+ success_rate = torch.tensor(success_rate, dtype=torch.float)
227
+ log_theta = torch.log(success_rate)
228
+
229
+ # Calculate adjusted score
230
+ adjusted_score = log_likelihood + log_theta
231
+ adjusted_scores[batch_index, idx] = adjusted_score
232
+
233
+ return adjusted_scores
234
+
235
+ def process_scores(self, input_ids, scores):
236
+ # we dynamically create stacks at the first call, so that we know the batch size and beam size
237
+ if self.batch_parsing_states is None:
238
+ self.batch_parsing_states = [
239
+ copy.deepcopy(
240
+ self.grammar_constraint.string_recognizer.get_initial_accept_state()
241
+ )
242
+ for _ in range(len(input_ids))
243
+ ]
244
+
245
+ # assume the generation starts from the same index
246
+ if self.generate_start_index is None:
247
+ # the default is the end of input sequence of tokens
248
+ self.generate_start_index = self.parse_start_index \
249
+ if self.parse_start_index else input_ids.size(1)
250
+ self.generated_tokens = input_ids[:, self.generate_start_index:]
251
+
252
+ # Advance parser states
253
+ self.batch_parsing_states = self.grammar_constraint.advance_token_ids(
254
+ input_ids, self.batch_parsing_states, self.parse_start_index
255
+ )
256
+
257
+ adjusted_scores = self.adjust_scores(scores, scores.device)
258
+
259
+ return adjusted_scores
260
+
261
+ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
262
+ def __call__(
263
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
264
+ ) -> torch.FloatTensor:
265
+ return self.process_scores(input_ids, scores)
266
+
267
+ def reset(self):
268
+ self.reset_parser()
269
+ self.reset_history()
270
+
271
+ def reset_parser(self):
272
+ self.batch_parsing_states = None
273
+ if self.grammar_constraint.is_incremental:
274
+ self.grammar_constraint.reset()
275
+
276
+ self.generate_start_index = None
277
+ self.generated_tokens = None
278
+
279
+ def reset_history(self):
280
+ self.history = []
281
+
282
+ def reset_trie(self):
283
+ self.oracle_trie = Trie()
284
+
285
+ def get_accepted_tokens(self, acceptance):
286
+ """
287
+ Get the indices of accepted tokens and their corresponding string values for each item in the batch.
288
+
289
+ Parameters:
290
+ - acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch.
291
+ """
292
+ batch_size, _ = acceptance.shape
293
+ acceptance_np = acceptance.cpu().numpy()
294
+ accepted_x, accepted_y = acceptance_np.nonzero()
295
+
296
+ # Initialize the dictionary with empty lists for indices
297
+ accepted_token_indices = {i: [] for i in range(batch_size)}
298
+ for x, y in zip(accepted_x, accepted_y):
299
+ accepted_token_indices[x].append(y)
300
+
301
+ # Convert token IDs to tokens
302
+ accepted_tokens = {
303
+ i: [self.grammar_constraint.tokenizer.decode([token_id]) for token_id in token_ids]
304
+ for i, token_ids in accepted_token_indices.items()
305
+ }
306
+
307
+ return accepted_tokens
308
+
309
+ def store_detailed_history(self, acceptance, scores, adjusted_scores):
310
+ """
311
+ Processes and stores information for accepted tokens including their IDs, tokens,
312
+ raw scores, and logits.
313
+
314
+ Parameters:
315
+ - acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch.
316
+ - scores (torch.Tensor): The raw scores from the model output.
317
+ - adjusted_scores (torch.Tensor): The adjusted scores after applying expected future grammaticality.
318
+ """
319
+ likelihoods = F.softmax(scores, dim=-1)
320
+ adjusted_likelihoods = F.softmax(adjusted_scores, dim=-1)
321
+
322
+ # Initializing the list to store detailed information for each step
323
+ batch_accepted_info = []
324
+
325
+ for batch_index in range(acceptance.size(0)): # Iterate over batch items
326
+ accepted_info = []
327
+ accepted_indices = acceptance[batch_index].nonzero().squeeze(-1)
328
+
329
+ for idx in accepted_indices:
330
+ token_id = idx.item()
331
+ raw_score = scores[batch_index, idx].item()
332
+ likelihood = likelihoods[batch_index, idx].item()
333
+ adjusted_likelihood = adjusted_likelihoods[batch_index, idx].item()
334
+ token = self.grammar_constraint.tokenizer.decode([token_id])
335
+
336
+ # Store detailed information as a dictionary
337
+ accepted_info.append({
338
+ "token_id": token_id,
339
+ "token": str(token),
340
+ "raw_score": raw_score,
341
+ "raw_likelihood": likelihood,
342
+ "adjusted_likelihood": adjusted_likelihood
343
+ })
344
+
345
+ batch_accepted_info.append(accepted_info)
346
+
347
+ # Store this detailed information in the history
348
+ self.history.append(batch_accepted_info)
transformers_gad/grammar_utils.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .token_grammar_recognizer import IncrementalTokenRecognizer
2
+
3
+ # Old class name, kept for backward compatibility
4
+ IncrementalGrammarConstraint = IncrementalTokenRecognizer
transformers_gad/logging_config.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # logging_config.py
2
+ import os
3
+ import logging
4
+
5
+
6
+ def setup_logging():
7
+ log_level_name = os.getenv(
8
+ "TCFG_LOG_LEVEL", "WARNING"
9
+ ).upper() # Default to WARNING if not set
10
+ log_levels = {
11
+ "DEBUG": logging.DEBUG,
12
+ "INFO": logging.INFO,
13
+ "WARNING": logging.WARNING,
14
+ "ERROR": logging.ERROR,
15
+ "CRITICAL": logging.CRITICAL,
16
+ }
17
+ log_level = log_levels.get(log_level_name, logging.WARNING)
18
+ logging.basicConfig(level=log_level)
transformers_gad/mapping.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ from transformers_gad.utils import get_tokenizer_model_type, ints2bytes
4
+ from transformers import AutoTokenizer
5
+ import logging
6
+
7
+ log = logging.getLogger(__name__)
8
+
9
+
10
+ def get_mapping(tokenizer, unicode=False):
11
+ log.debug(f"tokenizer type: {tokenizer.__class__.__name__}")
12
+ log.debug(f"tokenizer model type: {get_tokenizer_model_type(tokenizer)}")
13
+ if not unicode:
14
+ if (
15
+ "gpt2" in tokenizer.__class__.__name__.lower()
16
+ or "bloom" in tokenizer.__class__.__name__.lower()
17
+ or "pretrainedtokenizer" in tokenizer.__class__.__name__.lower()
18
+ or "codegen" in tokenizer.__class__.__name__.lower()
19
+ or "gptneox" in tokenizer.__class__.__name__.lower()
20
+ ):
21
+ return BBPEMapping(tokenizer)
22
+ elif "t5" in tokenizer.__class__.__name__.lower():
23
+ return BPEMapping(tokenizer)
24
+ elif "llama" in tokenizer.__class__.__name__.lower():
25
+ return LlamaBPEMapping(tokenizer)
26
+ elif "xglm" in tokenizer.__class__.__name__.lower():
27
+ return UniGramMapping(tokenizer)
28
+ else:
29
+ raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__.__name__}")
30
+ else:
31
+ if "gpt2" in tokenizer.__class__.__name__.lower():
32
+ return UnicodeBBPEMapping(tokenizer)
33
+ else:
34
+ raise NotImplementedError(
35
+ f"Unicode mapping for {tokenizer.__class__.__name__}"
36
+ )
37
+
38
+
39
+ class Mapping:
40
+ def __init__(self, tokenizer):
41
+ self.eos_token_id = tokenizer.eos_token_id
42
+ self.bos_token_id = tokenizer.bos_token_id
43
+ self.tokenizer = tokenizer
44
+ self.special = tokenizer.all_special_ids
45
+
46
+ def __len__(self):
47
+ return len(self.tokenizer.get_vocab())
48
+
49
+ def _map(self, token_id: int) -> str:
50
+ # This is the case for BOS,
51
+ if token_id in self.special:
52
+ return ""
53
+ # if token_id is tensor, convert it to int
54
+ if hasattr(token_id, "item"):
55
+ token_id = token_id.item()
56
+ raw_token = self.tokenizer.convert_ids_to_tokens(token_id)
57
+ return raw_token
58
+
59
+ def map(self, token_id: int, verbose=False) -> bytes:
60
+ token = self._map(token_id)
61
+ if verbose:
62
+ log.debug(f"token_id: {token_id}, token: {token}")
63
+ return bytes(token, "utf-8")
64
+
65
+
66
+ class BBPEMapping(Mapping):
67
+ def __init__(self, *args, **kwargs):
68
+ super().__init__(*args, **kwargs)
69
+
70
+ def _map(self, token_id: int) -> str:
71
+ raw_token = super()._map(token_id)
72
+ if raw_token.startswith("Ġ"):
73
+ raw_token = raw_token.replace("Ġ", " ")
74
+ return raw_token
75
+
76
+
77
+ class UnicodeBBPEMapping(Mapping):
78
+ def __init__(self, *args, **kwargs):
79
+ super().__init__(*args, **kwargs)
80
+ self.intermediate_encoding = UnicodeBBPEMapping.get_intermediate_encoding(
81
+ self.tokenizer
82
+ )
83
+
84
+ def _map(self, token_id: int, verbose=False) -> str:
85
+ raw_token = super()._map(token_id)
86
+ # if raw_token.startswith("Ġ"):
87
+ # raw_token = raw_token.replace("Ġ", " ")
88
+ return raw_token
89
+
90
+ def map(self, token_id: int, verbose=False) -> bytes:
91
+ raw_token = self._map(token_id, verbose)
92
+ if verbose:
93
+ log.debug(f"token_id: {token_id}, raw_token: {raw_token}")
94
+ return self.intermediate_encoding.token2bytes(raw_token)
95
+
96
+ @staticmethod
97
+ def get_intermediate_encoding(tokenizer):
98
+ if "gpt2" in tokenizer.__class__.__name__.lower():
99
+ return ByteEncoding(tokenizer)
100
+ else:
101
+ return None
102
+
103
+
104
+ class BPEMapping(Mapping):
105
+ def __init__(self, tokenizer):
106
+ super().__init__(tokenizer)
107
+ self.last_token_id = None
108
+
109
+ def _map(self, token_id: int) -> str:
110
+ raw_token = super()._map(token_id)
111
+
112
+ # we need to check if the token is at the beginning of the sentence to remove the space
113
+ # specific to BPE
114
+ at_bos = False
115
+ if self.last_token_id is not None and self.last_token_id == self.bos_token_id:
116
+ at_bos = True
117
+ self.last_token_id = token_id
118
+ if raw_token.startswith("▁"):
119
+ raw_token = raw_token.replace("▁", " ")
120
+ if at_bos:
121
+ # remove space at the beginning of the sentence
122
+ raw_token = raw_token[1:]
123
+ return raw_token
124
+
125
+
126
+ class LlamaBPEMapping(BPEMapping):
127
+ def __init__(self, tokenizer):
128
+ super().__init__(tokenizer)
129
+
130
+ def _map(self, token_id: int) -> str:
131
+ raw_token = super()._map(token_id)
132
+ # if the token is hex, token is a string like "<0x00>"
133
+ # first 256 tokens are hex
134
+ if raw_token.startswith("<0x"):
135
+ hex_value = raw_token[4:-1]
136
+ raw_token = chr(int(hex_value, 16))
137
+ return raw_token
138
+
139
+
140
+ class WordPieceMapping(Mapping):
141
+ def __init__(self, tokenizer):
142
+ super().__init__(tokenizer)
143
+
144
+ def map(self, token_id: int) -> bytes:
145
+ if token_id in self.special:
146
+ return bytes()
147
+ return bytes(
148
+ self.tokenizer.decode([token_id], clean_up_tokenization_spaces=False),
149
+ "utf-8",
150
+ )
151
+
152
+
153
+ class UniGramMapping(Mapping):
154
+ def __init__(self, tokenizer):
155
+ super().__init__(tokenizer)
156
+
157
+ def map(self, token_id: int) -> bytes:
158
+ if token_id in self.special:
159
+ return bytes()
160
+ return bytes(
161
+ self.tokenizer.decode([token_id], clean_up_tokenization_spaces=False),
162
+ "utf-8",
163
+ )
164
+
165
+
166
+ class XGLMUniGramMapping(Mapping):
167
+ def __init__(self, tokenizer):
168
+ super().__init__(tokenizer)
169
+ self.bos_token_id = tokenizer.eos_token_id
170
+ self.eos_token_id = None
171
+
172
+
173
+ class ByteEncoding:
174
+ def __init__(self, tokenizer):
175
+ # check if the tokenizer is fast, if so, convert it to slow
176
+ if tokenizer.is_fast:
177
+ tokenizer = AutoTokenizer.from_pretrained(
178
+ tokenizer.name_or_path, use_fast=False
179
+ )
180
+ self.tokenizer = tokenizer
181
+ self.byte2char: Dict[int, str] = tokenizer.byte_encoder
182
+ self.char2byte: Dict[str, int] = tokenizer.byte_decoder
183
+ # code point to byte
184
+ self.cdp2byte: Dict[int, int] = {ord(c): b for c, b in self.char2byte.items()}
185
+ self.byte2cdp: Dict[int, int] = {v: k for k, v in self.cdp2byte.items()}
186
+
187
+ def map(self, byte: int) -> int:
188
+ assert 0 <= byte < 256, f"byte: {byte} is not in the range [0, 256)"
189
+ return ord(self.byte2char[byte])
190
+
191
+ def token_ids2bytes(self, token_ids: List[int]) -> bytes:
192
+ tokens: List[str] = self.tokenizer.convert_ids_to_tokens(token_ids)
193
+ # for token id = BOS, the token should be empty string instead of <s>
194
+ # TODO, this may cause issues because this means that special tokens like BOS can appear at any position
195
+ tokens = [
196
+ "" if token in self.tokenizer.all_special_ids else token for token in tokens
197
+ ]
198
+ bytes: List[List[int]] = [self.token2bytes(token) for token in tokens]
199
+ # join the bytes
200
+ return ints2bytes(sum(bytes, []))
201
+
202
+ def token_id2bytes(self, token_id: int) -> bytes:
203
+ token: str = self.tokenizer.convert_ids_to_tokens(token_id)
204
+ return self.token2bytes(token)
205
+
206
+ def token2bytes(self, token: str) -> bytes:
207
+ # import pdb; pdb.set_trace()
208
+ bytes_seq: List[int] = [self.char2byte[c] for c in token]
209
+ return bytes(bytes_seq)
transformers_gad/oracle/__init_.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .oracle_trie import Trie, TrieNode, update_oracle_trie
transformers_gad/oracle/__pycache__/oracle_trie.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
transformers_gad/oracle/oracle_trie.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import json
4
+ import logging
5
+
6
+ class TrieNode:
7
+ def __init__(self,
8
+ token_id=None, raw_likelihood=None, raw_score=None,
9
+ success_rate=1,
10
+ is_start_of_sequence=False, is_end_of_sequence=False,
11
+ eos_token_id=2):
12
+ self.children = {}
13
+ self.parent = None
14
+ self.token_id = token_id
15
+ self.raw_likelihood = raw_likelihood
16
+ self.raw_score = raw_score
17
+
18
+ # The default approximation of EFG
19
+ self.success_rate = success_rate
20
+
21
+ self.eos_token_id = eos_token_id
22
+ self.is_start_of_sequence = is_start_of_sequence
23
+ self.is_end_of_sequence = is_end_of_sequence
24
+
25
+ def insert(self, child_node):
26
+ """
27
+ Insert child_node into the children dictionary
28
+ """
29
+ if child_node.token_id not in self.children:
30
+ self.children[child_node.token_id] = child_node
31
+ child_node.parent = self
32
+
33
+ if child_node.token_id == self.eos_token_id:
34
+ child_node.is_end_of_sequence = True
35
+
36
+ # update the success rate of the parent node
37
+ return self.update_success_rate()
38
+ else:
39
+ return 0
40
+
41
+ def insert_accepted_tokens(self, scores, acceptance):
42
+ """
43
+ Create node from acceptance and scores and
44
+ insert as children of self node
45
+ """
46
+ likelihoods = F.softmax(scores, dim=-1)
47
+
48
+ for batch_index in range(acceptance.size(0)):
49
+ accepted_tokens = acceptance[batch_index].nonzero().squeeze(-1)
50
+
51
+ for token_id in accepted_tokens:
52
+ if token_id not in self.children:
53
+ raw_likelihood = likelihoods[batch_index, token_id].item()
54
+ raw_score = scores[batch_index, token_id].item()
55
+
56
+ child_node = TrieNode(
57
+ token_id=token_id.item(),
58
+ raw_likelihood=raw_likelihood,
59
+ raw_score=raw_score)
60
+
61
+ self.insert(child_node)
62
+
63
+ def get_success_rate(self, token_id):
64
+ """
65
+ Return Approximated Expected Future Grammaticality of the token_id
66
+ """
67
+ if token_id in self.children:
68
+ return self.children[token_id].success_rate
69
+ else:
70
+ return 1
71
+
72
+ def update_success_rate(self):
73
+ """
74
+ Re-compute the success rate from the updated success rate of children
75
+ """
76
+ if self.children:
77
+ total_success_rate = sum(child.raw_likelihood * child.success_rate for child in self.children.values())
78
+
79
+ # Get how much of unexplored nodes are covered with this update
80
+ updated_rate = self.success_rate - total_success_rate
81
+ self.success_rate = total_success_rate
82
+
83
+ # Back propagate the success rate
84
+ if self.parent:
85
+ return self.parent.update_success_rate()
86
+
87
+ return updated_rate
88
+
89
+ def prefix_raw_likelihood(self):
90
+ if self.parent:
91
+ return self.raw_likelihood * self.parent.prefix_raw_likelihood()
92
+ else:
93
+ return self.raw_likelihood
94
+
95
+ def search_token(self, token_id):
96
+ """
97
+ Check if the self node has a children with token_id
98
+ Return the children node if it exists, return None otherwise
99
+ """
100
+ if token_id in self.children:
101
+ return self.children[token_id]
102
+ else:
103
+ return None
104
+
105
+ def to_dict(self):
106
+ """
107
+ Convert a trie into a dictionary by removing the pointer to the parent
108
+ """
109
+ return {
110
+ "token_id": self.token_id,
111
+ "raw_likelihood": self.raw_likelihood,
112
+ "raw_score": self.raw_score,
113
+ "success_rate": self.success_rate,
114
+ "eos_token_id": self.eos_token_id,
115
+ "is_start_of_sequence": self.is_start_of_sequence,
116
+ "is_end_of_sequence": self.is_end_of_sequence,
117
+ "children": [child.to_dict() for child in self.children.values()]
118
+ }
119
+
120
+ @staticmethod
121
+ def from_dict(d):
122
+ """
123
+ Recursively (re)construct trie from dictionary
124
+ """
125
+ node = TrieNode(
126
+ token_id=d['token_id'],
127
+ raw_likelihood=d['raw_likelihood'],
128
+ raw_score=d['raw_score'],
129
+ success_rate=d['success_rate'],
130
+ is_start_of_sequence=d['is_start_of_sequence'],
131
+ is_end_of_sequence=d['is_end_of_sequence'],
132
+ eos_token_id=d['eos_token_id'])
133
+
134
+ node.children = {child['token_id']:TrieNode.from_dict(child) for child in node.children}
135
+ for child in node.children.values():
136
+ child.parent = node
137
+
138
+ return node
139
+
140
+ def __repr__(self):
141
+ parent_token_id = 'None (Root Node)' if self.parent is None else self.parent.token_id
142
+ return (f"TrieNode(token_id={self.token_id}', "
143
+ f"raw_likelihood={self.raw_likelihood}, raw_score={self.raw_score}, children={list(self.children.keys())}, "
144
+ f"parent={parent_token_id}, success rate={self.success_rate})")
145
+
146
+ class Trie:
147
+ def __init__(self):
148
+ self.root = TrieNode()
149
+ self.root.is_start_of_sequence = True
150
+
151
+ def search_last_parent(self, prefix: torch.LongTensor):
152
+ """
153
+ Search the longest prefix in the trie that matches to the input sequence of tokens 'prefix'
154
+ """
155
+ matched_prefix = []
156
+ current_parent = self.root
157
+
158
+ # Assume one batch of prefix
159
+ for time_step, token_id in enumerate(prefix[0]):
160
+ token_id = token_id.item()
161
+ if token_id in current_parent.children:
162
+ current_parent = current_parent.children[token_id]
163
+ matched_prefix.append(current_parent.token_id)
164
+ else:
165
+ print(
166
+ f"matched prefix is {matched_prefix}; current {token_id} not found in the trie at time step {time_step}")
167
+ return None
168
+
169
+ return current_parent
170
+
171
+ def search(self, sequence):
172
+ """
173
+ Return the sequence of nodes that exactly matches with the input
174
+ """
175
+ node = self.root
176
+ nodes = []
177
+ for token_id in sequence:
178
+ if token_id not in node.children:
179
+ return None
180
+ node = node.children[token_id]
181
+ nodes.append(node)
182
+ return nodes
183
+
184
+ def raw_likelihood(self, sequence):
185
+ """
186
+ Return the raw likelihood (before the adjustment) of sequence
187
+ """
188
+ if isinstance(sequence, torch.Tensor):
189
+ sequence = sequence.tolist()
190
+
191
+ nodes = self.search(sequence)
192
+ if nodes is None:
193
+ return None
194
+
195
+ likelihood = 1
196
+ for node in nodes:
197
+ likelihood *= node.raw_likelihood
198
+ return likelihood
199
+
200
+ def json(self):
201
+ return json.dumps(self.root.to_dict(), indent=2)
202
+
203
+ @staticmethod
204
+ def loads(js):
205
+ trie = Trie()
206
+ trie.root = TrieNode.from_dict(json.loads(js))
207
+
208
+ return trie
209
+
210
+ def print_trie(self, node=None, prefix=None):
211
+ """
212
+ Print all the leaves in the trie
213
+ """
214
+ if node is None:
215
+ node = self.root
216
+ if prefix is None:
217
+ prefix = []
218
+
219
+ # If current node marks the end of a sequence, print the prefix as a list
220
+ if node.is_end_of_sequence or len(node.children) == 0:
221
+ print(prefix)
222
+
223
+ # Recursively call print_trie for all children, appending the current character/token to the prefix
224
+ for char, child_node in node.children.items():
225
+ self.print_trie(child_node, prefix + [char])
226
+
227
+ def has_full_information(self):
228
+ """
229
+ Checks if all paths in the trie end with an is_end_of_sequence node set to True.
230
+ Returns True if the trie has full information, False otherwise.
231
+ """
232
+ return self._check_full_information(self.root)
233
+
234
+ def _check_full_information(self, node):
235
+ # If the node has no children, check if it is marked as the end of a sequence
236
+ if not node.children:
237
+ return node.is_end_of_sequence
238
+
239
+ # Recursively check all children
240
+ return all(self._check_full_information(child) for child in node.children.values())
241
+
242
+ def print_all_nodes(self, node=None, depth=0):
243
+ """
244
+ Print all the nodes in the trie (including non-leaves)
245
+ """
246
+
247
+ if node is None:
248
+ node = self.root
249
+
250
+ # Print current node's details
251
+ indent = " " * depth # Create indentation based on the depth in the trie
252
+ node_details = (f"{indent}TrieNode(token_id={node.token_id}', "
253
+ f"raw_likelihood={node.raw_likelihood}, raw_score={node.raw_score}, success rate={node.success_rate}, "
254
+ f"children={list(node.children.keys())}, "
255
+ f"parent={node.parent.token_id if node.parent else None}, "
256
+ f"is_end_of_sequence={node.is_end_of_sequence})")
257
+ print(node_details)
258
+
259
+ # Recursively call print_all_nodes for all children
260
+ for child_node in node.children.values():
261
+ self.print_all_nodes(child_node, depth + 1)
transformers_gad/parser.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import sys
4
+ from typing import List
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ END_OF_ALTERNATE_MARKER = 0
9
+ END_OF_RULE_MARKER = 0
10
+ END_OF_GRAMMAR_MARKER = 0xFFFF
11
+ TO_BE_FILLED_MARKER = 0
12
+ REF_RULE_MARKER = 1
13
+ LITERAL_MARKER = 2
14
+
15
+
16
+ ########################
17
+ # EBNF Grammar Parsing #
18
+ ########################
19
+
20
+
21
+ class ParseState:
22
+ def __init__(self):
23
+ self.symbol_table = {}
24
+ self.grammar_encoding = [] # old name: out_grammar
25
+
26
+ def print(self, file=sys.stdout):
27
+ print_grammar(file, self)
28
+
29
+
30
+ def get_symbol_id(state: ParseState, symbol_name: str) -> int:
31
+ if symbol_name not in state.symbol_table:
32
+ state.symbol_table[symbol_name] = len(state.symbol_table)
33
+ return state.symbol_table[symbol_name]
34
+
35
+
36
+ def generate_symbol_id(state: ParseState, base_name: str) -> int:
37
+ next_id = len(state.symbol_table)
38
+ state.symbol_table[base_name + "_" + str(next_id)] = next_id
39
+ return next_id
40
+
41
+
42
+ def is_word_char(c: str) -> bool:
43
+ """
44
+ Check if a char is a-z, A-Z, 0-9, -, _, i.e., chars allowed as rule names
45
+ Returns:
46
+
47
+ """
48
+ return c.isalnum() or c == "-" or c == "_"
49
+
50
+
51
+ def hex_to_int(c: str) -> int:
52
+ """
53
+ Convert a hex char to int, c should be in the range of 0-9, a-f, A-F
54
+ case insensitive
55
+ Args:
56
+ c: a hex char
57
+ Returns:
58
+ int: the int value of the hex char
59
+ """
60
+ if c.isdigit():
61
+ return int(c)
62
+ elif "a" <= c.lower() <= "f":
63
+ return ord(c.lower()) - ord("a") + 10
64
+ return -1
65
+
66
+
67
+ def remove_leading_white_space(src, rm_leading_newline):
68
+ """
69
+ Skips over whitespace and comments in the input string.
70
+
71
+ This function processes the input string, skipping over any spaces, tabs,
72
+ and content following a '#' character, which denotes a comment. The parsing
73
+ of a comment continues until the end of the line (denoted by newline characters
74
+ '\r' or '\n'). If the 'rm_leading_newline' parameter is set to False, the function
75
+ will stop processing and return the remaining string upon encountering a
76
+ newline character, otherwise it will skip over newline characters as well.
77
+
78
+ Parameters:
79
+ src (str): The input string to be processed.
80
+ rm_leading_newline (bool): A flag indicating whether encountering a newline character
81
+ should stop the parsing (False) or if it should be skipped (True).
82
+
83
+ Returns:
84
+ str: The remaining portion of the input string after skipping whitespace and comments.
85
+ """
86
+ pos = 0
87
+ while pos < len(src) and (src[pos].isspace() or src[pos] == "#"):
88
+ if src[pos] == "#":
89
+ while pos < len(src) and src[pos] not in ("\r", "\n"):
90
+ pos += 1
91
+ else:
92
+ if not rm_leading_newline and src[pos] in ("\r", "\n"):
93
+ break
94
+ pos += 1
95
+ return src[pos:]
96
+
97
+
98
+ def parse_name(src) -> (str, str):
99
+ """
100
+ parse the leading name from the input string
101
+ Args:
102
+ src: the input grammar string
103
+
104
+ Returns:
105
+ name, remaining_src
106
+ """
107
+ pos = 0
108
+ while pos < len(src) and is_word_char(src[pos]):
109
+ pos += 1
110
+ if pos == 0:
111
+ raise RuntimeError("expecting name at " + src)
112
+ return src[:pos], src[pos:]
113
+
114
+
115
+ def parse_char(src) -> (str, str):
116
+ """
117
+ parse the leading char from the input string
118
+ :param src:
119
+ :return: char, remaining_src
120
+ """
121
+
122
+ # if we have a backslash, it's maybe an escape
123
+ if src[0] == "\\":
124
+ esc = src[1]
125
+ if esc == "x":
126
+ first = hex_to_int(src[2])
127
+ if first > -1:
128
+ second = hex_to_int(src[3])
129
+ if second > -1:
130
+ return (first << 4) + second, src[4:]
131
+ raise RuntimeError("expecting \\xNN at " + src)
132
+ elif esc in ('"', "[", "]"):
133
+ return esc, src[2:]
134
+ elif esc == "r":
135
+ return "\r", src[2:]
136
+ elif esc == "n":
137
+ return "\n", src[2:]
138
+ elif esc == "t":
139
+ return "\t", src[2:]
140
+ elif esc == "\\":
141
+ return "\\", src[2:]
142
+ elif esc == "/":
143
+ return "\\", src[1:]
144
+ raise RuntimeError("unknown escape at " + src)
145
+ elif src:
146
+ return src[0], src[1:]
147
+ raise RuntimeError("unexpected end of input")
148
+
149
+
150
+ def _parse_rhs_literal_string(src: str, outbuf: List[int]) -> str:
151
+ assert src[0] == '"', f"rule should start with '\"', but got {src[0]}"
152
+ remaining_src = src[1:]
153
+
154
+ # advance until we get an end quote or run out of input
155
+ while remaining_src and remaining_src[0] != '"':
156
+ char, remaining_src = parse_char(remaining_src)
157
+ outbuf.append(LITERAL_MARKER)
158
+ # print(f"char: {char}")
159
+ outbuf.append(ord(char))
160
+ outbuf.append(ord(char))
161
+
162
+ # in case we ran out of input before finding the end quote
163
+ if not remaining_src:
164
+ raise RuntimeError(f"expecting an end quote at {src},but not found")
165
+
166
+ # remove the end quote and return the remaining string
167
+ return remaining_src[1:]
168
+
169
+
170
+ def _parse_rhs_char_ranges(src: str, outbuf: List[int]) -> str:
171
+ assert src[0] == "[", f"rule should start with '[', but got {src[0]}"
172
+ remaining_src = src[1:]
173
+ start_idx = len(outbuf)
174
+ # num chars in range - replaced at end of loop
175
+ outbuf.append(TO_BE_FILLED_MARKER)
176
+ while remaining_src and remaining_src[0] != "]":
177
+ char, remaining_src = parse_char(remaining_src)
178
+
179
+ outbuf.append(ord(char))
180
+ if remaining_src[0] == "-" and remaining_src[1] != "]":
181
+ endchar_pair, remaining_src = parse_char(remaining_src[1:])
182
+ outbuf.append(ord(endchar_pair))
183
+ else:
184
+ # This is the case for enumerate, e.g., [0123456789], [abcdef]
185
+ # Each char is considered as a range of itself, i.e., c-c
186
+ outbuf.append(ord(char))
187
+ if not remaining_src:
188
+ raise RuntimeError(
189
+ f"expecting an ] at {src},but not found, is the char range closed?"
190
+ )
191
+ # replace num chars with actual
192
+ outbuf[start_idx] = len(outbuf) - start_idx - 1
193
+ return remaining_src[1:]
194
+
195
+
196
+ def _parse_rhs_symbol_reference(src: str, state: ParseState, outbuf: List[int]) -> str:
197
+ assert is_word_char(src[0]), f"rule should start with a word char, but got {src[0]}"
198
+ name, remaining_src = parse_name(src)
199
+ ref_rule_id = get_symbol_id(state, name)
200
+ outbuf.append(REF_RULE_MARKER)
201
+ outbuf.append(ref_rule_id)
202
+ return remaining_src
203
+
204
+
205
+ def _parse_rhs_grouping(
206
+ remaining_src: str, state: ParseState, rule_name: str, outbuf: List[int]
207
+ ) -> str:
208
+ assert (
209
+ remaining_src[0] == "("
210
+ ), f"rule should start with '(', but got {remaining_src[0]}"
211
+ remaining_src = remove_leading_white_space(remaining_src[1:], True)
212
+ # parse nested alternates into synthesized rule
213
+ synthetic_rule_id = generate_symbol_id(state, rule_name)
214
+ remaining_src = parse_rhs(state, remaining_src, rule_name, synthetic_rule_id, True)
215
+ # output reference to synthesized rule
216
+ outbuf.append(REF_RULE_MARKER)
217
+ outbuf.append(synthetic_rule_id)
218
+
219
+ if not remaining_src or remaining_src[0] != ")":
220
+ raise RuntimeError("expecting ')' at " + remaining_src)
221
+ return remaining_src[1:]
222
+
223
+
224
+ def _parse_rhs_repetition_operators(
225
+ remaining_src: str,
226
+ state: ParseState,
227
+ rule_name: str,
228
+ last_sym_start: int,
229
+ outbuf: List[int],
230
+ ) -> str:
231
+ assert remaining_src[0] in (
232
+ "*",
233
+ "+",
234
+ "?",
235
+ ), f"rule should start with '*', '+', or '?', but got {remaining_src[0]}"
236
+ out_grammar = state.grammar_encoding
237
+ # last_sym_start = len(outbuf)
238
+
239
+ # apply transformation to previous symbol (last_sym_start -
240
+ # end) according to rewrite rules:
241
+ # S* --> S' ::= S S' |
242
+ # S+ --> S' ::= S S' | S
243
+ # S? --> S' ::= S |
244
+ sub_rule_id = generate_symbol_id(state, rule_name)
245
+ out_grammar.append(sub_rule_id)
246
+ sub_rule_offset = len(out_grammar)
247
+ # placeholder for size of 1st alternate
248
+ out_grammar.append(TO_BE_FILLED_MARKER)
249
+ # add preceding symbol to generated rule
250
+ out_grammar.extend(outbuf[last_sym_start:])
251
+ if remaining_src[0] in ("*", "+"):
252
+ # cause generated rule to recurse
253
+ out_grammar.append(REF_RULE_MARKER)
254
+ out_grammar.append(sub_rule_id)
255
+ # apply actual size
256
+ out_grammar[sub_rule_offset] = len(out_grammar) - sub_rule_offset
257
+ # mark end of 1st alternate
258
+ out_grammar.append(END_OF_ALTERNATE_MARKER)
259
+ sub_rule_offset = len(out_grammar)
260
+ # placeholder for size of 2nd alternate
261
+ out_grammar.append(TO_BE_FILLED_MARKER)
262
+ if remaining_src[0] == "+":
263
+ # add preceding symbol as alternate only for '+'
264
+ out_grammar.extend(outbuf[last_sym_start:])
265
+ # apply actual size of 2nd alternate
266
+ out_grammar[sub_rule_offset] = len(out_grammar) - sub_rule_offset
267
+ # mark end of 2nd alternate, then end of rule
268
+ out_grammar.append(END_OF_ALTERNATE_MARKER)
269
+ out_grammar.append(END_OF_RULE_MARKER)
270
+
271
+ # in original rule, replace previous symbol with reference to generated rule
272
+ outbuf[last_sym_start:] = [REF_RULE_MARKER, sub_rule_id]
273
+ return remaining_src[1:]
274
+
275
+
276
+ def parse_simple_rhs(state, rhs: str, rule_name: str, outbuf, is_nested):
277
+ simple_rhs_offset = len(outbuf)
278
+
279
+ # sequence size, will be replaced at end when known
280
+ outbuf.append(TO_BE_FILLED_MARKER)
281
+
282
+ last_sym_start = len(outbuf)
283
+ remaining_rhs = rhs
284
+ while remaining_rhs:
285
+ if remaining_rhs[0] == '"': # literal string
286
+ # mark the start of the last symbol, for repetition operator
287
+ last_sym_start = len(outbuf)
288
+ remaining_rhs = _parse_rhs_literal_string(remaining_rhs, outbuf)
289
+ elif remaining_rhs[0] == "[": # char range(s)
290
+ # mark the start of the last symbol, for repetition operator
291
+ last_sym_start = len(outbuf)
292
+ remaining_rhs = _parse_rhs_char_ranges(remaining_rhs, outbuf)
293
+ elif is_word_char(remaining_rhs[0]): # rule reference
294
+ # mark the start of the last symbol, for repetition operator
295
+ last_sym_start = len(outbuf)
296
+ remaining_rhs = _parse_rhs_symbol_reference(remaining_rhs, state, outbuf)
297
+ elif remaining_rhs[0] == "(": # grouping
298
+ # mark the start of the last symbol, for repetition operator
299
+ last_sym_start = len(outbuf)
300
+ remaining_rhs = _parse_rhs_grouping(remaining_rhs, state, rule_name, outbuf)
301
+ elif remaining_rhs[0] in ("*", "+", "?"): # repetition operator
302
+ # No need to mark the start of the last symbol, because we already did it
303
+ if len(outbuf) - simple_rhs_offset - 1 == 0:
304
+ raise RuntimeError(
305
+ "expecting preceeding item to */+/? at " + remaining_rhs
306
+ )
307
+ remaining_rhs = _parse_rhs_repetition_operators(
308
+ remaining_rhs, state, rule_name, last_sym_start, outbuf
309
+ )
310
+ else:
311
+ # case for newline, i.e., end of rule
312
+ assert remaining_rhs[0] in [
313
+ "\n",
314
+ "|",
315
+ ")",
316
+ ], f"rule should end with newline or '|', but got {remaining_rhs[0]}"
317
+ # we break here so that we call parse_rule again to parse the next rule
318
+ break
319
+ # Here we do not rm newline deliberately so that we know the rhs is ended
320
+ remaining_rhs = remove_leading_white_space(
321
+ remaining_rhs, rm_leading_newline=is_nested
322
+ )
323
+
324
+ # apply actual size of this alternate sequence
325
+ outbuf[simple_rhs_offset] = len(outbuf) - simple_rhs_offset
326
+ # mark end of alternate
327
+ outbuf.append(END_OF_ALTERNATE_MARKER)
328
+ return remaining_rhs
329
+
330
+
331
+ def parse_rhs(state, rhs: str, rule_name, rule_id, is_nested):
332
+ outbuf = []
333
+ remaining_rhs = parse_simple_rhs(state, rhs, rule_name, outbuf, is_nested)
334
+ while remaining_rhs and remaining_rhs[0] == "|":
335
+ remaining_rhs = remove_leading_white_space(remaining_rhs[1:], True)
336
+ remaining_rhs = parse_simple_rhs(
337
+ state, remaining_rhs, rule_name, outbuf, is_nested
338
+ )
339
+
340
+ # Now we have finished parsing the rhs, we can add the rule to the grammar_encoding
341
+ state.grammar_encoding.append(rule_id)
342
+ state.grammar_encoding.extend(outbuf)
343
+ state.grammar_encoding.append(END_OF_RULE_MARKER)
344
+ return remaining_rhs
345
+
346
+
347
+ def parse_rule(state: ParseState, rule_text: str) -> str:
348
+ name, remaining_rule_text = parse_name(rule_text)
349
+ remaining_rule_text = remove_leading_white_space(remaining_rule_text, False)
350
+ # check if the rule is already defined, TODO: what will happen if the rule is already defined?
351
+ rule_id = get_symbol_id(state, name)
352
+
353
+ if remaining_rule_text[:3] != "::=":
354
+ raise RuntimeError("expecting ::= at " + remaining_rule_text)
355
+ remaining_rule_text = remove_leading_white_space(remaining_rule_text[3:], True)
356
+
357
+ remaining_rule_text = parse_rhs(state, remaining_rule_text, name, rule_id, False)
358
+
359
+ if remaining_rule_text and remaining_rule_text[0] == "\r":
360
+ remaining_rule_text = (
361
+ remaining_rule_text[2:]
362
+ if remaining_rule_text[1] == "\n"
363
+ else remaining_rule_text[1:]
364
+ )
365
+ elif remaining_rule_text and remaining_rule_text[0] == "\n":
366
+ remaining_rule_text = remaining_rule_text[1:]
367
+ elif remaining_rule_text:
368
+ raise RuntimeError("expecting newline or end at " + remaining_rule_text)
369
+ return remove_leading_white_space(remaining_rule_text, True)
370
+
371
+
372
+ def parse_ebnf(grammar_text: str) -> ParseState:
373
+ try:
374
+ state = ParseState()
375
+ remaining_grammar_text = remove_leading_white_space(grammar_text, True)
376
+ last_grammar_repr = ""
377
+ while remaining_grammar_text:
378
+ if last_grammar_repr:
379
+ last_parsed_rule_len = len(last_grammar_repr) - len(
380
+ remaining_grammar_text
381
+ )
382
+ logger.debug(
383
+ f"last_parsed_rule: {last_grammar_repr[:last_parsed_rule_len]}"
384
+ )
385
+ last_grammar_repr = remaining_grammar_text
386
+ remaining_grammar_text = parse_rule(state, remaining_grammar_text)
387
+ state.grammar_encoding.append(END_OF_GRAMMAR_MARKER)
388
+ return state
389
+ except RuntimeError as err:
390
+ logger.warning("error parsing grammar:", err)
391
+ return ParseState()
392
+
393
+
394
+ ###################################
395
+ # EBNF Grammar Parsing ends here #
396
+ ###################################
397
+
398
+
399
+ def break_grammar_into_rules(grammar_encoding: List[int]) -> List[List[int]]:
400
+ offset = 0
401
+ # we loop until we reach the end of the grammar_encoding
402
+ rule_encodings = []
403
+ i = 0
404
+ while i < len(grammar_encoding) - 2:
405
+ if (
406
+ grammar_encoding[i] == END_OF_ALTERNATE_MARKER
407
+ and grammar_encoding[i + 1] == END_OF_RULE_MARKER
408
+ ):
409
+ rule_encodings.append(grammar_encoding[offset : i + 2])
410
+ offset = i + 2
411
+ # skip the END_OF_RULE_MARKER
412
+ # This is mandatory because if we do not skip the END_OF_RULE_MARKER
413
+ # we fail in the case where the next rule has rule_id 0
414
+ i += 1
415
+ i += 1
416
+ return rule_encodings
417
+
418
+
419
+ def break_rule_into_elements(rule_encoding: List[int]) -> List[List[int]]:
420
+ rule_id = rule_encoding.pop(0)
421
+ end_of_rule_marker = rule_encoding.pop(-1)
422
+ assert (
423
+ end_of_rule_marker == END_OF_RULE_MARKER
424
+ ), f"rule should end with {END_OF_RULE_MARKER}, but got {end_of_rule_marker}"
425
+
426
+ offset = 0
427
+ elements = []
428
+ while offset < len(rule_encoding):
429
+ element_size = rule_encoding[offset]
430
+ assert (
431
+ rule_encoding[offset + element_size] == END_OF_ALTERNATE_MARKER
432
+ ), f"element should end with {END_OF_ALTERNATE_MARKER}, but got {rule_encoding[offset + element_size]}"
433
+ elements.append(rule_encoding[offset : offset + element_size + 1])
434
+ offset += element_size + 1
435
+ return elements
436
+
437
+
438
+ def _print_annotated_grammar(file, grammar_encoding, symbol_id_names, index=0):
439
+ rule_id = grammar_encoding[index]
440
+ print(f"<{index}>{symbol_id_names[rule_id]} ::=", end=" ", file=file)
441
+ pos = index + 1
442
+ while grammar_encoding[pos]:
443
+ if pos - 1 > index:
444
+ print("|", end=" ", file=file)
445
+ pos += 1 # sequence size, not needed here
446
+ while grammar_encoding[pos]:
447
+ if grammar_encoding[pos] == REF_RULE_MARKER:
448
+ ref_rule_id = grammar_encoding[pos + 1]
449
+ print(
450
+ f"<{pos}>{symbol_id_names[ref_rule_id]}",
451
+ end=" ",
452
+ file=file,
453
+ )
454
+ pos += 2
455
+ else:
456
+ print("<{}>[".format(pos), end="", file=file)
457
+ num_chars = grammar_encoding[pos]
458
+ pos += 1
459
+
460
+ for i in range(0, num_chars, 2):
461
+ print(
462
+ "{}-".format(chr(grammar_encoding[pos + i])), end="", file=file
463
+ )
464
+ if i + 1 < num_chars:
465
+ print(
466
+ "{}".format(chr(grammar_encoding[pos + i + 1])),
467
+ end="",
468
+ file=file,
469
+ )
470
+ print("]", end=" ", file=file)
471
+ pos += num_chars
472
+ pos += 1
473
+ print(file=file)
474
+ return pos + 1
475
+
476
+
477
+ def print_grammar(file, state):
478
+ pos = 0
479
+ symbol_id_names = {v: k for k, v in state.symbol_table.items()}
480
+ print("Grammar Rules:", file=file)
481
+ while (
482
+ pos < len(state.grammar_encoding)
483
+ and state.grammar_encoding[pos] != END_OF_GRAMMAR_MARKER
484
+ ):
485
+ pos = _print_annotated_grammar(
486
+ file, state.grammar_encoding, symbol_id_names, pos
487
+ )
488
+ if pos > len(state.grammar_encoding):
489
+ raise Warning(f"grammar_encoding is not ended with {END_OF_GRAMMAR_MARKER}")
490
+ pos = 0
491
+ print("\nGrammar Hex representation:", file=file)
492
+ while (
493
+ pos < len(state.grammar_encoding)
494
+ and state.grammar_encoding[pos] != END_OF_GRAMMAR_MARKER
495
+ ):
496
+ print(f"{state.grammar_encoding[pos]:04x}", end=" ", file=file)
497
+ pos += 1
498
+ if pos > len(state.grammar_encoding):
499
+ raise Warning(f"grammar_encoding is not ended with {END_OF_GRAMMAR_MARKER}")
500
+ else:
501
+ print("ffff\n")
502
+
503
+ print("Rules Decimal representation:", file=file)
504
+ # we loop until we reach the end of the grammar_encoding
505
+ rule_encodings = break_grammar_into_rules(state.grammar_encoding)
506
+ for rule_encoding in rule_encodings:
507
+ rule_id = rule_encoding[0]
508
+ print(
509
+ f"<{rule_id}> {break_rule_into_elements(rule_encoding)}",
510
+ file=file,
511
+ )
512
+
513
+
514
+ if __name__ == "__main__":
515
+ parser = argparse.ArgumentParser(description="Parse EBNF grammar files.")
516
+ parser.add_argument(
517
+ "-g",
518
+ "--grammar-file",
519
+ nargs="?",
520
+ default="/nobackup2/yf/mila/GD/examples/sygus/PRE_100_bare.ebnf",
521
+ help="Path to the grammar file",
522
+ )
523
+
524
+ args = parser.parse_args()
525
+
526
+ # set logging level
527
+ logging.basicConfig(level=logging.DEBUG)
528
+
529
+ with open(args.grammar_file, "r") as file:
530
+ input_text = file.read()
531
+ parsed_grammar = parse_ebnf(input_text)
532
+ print("parse state:")
533
+ parsed_grammar.print()
534
+ # print(f"symbol_ids: \n{parsed_grammar.symbol_table}")
535
+
536
+ # start_rule_id = parsed_grammar.symbol_table["root"]
537
+
538
+ # DEBUG: __main__:last_parsed_rule: root: := "0"d | "1"a
539
+ #
540
+ # DEBUG: __main__:last_parsed_rule: a: := "0"c | "1"b
541
+ #
542
+ # DEBUG: __main__:last_parsed_rule: b: := "0" | "1"
543
+ #
544
+ # DEBUG: __main__:last_parsed_rule: c: := "0" | "1"
545
+ #
546
+ # DEBUG: __main__:last_parsed_rule: d: := "0"e
547
+ #
548
+ # parse state:
549
+ # Grammar Rules:
550
+ # < 0 > root: := < 2 > [0 - 0] < 5 > d | < 9 > [1 - 1] < 12 > a
551
+ # < 16 > a: := < 18 > [0 - 0] < 21 > c | < 25 > [1 - 1] < 28 > b
552
+ # < 32 > b: := < 34 > [0 - 0] | < 39 > [1 - 1]
553
+ # < 44 > c: := < 46 > [0 - 0] | < 51 > [1 - 1]
554
+ # < 56 > d: := < 58 > [0 - 0] < 61 > e
555
+ # < 65 > e: := < 67 > [0 - 0]
556
+ #
557
+ # Grammar Hex representation:
558
+ # 0000 0006 0002 0030 0030 0001 0001 0000
559
+ # 0006 0002 0031 0031 0001 0002 0000 0000
560
+ # 0002 0006 0002 0030 0030 0001 0003 0000
561
+ # 0006 0002 0031 0031 0001 0004 0000 0000
562
+ # 0004 0004 0002 0030 0030 0000 0004 0002
563
+ # 0031 0031 0000 0000 0003 0004 0002 0030
564
+ # 0030 0000 0004 0002 0031 0031 0000 0000
565
+ # 0001 0006 0002 0030 0030 0001 0005 0000
566
+ # 0000 0005 0004 0002 0030 0030 0000 0000 ffff
567
+ #
568
+ # Rules Decimal representation:
569
+ # < 0 > [[6, 2, 48, 48, 1, 1, 0], [6, 2, 49, 49, 1, 2, 0]]
570
+ # < 2 > [[6, 2, 48, 48, 1, 3, 0], [6, 2, 49, 49, 1, 4, 0]]
571
+ # < 4 > [[4, 2, 48, 48, 0], [4, 2, 49, 49, 0]]
572
+ # < 3 > [[4, 2, 48, 48, 0], [4, 2, 49, 49, 0]]
573
+ # < 1 > [[6, 2, 48, 48, 1, 5, 0]]
574
+ # < 5 > [[4, 2, 48, 48, 0]]
575
+
576
+
transformers_gad/parser_cfg.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import sys
4
+ from typing import List
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ END_OF_ALTERNATE_MARKER = 0
9
+ END_OF_RULE_MARKER = 0
10
+ END_OF_GRAMMAR_MARKER = 0xFFFF
11
+ TO_BE_FILLED_MARKER = 0
12
+ REF_RULE_MARKER = 1
13
+ LITERAL_MARKER = 2
14
+
15
+
16
+ ########################
17
+ # EBNF Grammar Parsing #
18
+ ########################
19
+
20
+
21
+ class ParseState:
22
+ def __init__(self):
23
+ self.symbol_table = {}
24
+ self.grammar_encoding = [] # old name: out_grammar
25
+
26
+ def print(self, file=sys.stdout):
27
+ print_grammar(file, self)
28
+
29
+
30
+ def get_symbol_id(state: ParseState, symbol_name: str) -> int:
31
+ if symbol_name not in state.symbol_table:
32
+ state.symbol_table[symbol_name] = len(state.symbol_table)
33
+ return state.symbol_table[symbol_name]
34
+
35
+
36
+ def generate_symbol_id(state: ParseState, base_name: str) -> int:
37
+ next_id = len(state.symbol_table)
38
+ state.symbol_table[base_name + "_" + str(next_id)] = next_id
39
+ return next_id
40
+
41
+
42
+ def is_word_char(c: str) -> bool:
43
+ """
44
+ Check if a char is a-z, A-Z, 0-9, -, _, i.e., chars allowed as rule names
45
+ Returns:
46
+
47
+ """
48
+ return c.isalnum() or c == "-" or c == "_"
49
+
50
+
51
+ def hex_to_int(c: str) -> int:
52
+ """
53
+ Convert a hex char to int, c should be in the range of 0-9, a-f, A-F
54
+ case insensitive
55
+ Args:
56
+ c: a hex char
57
+ Returns:
58
+ int: the int value of the hex char
59
+ """
60
+ if c.isdigit():
61
+ return int(c)
62
+ elif "a" <= c.lower() <= "f":
63
+ return ord(c.lower()) - ord("a") + 10
64
+ return -1
65
+
66
+
67
+ def remove_leading_white_space(src, rm_leading_newline):
68
+ """
69
+ Skips over whitespace and comments in the input string.
70
+
71
+ This function processes the input string, skipping over any spaces, tabs,
72
+ and content following a '#' character, which denotes a comment. The parsing
73
+ of a comment continues until the end of the line (denoted by newline characters
74
+ '\r' or '\n'). If the 'rm_leading_newline' parameter is set to False, the function
75
+ will stop processing and return the remaining string upon encountering a
76
+ newline character, otherwise it will skip over newline characters as well.
77
+
78
+ Parameters:
79
+ src (str): The input string to be processed.
80
+ rm_leading_newline (bool): A flag indicating whether encountering a newline character
81
+ should stop the parsing (False) or if it should be skipped (True).
82
+
83
+ Returns:
84
+ str: The remaining portion of the input string after skipping whitespace and comments.
85
+ """
86
+ pos = 0
87
+ while pos < len(src) and (src[pos].isspace() or src[pos] == "#"):
88
+ if src[pos] == "#":
89
+ while pos < len(src) and src[pos] not in ("\r", "\n"):
90
+ pos += 1
91
+ else:
92
+ if not rm_leading_newline and src[pos] in ("\r", "\n"):
93
+ break
94
+ pos += 1
95
+ return src[pos:]
96
+
97
+
98
+ def parse_name(src) -> (str, str):
99
+ """
100
+ parse the leading name from the input string
101
+ Args:
102
+ src: the input grammar string
103
+
104
+ Returns:
105
+ name, remaining_src
106
+ """
107
+ pos = 0
108
+ while pos < len(src) and is_word_char(src[pos]):
109
+ pos += 1
110
+ if pos == 0:
111
+ raise RuntimeError("expecting name at " + src)
112
+ return src[:pos], src[pos:]
113
+
114
+
115
+ def parse_char(src) -> (str, str):
116
+ """
117
+ parse the leading char from the input string
118
+ :param src:
119
+ :return: char, remaining_src
120
+ """
121
+
122
+ # if we have a backslash, it's maybe an escape
123
+ if src[0] == "\\":
124
+ esc = src[1]
125
+ if esc == "x":
126
+ first = hex_to_int(src[2])
127
+ if first > -1:
128
+ second = hex_to_int(src[3])
129
+ if second > -1:
130
+ return (first << 4) + second, src[4:]
131
+ raise RuntimeError("expecting \\xNN at " + src)
132
+ elif esc in ('"', "[", "]"):
133
+ return esc, src[2:]
134
+ elif esc == "r":
135
+ return "\r", src[2:]
136
+ elif esc == "n":
137
+ return "\n", src[2:]
138
+ elif esc == "t":
139
+ return "\t", src[2:]
140
+ raise RuntimeError("unknown escape at " + src)
141
+ elif src:
142
+ return src[0], src[1:]
143
+ raise RuntimeError("unexpected end of input")
144
+
145
+
146
+ def _parse_rhs_literal_string(src: str, outbuf: List[int]) -> str:
147
+ assert src[0] == '"', f"rule should start with '\"', but got {src[0]}"
148
+ remaining_src = src[1:]
149
+
150
+ # advance until we get an end quote or run out of input
151
+ while remaining_src and remaining_src[0] != '"':
152
+ char, remaining_src = parse_char(remaining_src)
153
+ outbuf.append(LITERAL_MARKER)
154
+ outbuf.append(ord(char))
155
+ outbuf.append(ord(char))
156
+
157
+ # in case we ran out of input before finding the end quote
158
+ if not remaining_src:
159
+ raise RuntimeError(f"expecting an end quote at {src},but not found")
160
+
161
+ # remove the end quote and return the remaining string
162
+ return remaining_src[1:]
163
+
164
+
165
+ def _parse_rhs_char_ranges(src: str, outbuf: List[int]) -> str:
166
+ assert src[0] == "[", f"rule should start with '[', but got {src[0]}"
167
+ remaining_src = src[1:]
168
+ start_idx = len(outbuf)
169
+ # num chars in range - replaced at end of loop
170
+ outbuf.append(TO_BE_FILLED_MARKER)
171
+ while remaining_src and remaining_src[0] != "]":
172
+ char, remaining_src = parse_char(remaining_src)
173
+
174
+ outbuf.append(ord(char))
175
+ if remaining_src[0] == "-" and remaining_src[1] != "]":
176
+ endchar_pair, remaining_src = parse_char(remaining_src[1:])
177
+ outbuf.append(ord(endchar_pair))
178
+ else:
179
+ # This is the case for enumerate, e.g., [0123456789], [abcdef]
180
+ # Each char is considered as a range of itself, i.e., c-c
181
+ outbuf.append(ord(char))
182
+ if not remaining_src:
183
+ raise RuntimeError(
184
+ f"expecting an ] at {src},but not found, is the char range closed?"
185
+ )
186
+ # replace num chars with actual
187
+ outbuf[start_idx] = len(outbuf) - start_idx - 1
188
+ return remaining_src[1:]
189
+
190
+
191
+ def _parse_rhs_symbol_reference(src: str, state: ParseState, outbuf: List[int]) -> str:
192
+ assert is_word_char(src[0]), f"rule should start with a word char, but got {src[0]}"
193
+ name, remaining_src = parse_name(src)
194
+ ref_rule_id = get_symbol_id(state, name)
195
+ outbuf.append(REF_RULE_MARKER)
196
+ outbuf.append(ref_rule_id)
197
+ return remaining_src
198
+
199
+
200
+ def _parse_rhs_grouping(
201
+ remaining_src: str, state: ParseState, rule_name: str, outbuf: List[int]
202
+ ) -> str:
203
+ assert (
204
+ remaining_src[0] == "("
205
+ ), f"rule should start with '(', but got {remaining_src[0]}"
206
+ remaining_src = remove_leading_white_space(remaining_src[1:], True)
207
+ # parse nested alternates into synthesized rule
208
+ synthetic_rule_id = generate_symbol_id(state, rule_name)
209
+ remaining_src = parse_rhs(state, remaining_src, rule_name, synthetic_rule_id, True)
210
+ # output reference to synthesized rule
211
+ outbuf.append(REF_RULE_MARKER)
212
+ outbuf.append(synthetic_rule_id)
213
+
214
+ if not remaining_src or remaining_src[0] != ")":
215
+ raise RuntimeError("expecting ')' at " + remaining_src)
216
+ return remaining_src[1:]
217
+
218
+
219
+ def _parse_rhs_repetition_operators(
220
+ remaining_src: str,
221
+ state: ParseState,
222
+ rule_name: str,
223
+ last_sym_start: int,
224
+ outbuf: List[int],
225
+ ) -> str:
226
+ assert remaining_src[0] in (
227
+ "*",
228
+ "+",
229
+ "?",
230
+ ), f"rule should start with '*', '+', or '?', but got {remaining_src[0]}"
231
+ out_grammar = state.grammar_encoding
232
+ # last_sym_start = len(outbuf)
233
+
234
+ # apply transformation to previous symbol (last_sym_start -
235
+ # end) according to rewrite rules:
236
+ # S* --> S' ::= S S' |
237
+ # S+ --> S' ::= S S' | S
238
+ # S? --> S' ::= S |
239
+ sub_rule_id = generate_symbol_id(state, rule_name)
240
+ out_grammar.append(sub_rule_id)
241
+ sub_rule_offset = len(out_grammar)
242
+ # placeholder for size of 1st alternate
243
+ out_grammar.append(TO_BE_FILLED_MARKER)
244
+ # add preceding symbol to generated rule
245
+ out_grammar.extend(outbuf[last_sym_start:])
246
+ if remaining_src[0] in ("*", "+"):
247
+ # cause generated rule to recurse
248
+ out_grammar.append(REF_RULE_MARKER)
249
+ out_grammar.append(sub_rule_id)
250
+ # apply actual size
251
+ out_grammar[sub_rule_offset] = len(out_grammar) - sub_rule_offset
252
+ # mark end of 1st alternate
253
+ out_grammar.append(END_OF_ALTERNATE_MARKER)
254
+ sub_rule_offset = len(out_grammar)
255
+ # placeholder for size of 2nd alternate
256
+ out_grammar.append(TO_BE_FILLED_MARKER)
257
+ if remaining_src[0] == "+":
258
+ # add preceding symbol as alternate only for '+'
259
+ out_grammar.extend(outbuf[last_sym_start:])
260
+ # apply actual size of 2nd alternate
261
+ out_grammar[sub_rule_offset] = len(out_grammar) - sub_rule_offset
262
+ # mark end of 2nd alternate, then end of rule
263
+ out_grammar.append(END_OF_ALTERNATE_MARKER)
264
+ out_grammar.append(END_OF_RULE_MARKER)
265
+
266
+ # in original rule, replace previous symbol with reference to generated rule
267
+ outbuf[last_sym_start:] = [REF_RULE_MARKER, sub_rule_id]
268
+ return remaining_src[1:]
269
+
270
+
271
+ def parse_simple_rhs(state, rhs: str, rule_name: str, outbuf, is_nested):
272
+ simple_rhs_offset = len(outbuf)
273
+
274
+ # sequence size, will be replaced at end when known
275
+ outbuf.append(TO_BE_FILLED_MARKER)
276
+
277
+ last_sym_start = len(outbuf)
278
+ remaining_rhs = rhs
279
+ while remaining_rhs:
280
+ if remaining_rhs[0] == '"': # literal string
281
+ # mark the start of the last symbol, for repetition operator
282
+ last_sym_start = len(outbuf)
283
+ remaining_rhs = _parse_rhs_literal_string(remaining_rhs, outbuf)
284
+ elif remaining_rhs[0] == "[": # char range(s)
285
+ # mark the start of the last symbol, for repetition operator
286
+ last_sym_start = len(outbuf)
287
+ remaining_rhs = _parse_rhs_char_ranges(remaining_rhs, outbuf)
288
+ elif is_word_char(remaining_rhs[0]): # rule reference
289
+ # mark the start of the last symbol, for repetition operator
290
+ last_sym_start = len(outbuf)
291
+ remaining_rhs = _parse_rhs_symbol_reference(remaining_rhs, state, outbuf)
292
+ elif remaining_rhs[0] == "(": # grouping
293
+ # mark the start of the last symbol, for repetition operator
294
+ last_sym_start = len(outbuf)
295
+ remaining_rhs = _parse_rhs_grouping(remaining_rhs, state, rule_name, outbuf)
296
+ elif remaining_rhs[0] in ("*", "+", "?"): # repetition operator
297
+ # No need to mark the start of the last symbol, because we already did it
298
+ if len(outbuf) - simple_rhs_offset - 1 == 0:
299
+ raise RuntimeError(
300
+ "expecting preceeding item to */+/? at " + remaining_rhs
301
+ )
302
+ remaining_rhs = _parse_rhs_repetition_operators(
303
+ remaining_rhs, state, rule_name, last_sym_start, outbuf
304
+ )
305
+ else:
306
+ # case for newline, i.e., end of rule
307
+ assert remaining_rhs[0] in [
308
+ "\n",
309
+ "|",
310
+ ")",
311
+ ], f"rule should end with newline or '|', but got {remaining_rhs[0]}"
312
+ # we break here so that we call parse_rule again to parse the next rule
313
+ break
314
+ # Here we do not rm newline deliberately so that we know the rhs is ended
315
+ remaining_rhs = remove_leading_white_space(
316
+ remaining_rhs, rm_leading_newline=is_nested
317
+ )
318
+
319
+ # apply actual size of this alternate sequence
320
+ outbuf[simple_rhs_offset] = len(outbuf) - simple_rhs_offset
321
+ # mark end of alternate
322
+ outbuf.append(END_OF_ALTERNATE_MARKER)
323
+ return remaining_rhs
324
+
325
+
326
+ def parse_rhs(state, rhs: str, rule_name, rule_id, is_nested):
327
+ outbuf = []
328
+ remaining_rhs = parse_simple_rhs(state, rhs, rule_name, outbuf, is_nested)
329
+ while remaining_rhs and remaining_rhs[0] == "|":
330
+ remaining_rhs = remove_leading_white_space(remaining_rhs[1:], True)
331
+ remaining_rhs = parse_simple_rhs(
332
+ state, remaining_rhs, rule_name, outbuf, is_nested
333
+ )
334
+
335
+ # Now we have finished parsing the rhs, we can add the rule to the grammar_encoding
336
+ state.grammar_encoding.append(rule_id)
337
+ state.grammar_encoding.extend(outbuf)
338
+ state.grammar_encoding.append(END_OF_RULE_MARKER)
339
+ return remaining_rhs
340
+
341
+
342
+ def parse_rule(state: ParseState, rule_text: str) -> str:
343
+ name, remaining_rule_text = parse_name(rule_text)
344
+ remaining_rule_text = remove_leading_white_space(remaining_rule_text, False)
345
+ # check if the rule is already defined, TODO: what will happen if the rule is already defined?
346
+ rule_id = get_symbol_id(state, name)
347
+
348
+ if remaining_rule_text[:3] != "::=":
349
+ raise RuntimeError("expecting ::= at " + remaining_rule_text)
350
+ remaining_rule_text = remove_leading_white_space(remaining_rule_text[3:], True)
351
+
352
+ remaining_rule_text = parse_rhs(state, remaining_rule_text, name, rule_id, False)
353
+
354
+ if remaining_rule_text and remaining_rule_text[0] == "\r":
355
+ remaining_rule_text = (
356
+ remaining_rule_text[2:]
357
+ if remaining_rule_text[1] == "\n"
358
+ else remaining_rule_text[1:]
359
+ )
360
+ elif remaining_rule_text and remaining_rule_text[0] == "\n":
361
+ remaining_rule_text = remaining_rule_text[1:]
362
+ elif remaining_rule_text:
363
+ raise RuntimeError("expecting newline or end at " + remaining_rule_text)
364
+ return remove_leading_white_space(remaining_rule_text, True)
365
+
366
+
367
+ def parse_ebnf(grammar_text: str) -> ParseState:
368
+ try:
369
+ state = ParseState()
370
+ remaining_grammar_text = remove_leading_white_space(grammar_text, True)
371
+ last_grammar_repr = ""
372
+ while remaining_grammar_text:
373
+ if last_grammar_repr:
374
+ last_parsed_rule_len = len(last_grammar_repr) - len(
375
+ remaining_grammar_text
376
+ )
377
+ logger.debug(
378
+ f"last_parsed_rule: {last_grammar_repr[:last_parsed_rule_len]}"
379
+ )
380
+ last_grammar_repr = remaining_grammar_text
381
+ remaining_grammar_text = parse_rule(state, remaining_grammar_text)
382
+ state.grammar_encoding.append(END_OF_GRAMMAR_MARKER)
383
+ return state
384
+ except RuntimeError as err:
385
+ logger.warning("error parsing grammar:", err)
386
+ return ParseState()
387
+
388
+
389
+ ###################################
390
+ # EBNF Grammar Parsing ends here #
391
+ ###################################
392
+
393
+
394
+ def break_grammar_into_rules(grammar_encoding: List[int]) -> List[List[int]]:
395
+ offset = 0
396
+ # we loop until we reach the end of the grammar_encoding
397
+ rule_encodings = []
398
+ i = 0
399
+ while i < len(grammar_encoding) - 2:
400
+ if (
401
+ grammar_encoding[i] == END_OF_ALTERNATE_MARKER
402
+ and grammar_encoding[i + 1] == END_OF_RULE_MARKER
403
+ ):
404
+ rule_encodings.append(grammar_encoding[offset : i + 2])
405
+ offset = i + 2
406
+ # skip the END_OF_RULE_MARKER
407
+ # This is mandatory because if we do not skip the END_OF_RULE_MARKER
408
+ # we fail in the case where the next rule has rule_id 0
409
+ i += 1
410
+ i += 1
411
+ return rule_encodings
412
+
413
+
414
+ def break_rule_into_elements(rule_encoding: List[int]) -> List[List[int]]:
415
+ rule_id = rule_encoding.pop(0)
416
+ end_of_rule_marker = rule_encoding.pop(-1)
417
+ assert (
418
+ end_of_rule_marker == END_OF_RULE_MARKER
419
+ ), f"rule should end with {END_OF_RULE_MARKER}, but got {end_of_rule_marker}"
420
+
421
+ offset = 0
422
+ elements = []
423
+ while offset < len(rule_encoding):
424
+ element_size = rule_encoding[offset]
425
+ assert (
426
+ rule_encoding[offset + element_size] == END_OF_ALTERNATE_MARKER
427
+ ), f"element should end with {END_OF_ALTERNATE_MARKER}, but got {rule_encoding[offset + element_size]}"
428
+ elements.append(rule_encoding[offset : offset + element_size + 1])
429
+ offset += element_size + 1
430
+ return elements
431
+
432
+
433
+ def _print_annotated_grammar(file, grammar_encoding, symbol_id_names, index=0):
434
+ rule_id = grammar_encoding[index]
435
+ print(f"<{index}>{symbol_id_names[rule_id]} ::=", end=" ", file=file)
436
+ pos = index + 1
437
+ while grammar_encoding[pos]:
438
+ if pos - 1 > index:
439
+ print("|", end=" ", file=file)
440
+ pos += 1 # sequence size, not needed here
441
+ while grammar_encoding[pos]:
442
+ if grammar_encoding[pos] == REF_RULE_MARKER:
443
+ ref_rule_id = grammar_encoding[pos + 1]
444
+ print(
445
+ f"<{pos}>{symbol_id_names[ref_rule_id]}",
446
+ end=" ",
447
+ file=file,
448
+ )
449
+ pos += 2
450
+ else:
451
+ print("<{}>[".format(pos), end="", file=file)
452
+ num_chars = grammar_encoding[pos]
453
+ pos += 1
454
+
455
+ for i in range(0, num_chars, 2):
456
+ print(
457
+ "{}-".format(chr(grammar_encoding[pos + i])), end="", file=file
458
+ )
459
+ if i + 1 < num_chars:
460
+ print(
461
+ "{}".format(chr(grammar_encoding[pos + i + 1])),
462
+ end="",
463
+ file=file,
464
+ )
465
+ print("]", end=" ", file=file)
466
+ pos += num_chars
467
+ pos += 1
468
+ print(file=file)
469
+ return pos + 1
470
+
471
+
472
+ def print_grammar(file, state):
473
+ pos = 0
474
+ symbol_id_names = {v: k for k, v in state.symbol_table.items()}
475
+ print("Grammar Rules:", file=file)
476
+ while (
477
+ pos < len(state.grammar_encoding)
478
+ and state.grammar_encoding[pos] != END_OF_GRAMMAR_MARKER
479
+ ):
480
+ pos = _print_annotated_grammar(
481
+ file, state.grammar_encoding, symbol_id_names, pos
482
+ )
483
+ if pos > len(state.grammar_encoding):
484
+ raise Warning(f"grammar_encoding is not ended with {END_OF_GRAMMAR_MARKER}")
485
+ pos = 0
486
+ print("\nGrammar Hex representation:", file=file)
487
+ while (
488
+ pos < len(state.grammar_encoding)
489
+ and state.grammar_encoding[pos] != END_OF_GRAMMAR_MARKER
490
+ ):
491
+ print(f"{state.grammar_encoding[pos]:04x}", end=" ", file=file)
492
+ pos += 1
493
+ if pos > len(state.grammar_encoding):
494
+ raise Warning(f"grammar_encoding is not ended with {END_OF_GRAMMAR_MARKER}")
495
+ else:
496
+ print("ffff\n")
497
+
498
+ print("Rules Decimal representation:", file=file)
499
+ # we loop until we reach the end of the grammar_encoding
500
+ rule_encodings = break_grammar_into_rules(state.grammar_encoding)
501
+ for rule_encoding in rule_encodings:
502
+ rule_id = rule_encoding[0]
503
+ print(
504
+ f"<{rule_id}> {break_rule_into_elements(rule_encoding)}",
505
+ file=file,
506
+ )
507
+
508
+
509
+ if __name__ == "__main__":
510
+ parser = argparse.ArgumentParser(description="Parse EBNF grammar files.")
511
+ parser.add_argument(
512
+ "-g",
513
+ "--grammar-file",
514
+ nargs="?",
515
+ default="examples/grammars/json.ebnf",
516
+ help="Path to the grammar file (default: examples/grammars/json.ebnf)",
517
+ )
518
+
519
+ args = parser.parse_args()
520
+
521
+ # set logging level
522
+ logging.basicConfig(level=logging.DEBUG)
523
+
524
+ with open(args.grammar_file, "r") as file:
525
+ input_text = file.read()
526
+ parsed_grammar = parse_ebnf(input_text)
527
+ parsed_grammar.print()
528
+ print(f"symbol_ids: \n{parsed_grammar.symbol_table}")
529
+
530
+ start_rule_id = parsed_grammar.symbol_table["root"]
transformers_gad/recognizer.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from functools import lru_cache
4
+ from typing import List, Tuple, Dict
5
+
6
+ from transformers_gad.parser import (
7
+ END_OF_RULE_MARKER,
8
+ END_OF_ALTERNATE_MARKER,
9
+ parse_ebnf,
10
+ REF_RULE_MARKER,
11
+ )
12
+ from transformers_gad.utf8_utils import PartialUTF8, decode_utf8
13
+ from transformers_gad.utils import intervals_intersect
14
+ import logging
15
+
16
+
17
+ class AcceptState:
18
+ def __init__(self, stacks, partial_utf8):
19
+ self.stacks = stacks
20
+ self.partial_utf8 = partial_utf8
21
+
22
+ @staticmethod
23
+ def empty_state():
24
+ return AcceptState([], PartialUTF8())
25
+
26
+
27
+ class StringRecognizer:
28
+ def __init__(
29
+ self,
30
+ grammar_encoding: List[int],
31
+ start_rule_id: int = None,
32
+ rule_offsets: List[int] = None,
33
+ stacks: List[List[int]] = None,
34
+ ):
35
+ # strictly speaking, we don't need to copy grammar_encoding because we don't modify it
36
+ # but we do it anyway to be safe
37
+ # in case where the grammar is very large, we can consider not copying it
38
+ self.grammar_encoding = grammar_encoding
39
+ if rule_offsets is not None:
40
+ self.rule_offsets = rule_offsets
41
+ else:
42
+ if start_rule_id is None:
43
+ raise ValueError("start_rule_id cannot be None if rule_offsets is None")
44
+ self.rule_offsets = self.init_rules(start_rule_id)
45
+ # each stack is a list of indices into grammar_encoding
46
+ # each index points to a rule's
47
+ if stacks is not None:
48
+ self.stacks = stacks
49
+ else:
50
+ if start_rule_id is None:
51
+ raise ValueError("start_rule_id cannot be None if stacks is None")
52
+ self.stacks: List[List[int]] = self.init_stack(start_rule_id)
53
+ self.start_rule_id = start_rule_id
54
+
55
+ def init_rules(self, start_rule_id: int) -> List[int]:
56
+ _rule_offset = 0
57
+ rule_offsets = []
58
+ # Build `rules` as an array of rule IDs to their positions in `grammar_src`
59
+ while self.grammar_encoding[_rule_offset] != 0xFFFF:
60
+ rule_id = self.grammar_encoding[_rule_offset]
61
+ # store the offset idx
62
+ if len(rule_offsets) <= rule_id:
63
+ rule_offsets.extend([-1] * (rule_id - len(rule_offsets) + 1))
64
+ rule_offsets[rule_id] = _rule_offset
65
+
66
+ # Skip rule ID
67
+ # _rule_offset += 1
68
+ simple_rhs_offset = _rule_offset + 1
69
+
70
+ # Skip rule alternates
71
+ while self.grammar_encoding[simple_rhs_offset] != END_OF_RULE_MARKER:
72
+ simple_rhs_offset = (
73
+ simple_rhs_offset + 1 + self.grammar_encoding[simple_rhs_offset]
74
+ )
75
+
76
+ # Skip 0 denoting end of rule
77
+ # _rule_offset += 1
78
+ _rule_offset = simple_rhs_offset + 1
79
+
80
+ retrieved_start_rule_id = self.grammar_encoding[rule_offsets[start_rule_id]]
81
+ assert retrieved_start_rule_id == start_rule_id
82
+
83
+ return rule_offsets
84
+
85
+ def init_stack(self, start_rule_id: int) -> List[List[int]]:
86
+
87
+ stacks = []
88
+ # Loop over alternates of start rule to build initial stacks
89
+ sub_rhs_offset = self.rule_offsets[start_rule_id] + 1
90
+ while self.grammar_encoding[sub_rhs_offset]:
91
+ stack: List[int] = []
92
+ # If alternate is nonempty, add to stack
93
+ element_offset = sub_rhs_offset + 1
94
+ if self.grammar_encoding[element_offset] != END_OF_ALTERNATE_MARKER:
95
+ stack.append(element_offset)
96
+ stacks.extend(self.advance_stack(tuple(stack)))
97
+ sub_rhs_offset += 1 + self.grammar_encoding[sub_rhs_offset]
98
+ return stacks
99
+
100
+ def get_initial_accept_state(self) -> AcceptState:
101
+ return AcceptState(self.init_stack(self.start_rule_id), PartialUTF8())
102
+
103
+ def get_termination_accept_state(self) -> AcceptState:
104
+ return AcceptState([], PartialUTF8())
105
+
106
+ @lru_cache(maxsize=32768)
107
+ def advance_stack(self, stack: Tuple[int]) -> List[List[int]]:
108
+ stack = list(stack)
109
+ if len(stack) == 0:
110
+ return [stack]
111
+
112
+ # we get the last element of the stack, which is the element we are currently processing
113
+ cur_element_offset = stack[-1]
114
+
115
+ # if the element is a terminal, we don't need to advance the stack
116
+ if self.grammar_encoding[cur_element_offset] != REF_RULE_MARKER:
117
+ return [stack]
118
+ # the remaining case is that the element is a non-terminal, i.e. a reference to another rule
119
+ else:
120
+ ref_rule_id = self.grammar_encoding[cur_element_offset + 1]
121
+ # find the offset of the referenced rule
122
+ ref_subrule_offset = self.rule_offsets[ref_rule_id] + 1
123
+ new_stacks: List[List[int]] = []
124
+ # Loop over alternates of referenced rule to build new stacks
125
+ while self.grammar_encoding[ref_subrule_offset] != END_OF_RULE_MARKER:
126
+ # copy the original stack without the last element
127
+ new_stack = stack[:-1]
128
+ # if the rule ref is followed by another element, we add it to the stack
129
+ next_element_offset = cur_element_offset + 2
130
+ if (
131
+ self.grammar_encoding[next_element_offset]
132
+ != END_OF_ALTERNATE_MARKER
133
+ ):
134
+ new_stack.append(next_element_offset)
135
+
136
+ # if the referenced rule is not empty, we add its element offset to the stack
137
+ ref_element_offset = ref_subrule_offset + 1
138
+ if self.grammar_encoding[ref_element_offset] != END_OF_ALTERNATE_MARKER:
139
+ new_stack.append(ref_element_offset)
140
+
141
+ new_stacks.extend(self.advance_stack(tuple(new_stack)))
142
+ ref_subrule_offset += self.grammar_encoding[ref_subrule_offset] + 1
143
+
144
+ return new_stacks
145
+
146
+ def _consume_byte(self, byte: int, accept_state: AcceptState):
147
+ # suppose we have code point 一, ord('一') = 19968, we need to match 3 bytes
148
+ # we need to match 3 bytes, so we need to call _consume_byte_partial_match 3 times
149
+ self._consume_bytes(bytes([byte]), accept_state)
150
+
151
+ # @lru_cache(maxsize=32768)
152
+ def _probe_bytes(
153
+ self,
154
+ byte_seq: bytes,
155
+ stacks: List[List[int]],
156
+ partial_utf8: PartialUTF8,
157
+ verbose=True,
158
+ ):
159
+ if type(byte_seq) is list:
160
+ byte_seq = bytes(byte_seq)
161
+ code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8)
162
+ if verbose:
163
+ logging.debug(
164
+ f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}"
165
+ )
166
+ new_stacks = self._consume_code_points(code_points, stacks)
167
+
168
+ for stack in new_stacks:
169
+
170
+ # stack is empty, meaning that the variables are all consumed
171
+ if len(stack) == 0:
172
+ return True
173
+ element_offset = stack[-1]
174
+ if self.partial_utf8_accept_at_element(element_offset, new_partial_utf8):
175
+ return True
176
+ return False
177
+
178
+ def _consume_bytes(
179
+ self,
180
+ byte_seq: bytes,
181
+ accept_state: AcceptState = None,
182
+ verbose=True,
183
+ ):
184
+ if accept_state is None:
185
+ accept_state = self.get_initial_accept_state()
186
+ stacks = accept_state.stacks
187
+ partial_utf8 = accept_state.partial_utf8
188
+ if type(byte_seq) is list:
189
+ byte_seq = bytes(byte_seq)
190
+ code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8)
191
+ if verbose:
192
+ logging.debug(
193
+ f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}"
194
+ )
195
+ new_stacks = self._consume_code_points(code_points, stacks)
196
+
197
+ new_new_stacks = []
198
+ for stack in new_stacks:
199
+ if len(stack) == 0:
200
+ continue
201
+ element_offset = stack[-1]
202
+ if self.partial_utf8_accept_at_element(element_offset, new_partial_utf8):
203
+ new_new_stacks.append(stack)
204
+ return AcceptState(new_new_stacks, new_partial_utf8)
205
+
206
+ ##########################
207
+ #
208
+ # Code point recognition
209
+ #
210
+ ##########################
211
+
212
+ @lru_cache(maxsize=30000)
213
+ def _consume_code_point(
214
+ self, code_point: int, stacks: Tuple[Tuple[int]]
215
+ ) -> List[List[int]]:
216
+ """
217
+ consume a character from the stack
218
+ char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127]
219
+ """
220
+ new_stacks = []
221
+
222
+ stacks: List[List[int]] = list([list(stack) for stack in stacks])
223
+ if code_point == 0:
224
+ return new_stacks
225
+ for stack in stacks:
226
+ new_stacks.extend(
227
+ self._consume_code_point_per_stack(code_point, tuple(stack))
228
+ )
229
+ return new_stacks
230
+
231
+ @lru_cache(maxsize=30000)
232
+ def _consume_code_point_per_stack(
233
+ self, code_point: int, stack: Tuple[int]
234
+ ) -> List[List[int]]:
235
+ """
236
+ consume a character from the stack
237
+ char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127]
238
+ """
239
+ # TODO, the below code will raise an error when the stack is empty, but why is this happening?
240
+ # if len(stacks) == 0:
241
+ # raise ValueError("Stacks don't contain any stack, meaning that no character can be consumed")
242
+ # code_point = 0 is a special case when the uf8 sequence is not complete, we return an empty stack
243
+ # to indicate that the character is not accepted
244
+ stack = list(stack)
245
+ new_stacks = []
246
+ if code_point == 0:
247
+ return new_stacks
248
+ # stack is empty
249
+ if len(stack) == 0:
250
+ return new_stacks
251
+
252
+ element_offset = stack[-1]
253
+
254
+ found = self.accept_code_point_at_element(code_point, element_offset)
255
+ if not found:
256
+ return new_stacks
257
+
258
+ size = self.grammar_encoding[element_offset]
259
+ element_offset += size + 1
260
+ new_stack = stack[:-1]
261
+ if self.grammar_encoding[element_offset]:
262
+ new_stack.append(element_offset)
263
+ return self.advance_stack(tuple(new_stack))
264
+
265
+ def _consume_code_points(
266
+ self, code_points: List[int], stacks: List[List[int]], verbose=False
267
+ ) -> List[List[int]]:
268
+ for i, code_point in enumerate(code_points):
269
+ # for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
270
+ tuple_stacks: Tuple[Tuple[int]] = tuple([tuple(stack) for stack in stacks])
271
+ stacks = self._consume_code_point(code_point, tuple_stacks)
272
+ if len(stacks) > 0 and verbose:
273
+ accepted_code_point = code_points[: i + 1]
274
+ corresponding_char = chr(code_point)
275
+ logging.debug(
276
+ f"code point {accepted_code_point} corresponding to {corresponding_char} is accepted"
277
+ )
278
+ return stacks
279
+
280
+ def _accept_code_points(
281
+ self, code_points: List[int], stacks: List[List[int]], verbose=False
282
+ ) -> bool:
283
+ stacks = self._consume_code_points(code_points, stacks, verbose)
284
+ return len(stacks) > 0
285
+
286
+ @lru_cache(maxsize=30000)
287
+ def accept_code_point_at_element(
288
+ self, code_point: int, element_offset: int
289
+ ) -> bool:
290
+ size = self.grammar_encoding[element_offset]
291
+ # to make idx point to the range_start of the first range
292
+ element_offset += 1
293
+ for i in range(0, size, 2):
294
+ if (
295
+ self.grammar_encoding[element_offset + i]
296
+ <= code_point
297
+ <= self.grammar_encoding[element_offset + i + 1]
298
+ ):
299
+ return True
300
+ return False
301
+
302
+ # def _accept_code_point(self, code_point: int, stacks: List[List[int]]):
303
+ # # for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
304
+ # tuple_stacks: Tuple[Tuple[int]] = tuple([tuple(stack) for stack in stacks])
305
+ # new_stacks: List[List[int]] = self._consume_code_point(code_point, tuple_stacks)
306
+ # return len(new_stacks) > 0
307
+
308
+ #############################
309
+ #
310
+ # Partial UTF-8 recognition
311
+ #
312
+ #############################
313
+
314
+ def partial_utf8_accept_at_element(
315
+ self, element_offset: int, partial_utf8: PartialUTF8
316
+ ) -> bool:
317
+ # Extract the accumulated value and the number of remaining bytes from the partial_utf8 object.
318
+ partial_value = partial_utf8.value
319
+ n_remain = partial_utf8.n_remain
320
+
321
+ # Return False if there are no remaining bytes to process or if it's an invalid UTF-8 sequence.
322
+ if n_remain == 1 and partial_value < 2:
323
+ return False
324
+
325
+ # If there are no remaining bytes, this means we had already consumed a complete UTF-8 sequence.
326
+ if n_remain <= 0:
327
+ return True
328
+
329
+ # Calculate the lowest possible Unicode code point that can be formed with the remaining bytes.
330
+ low = partial_value << (n_remain * 6)
331
+ # Calculate the highest possible Unicode code point by setting all remaining bits to 1.
332
+ high = low | ((1 << (n_remain * 6)) - 1)
333
+
334
+ # If the low end of the range is 0 and a specific number of bytes remain, adjust low to the minimum value
335
+ # that can be represented with that number of bytes. This accounts for UTF-8 encoding rules.
336
+ if low == 0:
337
+ if n_remain == 2:
338
+ low = 1 << 11 # Minimum value representable with 2 additional bytes.
339
+ elif n_remain == 3:
340
+ low = 1 << 16 # Minimum value representable with 3 additional bytes.
341
+
342
+ # Get the size of the grammar rule starting at the current element_offset.
343
+ size = self.grammar_encoding[element_offset]
344
+ # Move the element_offset to the start of the grammar rule's definition.
345
+ element_offset += 1
346
+
347
+ # Iterate over the grammar rule, checking if the range defined by low-high overlaps with any specified ranges.
348
+ for i in range(0, size, 2):
349
+ # If the current range (specified in the grammar encoding) overlaps with the low-high range, return True.
350
+ if intervals_intersect(
351
+ low,
352
+ high,
353
+ self.grammar_encoding[element_offset + i],
354
+ self.grammar_encoding[element_offset + i + 1],
355
+ ):
356
+ return True
357
+
358
+ # If no overlap is found with any of the ranges, return False, indicating no valid partial match.
359
+ return False
360
+
361
+ #############################
362
+ #
363
+ # String recognition
364
+ #
365
+ #############################
366
+
367
+ def _consume_string(self, string: str, accept_state: AcceptState):
368
+ # _bytes = bytes(string, "utf-8")
369
+ code_points = [ord(char) for char in string]
370
+ stacks = self._consume_code_points(code_points, accept_state.stacks)
371
+ return AcceptState(stacks, accept_state.partial_utf8)
372
+
373
+ def _accept_prefix(self, string: str, accept_state: AcceptState = None):
374
+ if accept_state is None:
375
+ accept_state = self.get_initial_accept_state()
376
+ new_accept_state = self._consume_string(string, accept_state)
377
+ return len(new_accept_state.stacks) > 0
378
+
379
+ def _accept_string(self, string: str, accept_state: AcceptState = None):
380
+ if accept_state is None:
381
+ accept_state = self.get_initial_accept_state()
382
+ new_accept_state = self._consume_string(string, accept_state)
383
+ at_least_one_stack_is_empty = any(
384
+ len(stack) == 0 for stack in new_accept_state.stacks
385
+ )
386
+ return at_least_one_stack_is_empty
387
+
388
+ def _can_stop(self, stacks: List[List[int]]):
389
+ # This happens in practice, but maybe it shouldn't? TODO
390
+ if len(stacks) == 0:
391
+ return True
392
+ # if any of the stack is empty, we can stop
393
+ for stack in stacks:
394
+ if len(stack) == 0:
395
+ return True
396
+ else:
397
+ return False
398
+
399
+ def _must_stop(self, stacks: List[List[int]]):
400
+ return len(stacks) == 0 or all(len(stack) == 0 for stack in stacks)
401
+
402
+ #############################
403
+ #
404
+ # Not Used
405
+ #
406
+ #############################
407
+
408
+ # For each sub-rule in the grammar, cache whether each byte is accepted.
409
+ @lru_cache(maxsize=None)
410
+ def char_acceptance_at_element(self, element_offset):
411
+ """
412
+ Caches and returns a dictionary indicating whether a Unicode character is accepted
413
+ at a given rule position. This function considers Unicode characters, dynamically
414
+ inserting accepted ranges into a dictionary to optimize memory usage.
415
+
416
+ Args:
417
+ - rule_offset: The offset in the grammar encoding where the rule starts.
418
+
419
+ Returns:
420
+ - A dictionary where each key is a Unicode character (or range) and the value is True if accepted.
421
+ """
422
+ logging.debug(f"element_offset: {element_offset}")
423
+ acceptance = {}
424
+ num_chars = self.grammar_encoding[element_offset]
425
+ element_offset += 1
426
+ for i in range(0, num_chars, 2):
427
+ start = self.grammar_encoding[element_offset + i]
428
+ end = self.grammar_encoding[element_offset + i + 1]
429
+ for j in range(start, end + 1):
430
+ acceptance[j] = True
431
+ logging.debug(acceptance)
432
+ return acceptance
433
+
434
+ def _consume_code_points_new(
435
+ self, code_points: List[int], stacks: List[List[int]], verbose=False
436
+ ) -> List[List[int]]:
437
+ new_stacks: List[List[int]] = []
438
+ for stack in stacks:
439
+ new_stacks.extend(
440
+ self._consume_code_points_per_stack(
441
+ tuple(code_points), tuple(stack), verbose
442
+ )
443
+ )
444
+ return new_stacks
445
+
446
+ @lru_cache(maxsize=30000)
447
+ def _consume_code_points_per_stack(
448
+ self, code_points: Tuple[int], stack: Tuple[int], verbose=False
449
+ ) -> List[List[int]]:
450
+ code_points = list(code_points)
451
+ stacks = (stack,)
452
+ for i, code_point in enumerate(code_points):
453
+ # for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
454
+ stacks = self._consume_code_point(code_point, stacks)
455
+ stacks = tuple([tuple(stack) for stack in stacks])
456
+ return [list(stack) for stack in stacks]
transformers_gad/token_grammar_recognizer.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from abc import ABC
4
+ from functools import lru_cache
5
+ from typing import List
6
+
7
+ import torch
8
+
9
+ from transformers_gad.recognizer import StringRecognizer, AcceptState
10
+ from transformers_gad.parser import parse_ebnf
11
+ from transformers_gad.trie import ByteTrie
12
+ from transformers_gad.utf8_utils import PartialUTF8
13
+ from .vocab_struct import LEAF, TokenTrie
14
+ from transformers_gad.mapping import get_mapping
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class AbsTokenRecognizer(ABC):
20
+ def __init__(self, grammar_str, tokenizer, start_rule_name="root", unicode=False):
21
+ parsed_grammar = parse_ebnf(grammar_str)
22
+ grammar_encoding = parsed_grammar.grammar_encoding
23
+ self.start_rule_id = parsed_grammar.symbol_table.get(start_rule_name)
24
+ self.byte_encoding = unicode
25
+
26
+ if unicode and not tokenizer.__class__.__name__.lower().startswith(
27
+ "gpt2"
28
+ ): # gpt2tokenizer or gpt2tokenizerfast
29
+ raise ValueError(
30
+ "Constrained decoding with unicode is only supported for GPT2 model. Support for other models is coming soon."
31
+ "Or you can use the constraints with only ascii characters."
32
+ )
33
+
34
+ self.eos_token_id = tokenizer.eos_token_id
35
+ self.token_trie = TokenTrie(tokenizer)
36
+ self.tokenizer = tokenizer
37
+ self.string_recognizer = StringRecognizer(grammar_encoding, self.start_rule_id)
38
+ self.unicode_trie = ByteTrie.from_tokenizer(tokenizer, unicode=unicode)
39
+ self.mapping = get_mapping(tokenizer, unicode=unicode)
40
+ assert len(self.mapping) == len(
41
+ self.token_trie
42
+ ), f"{len(self.mapping)}, {len(self.token_trie)}"
43
+
44
+ def _consume_token_id(
45
+ self, token_id: int, accept_state: AcceptState
46
+ ) -> AcceptState:
47
+ if self.string_recognizer._must_stop(accept_state.stacks):
48
+ if token_id == self.eos_token_id:
49
+ return self.string_recognizer.get_termination_accept_state()
50
+ else:
51
+ raise ValueError(
52
+ f"All stacks are empty, so the only token accepted is EOS({self.eos_token_id}), but got {token_id}"
53
+ )
54
+ if token_id == self.eos_token_id:
55
+ if self.string_recognizer._can_stop(accept_state.stacks):
56
+ # if at least one of the stack is empty, we can stop
57
+ # we clear all the stacks, meaning that we don't accept any token after EOS
58
+ return self.string_recognizer.get_termination_accept_state()
59
+ else:
60
+ raise ValueError(
61
+ f"At least one of the stack should be empty when EOS is reached. However, "
62
+ f"the stacks are {accept_state.stacks}"
63
+ )
64
+
65
+ bytes_or_codepoints = self.mapping.map(token_id)
66
+ accept_state = self.string_recognizer._consume_bytes(
67
+ bytes_or_codepoints, accept_state
68
+ )
69
+ return accept_state
70
+
71
+ def probe_token_id(self, token_id: int, accept_state: AcceptState) -> bool:
72
+ stacks = accept_state.stacks
73
+ if self.string_recognizer._must_stop(stacks):
74
+ if token_id == self.eos_token_id:
75
+ return True
76
+ else:
77
+ return False
78
+ if token_id == self.eos_token_id:
79
+ if self.string_recognizer._can_stop(stacks):
80
+ # if at least one of the stack is empty, we can stop
81
+ # we clear all the stacks, meaning that we don't accept any token after EOS
82
+ return True
83
+ else:
84
+ return False
85
+ # for code_point in self.mapping.map(token_id):
86
+ # stacks = self.grammar._consume_char_code_point(code_point, stacks)
87
+ bytes_or_codepoints = self.mapping.map(token_id, verbose=False)
88
+ new_acc_state = self.string_recognizer._consume_bytes(
89
+ bytes_or_codepoints, accept_state, verbose=False
90
+ )
91
+ return len(new_acc_state.stacks) > 0
92
+
93
+ def advance_token_ids(self, *args, **kwargs):
94
+ """Process a list of tokens according to the grammar rules."""
95
+ raise NotImplementedError
96
+
97
+ def batch_filter_vocab(self, batch_accept_states, device) -> torch.Tensor:
98
+ batch_acceptance = []
99
+ for accept_state in batch_accept_states:
100
+ batch_acceptance.append(self.filter_vocab(accept_state, device))
101
+ return torch.stack(batch_acceptance)
102
+
103
+ def filter_vocab(self, accept_state, device) -> torch.Tensor:
104
+ if not accept_state.stacks: # Check if stacks is empty
105
+ # Handle the empty case: for example, return a tensor of False
106
+ # The size of the tensor should match the size of your vocabulary
107
+ vocab_size = len(self.mapping)
108
+ logger.debug(f"Empty stack, sum of acceptance: {0}")
109
+ # size of the vocab
110
+ accepts = [False] * vocab_size
111
+ accepts[self.eos_token_id] = True
112
+ return torch.tensor(accepts, dtype=torch.bool, device=device)
113
+
114
+ acceptance = self.get_token_acceptance(accept_state, device)
115
+
116
+ return acceptance
117
+
118
+ def get_token_acceptance(self, accept_state, device) -> torch.Tensor:
119
+ acceptance_matrix = torch.cat(
120
+ [
121
+ self.get_token_acceptance_array_for_stack(
122
+ tuple(stack), accept_state.partial_utf8, device
123
+ )
124
+ for stack in accept_state.stacks
125
+ ]
126
+ )
127
+ # Merge stacks: any True => True
128
+ acceptance = acceptance_matrix.reshape(len(accept_state.stacks), -1).any(dim=0)
129
+ return acceptance
130
+
131
+ @lru_cache(maxsize=32768)
132
+ def get_token_acceptance_array_for_stack(self, stack, partial_utf8, device):
133
+ # stack = list(stack) # needs to come in as a tuple for lru_cache
134
+ assert isinstance(stack, tuple)
135
+ stack = list(stack)
136
+
137
+ if self.byte_encoding:
138
+
139
+ accept_f = lambda x: self.string_recognizer._probe_bytes(
140
+ x, [stack], partial_utf8=partial_utf8
141
+ )
142
+ token_acceptance = self.unicode_trie.get_token_acceptance(
143
+ accept=accept_f, accept_eos=False, eos_token_id=self.eos_token_id
144
+ )
145
+ else:
146
+ accepts = [False] * len(self.mapping)
147
+ token_acceptance = check_token_acceptance_in_trie(
148
+ self.token_trie.trie,
149
+ [stack],
150
+ self.string_recognizer,
151
+ self.eos_token_id,
152
+ accepts,
153
+ )
154
+ x = torch.tensor(token_acceptance, dtype=torch.bool, device=device)
155
+ x_eos = self.validate_and_set_eos_acceptance(x)
156
+ return x_eos
157
+
158
+ def validate_and_set_eos_acceptance(self, acceptance: torch.Tensor) -> torch.Tensor:
159
+ if torch.any(acceptance) == 0:
160
+ acceptance[self.eos_token_id] = True
161
+ else:
162
+ if acceptance[self.eos_token_id]:
163
+ raise ValueError()
164
+ acceptance[self.eos_token_id] = False
165
+ return acceptance
166
+
167
+
168
+ class IncrementalTokenRecognizer(AbsTokenRecognizer):
169
+ def __init__(self, grammar_str, start_rule_name, tokenizer, unicode=False):
170
+ super().__init__(grammar_str, tokenizer, start_rule_name, unicode)
171
+ self.last_size = None
172
+ self.is_incremental = True
173
+
174
+ # if self.last_size is not set (which would be the case when processing the first token).
175
+ # In this case, do nothing.
176
+
177
+ def advance_token_ids(self, input_ids, batch_accept_states, parse_start_index=None):
178
+
179
+ if self.last_size is None:
180
+ prefix_to_parse = [
181
+ single_input_ids[parse_start_index:]
182
+ if parse_start_index is not None
183
+ else []
184
+ for single_input_ids in input_ids
185
+ ]
186
+
187
+ # self.grammar_acceptor.accept_token_ids(prefix_to_parse, self.stacks)
188
+ batch_accept_states = [
189
+ self._consume_token_ids(prefix, accept_state)
190
+ for prefix, accept_state in zip(prefix_to_parse, batch_accept_states)
191
+ ]
192
+ # if the length of the current input IDs (input_ids[0]) is exactly one more than self.last_size.
193
+ # This is expected in a scenario where inputs are processed incrementally, one token at a time.
194
+ elif len(input_ids[0]) == self.last_size + 1:
195
+ batch_accept_states = [
196
+ self._consume_token_id(single_input_ids[-1], accept_state)
197
+ for single_input_ids, accept_state in zip(
198
+ input_ids, batch_accept_states
199
+ )
200
+ ]
201
+ # ensure that the input size is consistent with the expected incremental processing
202
+ # (i.e., one token at a time).
203
+ else:
204
+ # here we check if the input_ids are one token longer than the last time we processed
205
+ # but we don't check if input_ids are actually valid.
206
+ # Imagine a scenario where we generate 10 tokens, then we replace the 10 generated tokens with 10 new tokens.
207
+ # In this case, the input_ids will be consistent with the last_size, but the input_ids are not valid.
208
+ # However, should we really check if the input_ids are valid here?
209
+ # If we do, then we need to reparse the whole input_ids at each call, which is not efficient.
210
+ # Maybe we should just trust the user to provide valid input_ids?
211
+ # The conclusion is that, we assume the input_ids are valid, and our generation will be correct.
212
+ # If the input_ids are not valid, then the generation result will be wrong and we don't take responsibility for that.
213
+ raise RuntimeError(
214
+ "Input ID's length is inconsistent with the current state of "
215
+ "the GrammarConstrainedLogitsProcessor. If you want to process "
216
+ "another input sequence, please instantiate a new "
217
+ "GrammarConstrainedLogitsProcessor "
218
+ "or call reset_parser method of GrammarAlignedOracleLogitsProcessor"
219
+ )
220
+ self.last_size = len(input_ids[0])
221
+
222
+ return batch_accept_states
223
+
224
+ def _consume_token_ids(
225
+ self, token_ids: List[int], accept_state: AcceptState = None, as_string=True
226
+ ):
227
+ if accept_state is None:
228
+ accept_state = self.string_recognizer.get_initial_accept_state()
229
+ if as_string:
230
+ string = self.tokenizer.decode(token_ids)
231
+ accept_state = self.string_recognizer._consume_string(string, accept_state)
232
+ else:
233
+ for i, token_id in enumerate(token_ids):
234
+ accept_state = self._consume_token_id(token_id, accept_state)
235
+ if len(accept_state.stacks) > 0:
236
+ cur_token_ids = token_ids[: i + 1]
237
+ logging.debug(f"{cur_token_ids} is accepted")
238
+ decoded_string = self.tokenizer.decode(cur_token_ids)
239
+ logging.debug(f"The decoded string is {decoded_string}")
240
+ return accept_state
241
+
242
+ def reset(self):
243
+ self.last_size = None
244
+
245
+ def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts):
246
+
247
+ for byte, next_trie in trie.items():
248
+ if byte == LEAF:
249
+ token_id = next_trie
250
+ if token_id != eos_token_id:
251
+ # if the stacks is not empty, it means we can still continue to parse
252
+ # so we should accept the token
253
+ accepts[token_id] = bool(stacks)
254
+ continue
255
+
256
+ new_stacks = []
257
+ for stk in stacks:
258
+ if not stk:
259
+ continue
260
+
261
+ next_element_offset = stk[-1]
262
+ num_chars = grammar.grammar_encoding[next_element_offset]
263
+
264
+ if not grammar.char_acceptance_at_element(next_element_offset).get(
265
+ byte, False
266
+ ):
267
+ # if the current byte is not accepted by the current rule, we need to try next rule
268
+ continue
269
+
270
+ next_element_offset += num_chars + 1
271
+ new_stack = stk[:-1]
272
+ if grammar.grammar_encoding[next_element_offset]:
273
+ new_stack.append(next_element_offset)
274
+ new_stacks.extend(grammar.advance_stack(tuple(new_stack)))
275
+
276
+ if new_stacks:
277
+ check_token_acceptance_in_trie(
278
+ next_trie, new_stacks, grammar, eos_token_id, accepts
279
+ )
280
+
281
+ return accepts
282
+
283
+
284
+ if __name__ == "__main__":
285
+ from transformers import AutoTokenizer
286
+
287
+ with open("examples/grammars/japanese.ebnf", "r") as file:
288
+ input_text = file.read()
289
+ parsed_grammar = parse_ebnf(input_text)
290
+ parsed_grammar.print()
291
+
292
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
293
+
294
+ tokenRecognizer = IncrementalTokenRecognizer(
295
+ grammar_str=input_text, start_rule_name="root", tokenizer=tokenizer
296
+ )
297
+
298
+ japanese = "トリーム" # "こんにちは"
299
+ token_ids = tokenizer.encode(japanese)
300
+ # 13298, 12675, 12045, 254
301
+ stacks = tokenRecognizer._consume_token_ids(
302
+ token_ids, tokenRecognizer.string_recognizer.stacks, as_string=False
303
+ )
304
+
305
+ if stacks:
306
+ print("The Japanese input is accepted")
307
+ else:
308
+ print("The Japanese input is not accepted")
309
+
310
+ korean = "안녕하세요"
311
+ token_ids = tokenizer.encode(korean)
312
+
313
+ try:
314
+ stacks = tokenRecognizer._consume_token_ids(
315
+ token_ids, tokenRecognizer.string_recognizer.stacks, as_string=False
316
+ )
317
+ if stacks:
318
+ print("The Korean input is accepted")
319
+ else:
320
+ print("The Korean input is not accepted")
321
+ except ValueError as e:
322
+ print("The Korean input is not accepted")
transformers_gad/trie.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from functools import lru_cache
3
+ from typing import Dict, List, Tuple
4
+ from collections import deque
5
+
6
+ from transformers_gad.mapping import get_mapping
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class TrieNode:
11
+ def __init__(self):
12
+ self.children = {}
13
+ self.is_end_of_word = False
14
+ self.token_id = None
15
+
16
+
17
+ class ByteTrie:
18
+ def __init__(self):
19
+ self.root = TrieNode()
20
+
21
+ def insert(self, word, token_id=None):
22
+ node = self.root
23
+ for char in word:
24
+ if char not in node.children:
25
+ node.children[char] = TrieNode()
26
+ node = node.children[char]
27
+ node.is_end_of_word = True
28
+ node.token_id = token_id
29
+
30
+ def search(self, word):
31
+ node = self.root
32
+ for char in word:
33
+ if char not in node.children:
34
+ return False
35
+ node = node.children[char]
36
+ return node.is_end_of_word
37
+
38
+ def start_with_prefix(self, prefix):
39
+ node = self.root
40
+ for char in prefix:
41
+ if char not in node.children:
42
+ return False
43
+ node = node.children[char]
44
+ return True
45
+
46
+ @classmethod
47
+ def from_tokenizer(cls, tokenizer, unicode=True):
48
+ vocab: Dict[str, int] = tokenizer.get_vocab()
49
+ trie = cls()
50
+ mapping = get_mapping(tokenizer, unicode=unicode)
51
+ for token_id in vocab.values():
52
+ byte_repr = mapping.map(token_id)
53
+ trie.insert(byte_repr, token_id)
54
+ return trie
55
+
56
+ @lru_cache(maxsize=128)
57
+ def __len__(self):
58
+ return len(self.dfs(verbose=False))
59
+
60
+ def dfs(self, accept=lambda x: True, verbose=False) -> List[Tuple[List[int], int]]:
61
+ result = []
62
+ counter = {"visited": 0, "pruned": 0}
63
+ _dfs(self.root, [], result, accept, counter)
64
+ return result
65
+
66
+ def bfs(
67
+ self, predicate=lambda x: True, verbose=False
68
+ ) -> List[Tuple[List[int], int]]:
69
+ queue = deque([(self.root, [])])
70
+ valid_byte_seqs: List[Tuple[List[int], int]] = []
71
+ counter = {"visited": 0, "pruned": 0}
72
+
73
+ while queue:
74
+ counter["visited"] += 1
75
+ node, byte_seq = queue.popleft()
76
+ if predicate(byte_seq):
77
+ if node.is_end_of_word:
78
+ valid_byte_seqs.append((byte_seq, node.token_id))
79
+ for char, next_node in node.children.items():
80
+ new_byte_seq: List[int] = byte_seq.copy()
81
+ new_byte_seq.append(char)
82
+ queue.append((next_node, new_byte_seq))
83
+ else:
84
+ counter["pruned"] += 1
85
+ return valid_byte_seqs
86
+
87
+ def get_token_acceptance(
88
+ self, accept=lambda x: True, accept_eos=True, eos_token_id=None
89
+ ) -> List[bool]:
90
+ valid_byte_seqs: List[Tuple[List[int], int]] = self.bfs(accept, verbose=True)
91
+ valid_token_ids: List[int] = [token_id for _, token_id in valid_byte_seqs]
92
+ token_acceptance: List[bool] = [False] * (len(self))
93
+ for token_id in valid_token_ids:
94
+ token_acceptance[token_id] = True
95
+ if not accept_eos:
96
+ # eos_token is mapped to an empty string, so it's always accepted regardless of the accept function
97
+ # this can be undesirable, so we can set it to False to ignore it
98
+ token_acceptance[eos_token_id] = False
99
+ return token_acceptance
100
+
101
+
102
+ def _dfs(
103
+ node,
104
+ cur_byte_seq: List[int],
105
+ result: List[Tuple[List[int], int]],
106
+ accept: callable,
107
+ counter: Dict[str, int],
108
+ ):
109
+ counter["visited"] += 1
110
+ if accept(cur_byte_seq):
111
+ if node.is_end_of_word:
112
+ result.append((cur_byte_seq, node.token_id))
113
+ for char, next_node in node.children.items():
114
+ new_byte_seq: List[int] = cur_byte_seq.copy()
115
+ new_byte_seq.append(char)
116
+ _dfs(next_node, new_byte_seq, result, accept, counter)
117
+ else:
118
+ # Skip the entire subtree if the predict function returns False
119
+ counter["pruned"] += 1
120
+ return
121
+
122
+
123
+ def starts_with_prefix(prefix, target):
124
+ """
125
+ Check if the given prefix is a valid start of the target word or if the target word is a valid start of the given prefix.
126
+
127
+ Args:
128
+ prefix (str): The string prefix to be checked.
129
+ target (str): The target word to compare the prefix against.
130
+
131
+ Returns:
132
+ bool: True if prefix is a valid start of target or if target is a valid start of prefix, False otherwise.
133
+ """
134
+
135
+ # Check if the target word starts with the given prefix.
136
+ # This covers the case where the prefix is shorter than the target word.
137
+ if target.startswith(prefix):
138
+ return True
139
+
140
+ # Check if the given prefix starts with the target word.
141
+ # This covers the case where the prefix is longer than or equal to the target word.
142
+ if prefix.startswith(target):
143
+ return True
144
+
145
+ # If neither of the above conditions are true, return False.
146
+ return False
147
+
148
+
149
+ if __name__ == "__main__":
150
+ import logging
151
+
152
+ # Configure logging
153
+ logging.basicConfig(level=logging.INFO)
154
+ from transformers import AutoTokenizer
155
+
156
+ tokenizer = AutoTokenizer.from_pretrained("gpt2", fast=True)
157
+
158
+ trie = ByteTrie.from_tokenizer(tokenizer, unicode=True)
159
+ print(f"length of trie: {len(trie)}=={len(tokenizer.vocab.items())}")
160
+
161
+ #
162
+ # print(trie.search("hello")) # Example, replace with actual words from the vocab
163
+ # print(trie.start_with_prefix("hell"))
164
+ #
165
+ # # Example Usage
166
+ # words = trie.dfs(accept=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0)
167
+ # for word in words:
168
+ # print(bytes(word[0]).decode("utf-8"))
169
+ #
170
+ # # Example Usage
171
+ # words = trie.bfs(predicate=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0)
172
+ # for word in words:
173
+ # print(bytes(word[0]).decode("utf-8"))
174
+ #
175
+ # token_acceptance = trie.get_token_acceptance(accept=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0)
176
+ # print(sum(token_acceptance))
177
+ # assert sum(token_acceptance) == len(words)
178
+
179
+ ########################
180
+ # UTF-8
181
+ ########################
182
+
183
+ # from transformers import AutoTokenizer
184
+ #
185
+ # japanese = "こんにちは世界"
186
+ # with open("examples/grammars/japanese.ebnf", "r") as file:
187
+ # input_text = file.read()
188
+ # parsed_grammar = parse_ebnf(input_text)
189
+ #
190
+ # start_rule_id = parsed_grammar.symbol_table["root"]
191
+ #
192
+ # recognizer = GrammarRecognizer(parsed_grammar.grammar_encoding, start_rule_id)
193
+ # accept_state = recognizer.init_accept_state()
194
+ # token_acc = trie.get_token_acceptance(accept=lambda x: recognizer._probe_bytes_partial_match(x, accept_state=accept_state))
transformers_gad/utf8_utils.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Tuple
3
+
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class PartialUTF8:
9
+ """
10
+ A data class representing the state of a partially decoded UTF-8 sequence.
11
+
12
+ Attributes:
13
+ - value (int): The current accumulated value of the partially decoded Unicode code point.
14
+ This attribute stores the bits that have been decoded so far. For a fully decoded
15
+ character or before any partial decoding has started, this would typically be `0`.
16
+
17
+ - n_remain (int): The number of bytes remaining to complete the current UTF-8 encoded character.
18
+ A value of `-1` indicates that there is no ongoing partial decoding, i.e.,
19
+ either decoding has not started, or the last character was fully decoded.
20
+
21
+ This class is used to handle situations where UTF-8 encoded data may end in the middle of a character
22
+ sequence, allowing for the decoding process to be resumed when more data becomes available.
23
+ """
24
+
25
+ value: int = 0 # Default to 0, indicating no partial value accumulated
26
+ n_remain: int = (
27
+ -1
28
+ ) # Default to -1, indicating no bytes are currently expected to complete the character
29
+
30
+ def __hash__(self):
31
+ return hash((self.value, self.n_remain))
32
+
33
+ def __eq__(self, other):
34
+ if not isinstance(other, PartialUTF8):
35
+ return NotImplemented
36
+ return self.value == other.value and self.n_remain == other.n_remain
37
+
38
+
39
+ from typing import List, Tuple
40
+ from functools import lru_cache
41
+
42
+
43
+ @lru_cache(maxsize=3000000)
44
+ def decode_utf8(
45
+ src: bytes, partial_start: PartialUTF8
46
+ ) -> Tuple[List[int], PartialUTF8]:
47
+ # Lookup table for determining the total bytes based on the first byte's high 4 bits
48
+ lookup = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4]
49
+ pos = 0 # Position in the src bytes to start decoding from
50
+ code_points = [] # List to store the decoded Unicode code points
51
+ value = partial_start.value # Start with any previously partial decoded value
52
+ n_remain = partial_start.n_remain # Number of bytes remaining from a partial decode
53
+
54
+ # If there's a partial sequence left from last decode, try to continue decoding it
55
+ while pos < len(src) and n_remain > 0:
56
+ next_byte = src[pos] # Get the next byte to process
57
+ # Check if the continuation byte format is correct (`10xxxxxx`)
58
+ if (next_byte >> 6) != 2:
59
+ # If not, it's an invalid sequence. Abort and return a special error state.
60
+ code_points = [0]
61
+ return code_points, PartialUTF8(0, -1)
62
+
63
+ # Accumulate the value by shifting left and adding the relevant 6 bits
64
+ value = (value << 6) + (next_byte & 0x3F)
65
+ pos += 1 # Move to the next byte
66
+ n_remain -= 1 # Decrement the number of remaining bytes
67
+
68
+ # If we've completed a partial sequence, add its value to the code points
69
+ if partial_start.n_remain > 0 and n_remain == 0:
70
+ code_points.append(value)
71
+
72
+ # Process the rest of src as complete or new UTF-8 sequences
73
+ while pos < len(src):
74
+ first_byte = src[pos] # Get the first byte of the next sequence
75
+ highbits = first_byte >> 4 # Extract the high 4 bits for the lookup table
76
+ n_remain = lookup[highbits] - 1 # Determine remaining bytes in this sequence
77
+
78
+ # If lookup returns an invalid number, it's an invalid sequence. Abort.
79
+ if n_remain < 0:
80
+ # raise ValueError("Invalid UTF-8 sequence")
81
+ code_points = [0]
82
+ return code_points, PartialUTF8(0, -1)
83
+
84
+ # Calculate the mask to isolate significant bits from the first byte
85
+ mask = (1 << (7 - n_remain)) - 1
86
+ value = first_byte & mask # Apply the mask to get the initial value
87
+ pos += 1 # Move to the next byte
88
+
89
+ # Process the continuation bytes
90
+ while pos < len(src) and n_remain > 0:
91
+ next_byte = src[pos]
92
+ # Shift the accumulated value and add the next 6 significant bits
93
+ value = (value << 6) + (next_byte & 0x3F)
94
+ pos += 1 # Move to the next byte
95
+ n_remain -= 1 # Decrement the number of remaining bytes
96
+
97
+ # If the sequence is complete, add its decoded value to the code points
98
+ if n_remain == 0:
99
+ code_points.append(value)
100
+
101
+ # # Append a terminating value to indicate the end (following llama-cpp implementation)
102
+ # code_points.append(0)
103
+ # the following line is crucial for LRU cache to work, as it reset to the initial state
104
+ if n_remain == 0:
105
+ n_remain = -1
106
+ value = 0
107
+
108
+ # Return the decoded code points and the state of any partial decoding
109
+ return code_points, PartialUTF8(value, n_remain)
110
+
111
+
112
+ def decode_utf8_leading_char(src: bytes) -> tuple:
113
+ first_byte = src[0]
114
+ highbits = first_byte >> 4
115
+ lookup = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4]
116
+ char_len = lookup[highbits]
117
+
118
+ # Extract the relevant bytes for the UTF-8 character
119
+ utf8_char_bytes = src[:char_len]
120
+
121
+ # Decode the character
122
+ char = utf8_char_bytes.decode("utf-8")
123
+
124
+ # Use ord() to convert the single character to its Unicode code point
125
+ code_point = ord(char)
126
+
127
+ # Remaining bytes
128
+ remaining_bytes = src[char_len:]
129
+
130
+ return code_point, remaining_bytes
131
+
132
+
133
+ def decode_utf8_string(utf8_bytes: bytes) -> list:
134
+ code_points = []
135
+ while utf8_bytes:
136
+ code_point, utf8_bytes = decode_utf8_leading_char(utf8_bytes)
137
+ code_points.append(code_point)
138
+ return code_points
139
+
140
+ if __name__ == "__main__":
141
+ # Given string
142
+ my_string = "€Hello" # The Euro symbol followed by "Hello"
143
+
144
+ # Get UTF-8 encoded bytes
145
+ utf8_bytes = my_string.encode("utf-8")
146
+
147
+ assert utf8_bytes == b"\xe2\x82\xacHello"
148
+
149
+ # Example usage with the Euro symbol followed by more characters
150
+ code_point, remaining_bytes = decode_utf8_leading_char(utf8_bytes)
151
+
152
+ print(f"Code Point: {code_point}") # Expected Output: 8364 (Euro symbol)
153
+ print(f"Remaining Bytes: {remaining_bytes}") # Expected Output: b'Hello'
154
+
155
+ # Example usage with the entire string
156
+ code_points = decode_utf8_string(utf8_bytes)
157
+
158
+ print(
159
+ f"Code Points: {code_points}"
160
+ ) # Expected Output: [8364, 72, 101, 108, 108, 111]
161
+
162
+ print("-" * 50)
163
+
164
+ # Example usage:
165
+ utf8_bytes = b"\xe2\x82\xacHello" # UTF-8 encoded string (Euro symbol + "Hello")
166
+ partial_start = PartialUTF8() # Assuming start with no partial sequence
167
+ code_points, partial_utf8 = decode_utf8(utf8_bytes, partial_start)
168
+
169
+ print("Code Points:", code_points)
170
+ print("Remaining UTF-8 State:", partial_utf8.value, partial_utf8.n_remain)
transformers_gad/utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import warnings
3
+ from typing import List
4
+
5
+ from termcolor import colored
6
+
7
+
8
+ def ints2bytes(sequence: List[int]) -> bytes:
9
+ # check in the range of 0-255
10
+ for item in sequence:
11
+ if not 0 <= item <= 255:
12
+ raise ValueError(f"item: {item} is not in the range [0, 255]")
13
+ return bytes(sequence)
14
+
15
+
16
+ def bytes2ints(byte_sequence: bytes) -> List[int]:
17
+ return list(byte_sequence)
18
+
19
+
20
+ def intervals_intersect(low1, high1, low2, high2):
21
+ """
22
+ Check if two intervals [low1, high1] and [low2, high2] intersect.
23
+
24
+ :param high1: High bound of the first interval.
25
+ :param low1: Low bound of the first interval.
26
+ :param high2: High bound of the second interval.
27
+ :param low2: Low bound of the second interval.
28
+ :return: True if the intervals intersect, False otherwise.
29
+ """
30
+ # Check if one interval is completely to the right of the other
31
+ if low1 > high2 or low2 > high1:
32
+ return False
33
+
34
+ # If the above condition is not met, the intervals intersect
35
+ return True
36
+
37
+
38
+ def pprint_token_ids(tokenizer, token_ids=None, text=None):
39
+ if token_ids is None and text is None:
40
+ raise ValueError("Either token_ids or text should be provided")
41
+ if token_ids is None:
42
+ token_ids = tokenizer.encode(text, add_special_tokens=False)
43
+ special_token_ids = tokenizer.all_special_ids
44
+ special_tokens = tokenizer.all_special_tokens
45
+ special_id2token = {
46
+ id: token for id, token in zip(special_token_ids, special_tokens)
47
+ }
48
+ # loop over token_ids and color the special tokens
49
+ colored_token_ids = []
50
+
51
+ for token_id in token_ids:
52
+ if token_id in special_id2token:
53
+ colored_token_ids.append(colored(token_id, "red", attrs=["bold"]))
54
+ else:
55
+ colored_token_ids.append(str(token_id))
56
+ colored_token_ids_str = [str(item) for item in colored_token_ids]
57
+ print("[" + ", ".join(colored_token_ids_str) + "]")
58
+
59
+
60
+ def get_tokenizer_model_type(model: str = "gpt2"):
61
+ """
62
+ reference https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_fast.py#L729
63
+ :param model:
64
+ :return: BPE, Unigram, WordPiece, WordLevel
65
+ SentencePiece is used in conjunction with Unigram
66
+ """
67
+ from transformers import AutoTokenizer
68
+
69
+ # if the tokenizer is not in the repo, it will raise OSError
70
+ # OSError: Can't load tokenizer for 'xxx'
71
+ # This happens when the model reuses the tokenizer of another model
72
+ if type(model) == str:
73
+ try:
74
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
75
+ # check if the tokenizer is fast
76
+ except OSError:
77
+ return None
78
+ else:
79
+ tokenizer = model
80
+
81
+ if not tokenizer.is_fast:
82
+ raise ValueError(f"The tokenizer {model} is not fast tokenizer")
83
+ tokenizer_json = json.loads(tokenizer._tokenizer.to_str())
84
+ model_type = tokenizer_json["model"]["type"]
85
+ if (
86
+ model_type == "BPE"
87
+ and tokenizer_json["pre_tokenizer"] is not None
88
+ and (
89
+ tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel"
90
+ or (
91
+ "pretokenizers" in tokenizer_json["pre_tokenizer"]
92
+ and tokenizer_json["pre_tokenizer"]["pretokenizers"][1]["type"]
93
+ == "ByteLevel"
94
+ )
95
+ )
96
+ ):
97
+ model_type = "ByteLevelBPE"
98
+ return model_type
transformers_gad/vocab_struct.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################
2
+ # DATA STRUCTURES
3
+ #################
4
+
5
+ import logging
6
+ import re
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ LEAF = -1
11
+
12
+ # TokenTrie is a trie that maps token IDs to their byte representations
13
+
14
+ class TokenTrie:
15
+ def __init__(self, tokenizer):
16
+ self.eos_token_id = tokenizer.eos_token_id
17
+ self.tokens = []
18
+ self.trie = {}
19
+ self.load_tokens(tokenizer)
20
+
21
+ def id2str(self, token_id):
22
+ return self.tokens[token_id]
23
+
24
+ def __len__(self):
25
+ return len(self.tokens)
26
+
27
+ def load_tokens(self, tokenizer):
28
+ def replace_hex(match):
29
+ hex_value = match.group(1)
30
+ return chr(int(hex_value, 16))
31
+
32
+ if "gpt2" in tokenizer.__class__.__name__.lower():
33
+ special = tokenizer.additional_special_tokens_ids
34
+
35
+ # Here, the decoder does a string replace on a bunch of sequences
36
+ # like ' .' for '.'. This interferes with our assumptions, where a
37
+ # token should always have exactly one representation.
38
+ # Fortunately(?) text-generation-inference doesn't seem to run this
39
+ # cleanup, so we get extraneous spaces. So, in order to generate
40
+ # the right token set for TGI, we have to skip the space trimming.
41
+ # See:
42
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3588-L3600
43
+ def fmt_token(id):
44
+ if id in special:
45
+ return None
46
+ return bytes(
47
+ tokenizer.decode([id], clean_up_tokenization_spaces=False), "utf-8"
48
+ )
49
+
50
+ elif (
51
+ "llama" in tokenizer.__class__.__name__.lower()
52
+ or "t5" in tokenizer.__class__.__name__.lower()
53
+ ):
54
+
55
+ def fmt_token(id):
56
+ token = tokenizer.convert_ids_to_tokens(id)
57
+ token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token)
58
+ token = token.replace("▁", " ")
59
+ return bytes(token, "utf-8") # here return bytes representations of the tokens
60
+
61
+ else:
62
+ logger.warning(
63
+ "Warning: unrecognized tokenizer: using default token formatting"
64
+ )
65
+
66
+ def fmt_token(id):
67
+ token = tokenizer.convert_ids_to_tokens(id)
68
+ return bytes(token, "utf-8")
69
+
70
+ # note: vocab_size doesn't work here because there are also
71
+ # get_added_vocab() tokens
72
+ self.tokens = [fmt_token(i) for i in range(len(tokenizer.get_vocab()))]
73
+ for token_id, token_bytes in enumerate(self.tokens):
74
+ if token_bytes is not None:
75
+ self.insert_into_trie(self.trie, token_bytes, token_id)
76
+
77
+ def insert_into_trie(self, trie, token_bytes, token_id):
78
+ current = trie
79
+ for byte in token_bytes:
80
+ if byte not in current:
81
+ current[byte] = {}
82
+ current = current[byte]
83
+ current[LEAF] = token_id