Commit
·
901bbd9
1
Parent(s):
9fbf7f9
Add GAD libraries
Browse files- transformers_gad/__init__.py +5 -0
- transformers_gad/__pycache__/__init__.cpython-311.pyc +0 -0
- transformers_gad/__pycache__/grammar_utils.cpython-311.pyc +0 -0
- transformers_gad/__pycache__/logging_config.cpython-311.pyc +0 -0
- transformers_gad/__pycache__/mapping.cpython-311.pyc +0 -0
- transformers_gad/__pycache__/parser.cpython-311.pyc +0 -0
- transformers_gad/__pycache__/recognizer.cpython-311.pyc +0 -0
- transformers_gad/__pycache__/token_grammar_recognizer.cpython-311.pyc +0 -0
- transformers_gad/__pycache__/trie.cpython-311.pyc +0 -0
- transformers_gad/__pycache__/utf8_utils.cpython-311.pyc +0 -0
- transformers_gad/__pycache__/utils.cpython-311.pyc +0 -0
- transformers_gad/__pycache__/vocab_struct.cpython-311.pyc +0 -0
- transformers_gad/generation/__init__.py +1 -0
- transformers_gad/generation/__pycache__/__init__.cpython-311.pyc +0 -0
- transformers_gad/generation/__pycache__/logits_process.cpython-311.pyc +0 -0
- transformers_gad/generation/logits_process.py +348 -0
- transformers_gad/grammar_utils.py +4 -0
- transformers_gad/logging_config.py +18 -0
- transformers_gad/mapping.py +209 -0
- transformers_gad/oracle/__init_.py +1 -0
- transformers_gad/oracle/__pycache__/oracle_trie.cpython-311.pyc +0 -0
- transformers_gad/oracle/oracle_trie.py +261 -0
- transformers_gad/parser.py +576 -0
- transformers_gad/parser_cfg.py +530 -0
- transformers_gad/recognizer.py +456 -0
- transformers_gad/token_grammar_recognizer.py +322 -0
- transformers_gad/trie.py +194 -0
- transformers_gad/utf8_utils.py +170 -0
- transformers_gad/utils.py +98 -0
- transformers_gad/vocab_struct.py +83 -0
transformers_gad/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .logging_config import setup_logging
|
2 |
+
|
3 |
+
setup_logging()
|
4 |
+
|
5 |
+
__version__ = "0.1.2"
|
transformers_gad/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (318 Bytes). View file
|
|
transformers_gad/__pycache__/grammar_utils.cpython-311.pyc
ADDED
Binary file (330 Bytes). View file
|
|
transformers_gad/__pycache__/logging_config.cpython-311.pyc
ADDED
Binary file (964 Bytes). View file
|
|
transformers_gad/__pycache__/mapping.cpython-311.pyc
ADDED
Binary file (14.7 kB). View file
|
|
transformers_gad/__pycache__/parser.cpython-311.pyc
ADDED
Binary file (24.1 kB). View file
|
|
transformers_gad/__pycache__/recognizer.cpython-311.pyc
ADDED
Binary file (20.2 kB). View file
|
|
transformers_gad/__pycache__/token_grammar_recognizer.cpython-311.pyc
ADDED
Binary file (16.3 kB). View file
|
|
transformers_gad/__pycache__/trie.cpython-311.pyc
ADDED
Binary file (8.37 kB). View file
|
|
transformers_gad/__pycache__/utf8_utils.cpython-311.pyc
ADDED
Binary file (6.17 kB). View file
|
|
transformers_gad/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (4.66 kB). View file
|
|
transformers_gad/__pycache__/vocab_struct.cpython-311.pyc
ADDED
Binary file (4.44 kB). View file
|
|
transformers_gad/generation/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .logits_process import GrammarConstrainedLogitsProcessor, GrammarAlignedOracleLogitsProcessor
|
transformers_gad/generation/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (343 Bytes). View file
|
|
transformers_gad/generation/__pycache__/logits_process.cpython-311.pyc
ADDED
Binary file (17.2 kB). View file
|
|
transformers_gad/generation/logits_process.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import logging
|
7 |
+
from transformers.generation.logits_process import (
|
8 |
+
LogitsProcessor,
|
9 |
+
LOGITS_PROCESSOR_INPUTS_DOCSTRING,
|
10 |
+
)
|
11 |
+
from transformers.utils import add_start_docstrings
|
12 |
+
from transformers_gad.grammar_utils import IncrementalGrammarConstraint
|
13 |
+
from transformers_gad.oracle.oracle_trie import Trie
|
14 |
+
|
15 |
+
class GrammarConstrainedLogitsProcessor(LogitsProcessor):
|
16 |
+
def __init__(self, grammar_constraint, parse_start_index=None, save_log=False):
|
17 |
+
# Parser variables
|
18 |
+
self.grammar_constraint = grammar_constraint
|
19 |
+
self.batch_parsing_states = None
|
20 |
+
self.parse_start_index = parse_start_index
|
21 |
+
|
22 |
+
# To start with a longer prefix in enumerative search
|
23 |
+
self.generate_start_index = None
|
24 |
+
self.generated_tokens = None
|
25 |
+
|
26 |
+
# Generation Log
|
27 |
+
self.save_log = save_log
|
28 |
+
self.history = []
|
29 |
+
|
30 |
+
def reset(self):
|
31 |
+
self.reset_parser()
|
32 |
+
self.reset_history()
|
33 |
+
|
34 |
+
def reset_parser(self):
|
35 |
+
self.batch_parsing_states = None
|
36 |
+
if self.grammar_constraint.is_incremental:
|
37 |
+
self.grammar_constraint.reset()
|
38 |
+
|
39 |
+
self.generate_start_index = None
|
40 |
+
self.generated_tokens = None
|
41 |
+
|
42 |
+
def reset_history(self):
|
43 |
+
self.history = []
|
44 |
+
|
45 |
+
def mask_scores(self, scores, device):
|
46 |
+
"""
|
47 |
+
resolve each stack to a tensor of True/False for each token
|
48 |
+
indicating acceptance
|
49 |
+
"""
|
50 |
+
masked_scores = scores.clone()
|
51 |
+
acceptance = self.grammar_constraint.batch_filter_vocab(
|
52 |
+
self.batch_parsing_states, device
|
53 |
+
)
|
54 |
+
|
55 |
+
if self.save_log:
|
56 |
+
self.store_detailed_history(acceptance, scores)
|
57 |
+
|
58 |
+
# Scores to -inf where False
|
59 |
+
masked_scores[~acceptance] = -math.inf
|
60 |
+
|
61 |
+
return masked_scores
|
62 |
+
|
63 |
+
def process_scores(self, input_ids, scores):
|
64 |
+
# we dynamically create stacks at the first call, so that we know the batch size and beam size
|
65 |
+
if self.batch_parsing_states is None:
|
66 |
+
self.batch_parsing_states = [
|
67 |
+
copy.deepcopy(
|
68 |
+
self.grammar_constraint.string_recognizer.get_initial_accept_state()
|
69 |
+
)
|
70 |
+
for _ in range(len(input_ids))
|
71 |
+
]
|
72 |
+
|
73 |
+
# assume the generation starts from the same index
|
74 |
+
if self.generate_start_index is None:
|
75 |
+
# the default is the end of input sequence of tokens
|
76 |
+
self.generate_start_index = self.parse_start_index \
|
77 |
+
if self.parse_start_index else input_ids.size(1)
|
78 |
+
self.generated_tokens = input_ids[:, self.generate_start_index:]
|
79 |
+
|
80 |
+
# Advance parser states
|
81 |
+
self.batch_parsing_states = self.grammar_constraint.advance_token_ids(
|
82 |
+
input_ids, self.batch_parsing_states, self.parse_start_index
|
83 |
+
)
|
84 |
+
|
85 |
+
masked_scores = self.mask_scores(scores, scores.device)
|
86 |
+
return masked_scores
|
87 |
+
|
88 |
+
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
89 |
+
def __call__(
|
90 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
91 |
+
) -> torch.FloatTensor:
|
92 |
+
return self.process_scores(input_ids, scores)
|
93 |
+
|
94 |
+
def reset_parser(self):
|
95 |
+
self.batch_parsing_states = None
|
96 |
+
if isinstance(self.grammar_constraint, IncrementalGrammarConstraint):
|
97 |
+
self.grammar_constraint.reset()
|
98 |
+
|
99 |
+
def get_accepted_tokens(self, acceptance):
|
100 |
+
"""
|
101 |
+
Get the indices of accepted tokens and their corresponding string values for each item in the batch.
|
102 |
+
|
103 |
+
Parameters:
|
104 |
+
- acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch.
|
105 |
+
"""
|
106 |
+
batch_size, _ = acceptance.shape
|
107 |
+
acceptance_np = acceptance.cpu().numpy()
|
108 |
+
accepted_x, accepted_y = acceptance_np.nonzero()
|
109 |
+
|
110 |
+
# Initialize the dictionary with empty lists for indices
|
111 |
+
accepted_token_indices = {i: [] for i in range(batch_size)}
|
112 |
+
for x, y in zip(accepted_x, accepted_y):
|
113 |
+
accepted_token_indices[x].append(y)
|
114 |
+
|
115 |
+
# Convert token IDs to tokens
|
116 |
+
accepted_tokens = {
|
117 |
+
i: [self.grammar_constraint.tokenizer.decode([token_id]) for token_id in token_ids]
|
118 |
+
for i, token_ids in accepted_token_indices.items()
|
119 |
+
}
|
120 |
+
|
121 |
+
return accepted_tokens
|
122 |
+
|
123 |
+
def store_detailed_history(self, acceptance, scores):
|
124 |
+
"""
|
125 |
+
Processes and stores information for accepted tokens including their IDs, tokens,
|
126 |
+
raw scores, and logits.
|
127 |
+
|
128 |
+
Parameters:
|
129 |
+
- acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch.
|
130 |
+
- scores (torch.Tensor): The raw scores from the model output.
|
131 |
+
- adjusted_scores (torch.Tensor): The adjusted scores after applying expected future grammaticality.
|
132 |
+
"""
|
133 |
+
likelihoods = F.softmax(scores, dim=-1)
|
134 |
+
|
135 |
+
# Initializing the list to store detailed information for each step
|
136 |
+
batch_accepted_info = []
|
137 |
+
|
138 |
+
for batch_index in range(acceptance.size(0)): # Iterate over batch items
|
139 |
+
accepted_info = []
|
140 |
+
accepted_indices = acceptance[batch_index].nonzero().squeeze(-1)
|
141 |
+
|
142 |
+
for idx in accepted_indices:
|
143 |
+
token_id = idx.item()
|
144 |
+
raw_score = scores[batch_index, idx].item()
|
145 |
+
likelihood = likelihoods[batch_index, idx].item()
|
146 |
+
token = self.grammar_constraint.tokenizer.decode([token_id])
|
147 |
+
|
148 |
+
# Store detailed information as a dictionary
|
149 |
+
accepted_info.append({
|
150 |
+
"token_id": token_id,
|
151 |
+
"token": str(token),
|
152 |
+
"raw_score": raw_score,
|
153 |
+
"raw_likelihood": likelihood
|
154 |
+
})
|
155 |
+
|
156 |
+
batch_accepted_info.append(accepted_info)
|
157 |
+
|
158 |
+
# Store this detailed information in the history
|
159 |
+
self.history.append(batch_accepted_info)
|
160 |
+
|
161 |
+
class GrammarAlignedOracleLogitsProcessor(LogitsProcessor):
|
162 |
+
def __init__(self, grammar_constraint, oracle_trie=Trie(), parse_start_index=None, save_log=False):
|
163 |
+
# Parser variables
|
164 |
+
self.grammar_constraint = grammar_constraint
|
165 |
+
self.batch_parsing_states = None
|
166 |
+
self.parse_start_index = parse_start_index
|
167 |
+
|
168 |
+
# ASAp oracle trie
|
169 |
+
self.oracle_trie = oracle_trie
|
170 |
+
|
171 |
+
# To start with a longer prefix in enumerative search
|
172 |
+
self.generate_start_index = None
|
173 |
+
self.generated_tokens = None
|
174 |
+
|
175 |
+
# Generation Log
|
176 |
+
self.save_log = save_log
|
177 |
+
self.history = []
|
178 |
+
|
179 |
+
def adjust_scores(self, scores, device):
|
180 |
+
"""
|
181 |
+
resolve each stack to a tensor of True/False for each token
|
182 |
+
indicating acceptance
|
183 |
+
"""
|
184 |
+
acceptance = self.grammar_constraint.batch_filter_vocab(
|
185 |
+
self.batch_parsing_states, device
|
186 |
+
)
|
187 |
+
|
188 |
+
current_parent = self.oracle_trie.search_last_parent(self.generated_tokens)
|
189 |
+
current_parent.insert_accepted_tokens(scores, acceptance)
|
190 |
+
adjusted_scores = self.apply_oracle_adjustments(acceptance, scores, current_parent)
|
191 |
+
|
192 |
+
if self.save_log:
|
193 |
+
self.store_detailed_history(acceptance, scores, adjusted_scores)
|
194 |
+
|
195 |
+
# Scores to -inf where False
|
196 |
+
adjusted_scores[~acceptance] = -math.inf
|
197 |
+
|
198 |
+
return adjusted_scores
|
199 |
+
|
200 |
+
def apply_oracle_adjustments(self, acceptance, scores, current_parent):
|
201 |
+
"""
|
202 |
+
Multiply expected future grammarticality
|
203 |
+
Use the normalized (and unmasked) probabiltiy
|
204 |
+
|
205 |
+
Parameters:
|
206 |
+
- acceptance (torch.Tensor): A characteristic vector of valid tokens
|
207 |
+
used to updated only valid tokens
|
208 |
+
- scores (torch.Tensor): Unnormalized logits from language model
|
209 |
+
- current_parent (TrieNode): The trie node for the current prefix
|
210 |
+
"""
|
211 |
+
adjusted_scores = scores.clone()
|
212 |
+
likelihoods = F.softmax(adjusted_scores, dim=-1)
|
213 |
+
log_likelihoods = torch.log(likelihoods)
|
214 |
+
|
215 |
+
for batch_index in range(acceptance.size(0)):
|
216 |
+
accepted_indices = acceptance[batch_index].nonzero().squeeze(-1)
|
217 |
+
|
218 |
+
for idx in accepted_indices:
|
219 |
+
token_id = idx.item()
|
220 |
+
log_likelihood = log_likelihoods[batch_index, idx].item()
|
221 |
+
|
222 |
+
# Get theta (log of expected future grammaticality) for this specific token
|
223 |
+
success_rate = current_parent.get_success_rate(token_id)
|
224 |
+
|
225 |
+
if not isinstance(success_rate, torch.Tensor):
|
226 |
+
success_rate = torch.tensor(success_rate, dtype=torch.float)
|
227 |
+
log_theta = torch.log(success_rate)
|
228 |
+
|
229 |
+
# Calculate adjusted score
|
230 |
+
adjusted_score = log_likelihood + log_theta
|
231 |
+
adjusted_scores[batch_index, idx] = adjusted_score
|
232 |
+
|
233 |
+
return adjusted_scores
|
234 |
+
|
235 |
+
def process_scores(self, input_ids, scores):
|
236 |
+
# we dynamically create stacks at the first call, so that we know the batch size and beam size
|
237 |
+
if self.batch_parsing_states is None:
|
238 |
+
self.batch_parsing_states = [
|
239 |
+
copy.deepcopy(
|
240 |
+
self.grammar_constraint.string_recognizer.get_initial_accept_state()
|
241 |
+
)
|
242 |
+
for _ in range(len(input_ids))
|
243 |
+
]
|
244 |
+
|
245 |
+
# assume the generation starts from the same index
|
246 |
+
if self.generate_start_index is None:
|
247 |
+
# the default is the end of input sequence of tokens
|
248 |
+
self.generate_start_index = self.parse_start_index \
|
249 |
+
if self.parse_start_index else input_ids.size(1)
|
250 |
+
self.generated_tokens = input_ids[:, self.generate_start_index:]
|
251 |
+
|
252 |
+
# Advance parser states
|
253 |
+
self.batch_parsing_states = self.grammar_constraint.advance_token_ids(
|
254 |
+
input_ids, self.batch_parsing_states, self.parse_start_index
|
255 |
+
)
|
256 |
+
|
257 |
+
adjusted_scores = self.adjust_scores(scores, scores.device)
|
258 |
+
|
259 |
+
return adjusted_scores
|
260 |
+
|
261 |
+
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
262 |
+
def __call__(
|
263 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
264 |
+
) -> torch.FloatTensor:
|
265 |
+
return self.process_scores(input_ids, scores)
|
266 |
+
|
267 |
+
def reset(self):
|
268 |
+
self.reset_parser()
|
269 |
+
self.reset_history()
|
270 |
+
|
271 |
+
def reset_parser(self):
|
272 |
+
self.batch_parsing_states = None
|
273 |
+
if self.grammar_constraint.is_incremental:
|
274 |
+
self.grammar_constraint.reset()
|
275 |
+
|
276 |
+
self.generate_start_index = None
|
277 |
+
self.generated_tokens = None
|
278 |
+
|
279 |
+
def reset_history(self):
|
280 |
+
self.history = []
|
281 |
+
|
282 |
+
def reset_trie(self):
|
283 |
+
self.oracle_trie = Trie()
|
284 |
+
|
285 |
+
def get_accepted_tokens(self, acceptance):
|
286 |
+
"""
|
287 |
+
Get the indices of accepted tokens and their corresponding string values for each item in the batch.
|
288 |
+
|
289 |
+
Parameters:
|
290 |
+
- acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch.
|
291 |
+
"""
|
292 |
+
batch_size, _ = acceptance.shape
|
293 |
+
acceptance_np = acceptance.cpu().numpy()
|
294 |
+
accepted_x, accepted_y = acceptance_np.nonzero()
|
295 |
+
|
296 |
+
# Initialize the dictionary with empty lists for indices
|
297 |
+
accepted_token_indices = {i: [] for i in range(batch_size)}
|
298 |
+
for x, y in zip(accepted_x, accepted_y):
|
299 |
+
accepted_token_indices[x].append(y)
|
300 |
+
|
301 |
+
# Convert token IDs to tokens
|
302 |
+
accepted_tokens = {
|
303 |
+
i: [self.grammar_constraint.tokenizer.decode([token_id]) for token_id in token_ids]
|
304 |
+
for i, token_ids in accepted_token_indices.items()
|
305 |
+
}
|
306 |
+
|
307 |
+
return accepted_tokens
|
308 |
+
|
309 |
+
def store_detailed_history(self, acceptance, scores, adjusted_scores):
|
310 |
+
"""
|
311 |
+
Processes and stores information for accepted tokens including their IDs, tokens,
|
312 |
+
raw scores, and logits.
|
313 |
+
|
314 |
+
Parameters:
|
315 |
+
- acceptance (torch.Tensor): A boolean tensor indicating accepted tokens for each item in the batch.
|
316 |
+
- scores (torch.Tensor): The raw scores from the model output.
|
317 |
+
- adjusted_scores (torch.Tensor): The adjusted scores after applying expected future grammaticality.
|
318 |
+
"""
|
319 |
+
likelihoods = F.softmax(scores, dim=-1)
|
320 |
+
adjusted_likelihoods = F.softmax(adjusted_scores, dim=-1)
|
321 |
+
|
322 |
+
# Initializing the list to store detailed information for each step
|
323 |
+
batch_accepted_info = []
|
324 |
+
|
325 |
+
for batch_index in range(acceptance.size(0)): # Iterate over batch items
|
326 |
+
accepted_info = []
|
327 |
+
accepted_indices = acceptance[batch_index].nonzero().squeeze(-1)
|
328 |
+
|
329 |
+
for idx in accepted_indices:
|
330 |
+
token_id = idx.item()
|
331 |
+
raw_score = scores[batch_index, idx].item()
|
332 |
+
likelihood = likelihoods[batch_index, idx].item()
|
333 |
+
adjusted_likelihood = adjusted_likelihoods[batch_index, idx].item()
|
334 |
+
token = self.grammar_constraint.tokenizer.decode([token_id])
|
335 |
+
|
336 |
+
# Store detailed information as a dictionary
|
337 |
+
accepted_info.append({
|
338 |
+
"token_id": token_id,
|
339 |
+
"token": str(token),
|
340 |
+
"raw_score": raw_score,
|
341 |
+
"raw_likelihood": likelihood,
|
342 |
+
"adjusted_likelihood": adjusted_likelihood
|
343 |
+
})
|
344 |
+
|
345 |
+
batch_accepted_info.append(accepted_info)
|
346 |
+
|
347 |
+
# Store this detailed information in the history
|
348 |
+
self.history.append(batch_accepted_info)
|
transformers_gad/grammar_utils.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .token_grammar_recognizer import IncrementalTokenRecognizer
|
2 |
+
|
3 |
+
# Old class name, kept for backward compatibility
|
4 |
+
IncrementalGrammarConstraint = IncrementalTokenRecognizer
|
transformers_gad/logging_config.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# logging_config.py
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
|
5 |
+
|
6 |
+
def setup_logging():
|
7 |
+
log_level_name = os.getenv(
|
8 |
+
"TCFG_LOG_LEVEL", "WARNING"
|
9 |
+
).upper() # Default to WARNING if not set
|
10 |
+
log_levels = {
|
11 |
+
"DEBUG": logging.DEBUG,
|
12 |
+
"INFO": logging.INFO,
|
13 |
+
"WARNING": logging.WARNING,
|
14 |
+
"ERROR": logging.ERROR,
|
15 |
+
"CRITICAL": logging.CRITICAL,
|
16 |
+
}
|
17 |
+
log_level = log_levels.get(log_level_name, logging.WARNING)
|
18 |
+
logging.basicConfig(level=log_level)
|
transformers_gad/mapping.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
|
3 |
+
from transformers_gad.utils import get_tokenizer_model_type, ints2bytes
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
import logging
|
6 |
+
|
7 |
+
log = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
def get_mapping(tokenizer, unicode=False):
|
11 |
+
log.debug(f"tokenizer type: {tokenizer.__class__.__name__}")
|
12 |
+
log.debug(f"tokenizer model type: {get_tokenizer_model_type(tokenizer)}")
|
13 |
+
if not unicode:
|
14 |
+
if (
|
15 |
+
"gpt2" in tokenizer.__class__.__name__.lower()
|
16 |
+
or "bloom" in tokenizer.__class__.__name__.lower()
|
17 |
+
or "pretrainedtokenizer" in tokenizer.__class__.__name__.lower()
|
18 |
+
or "codegen" in tokenizer.__class__.__name__.lower()
|
19 |
+
or "gptneox" in tokenizer.__class__.__name__.lower()
|
20 |
+
):
|
21 |
+
return BBPEMapping(tokenizer)
|
22 |
+
elif "t5" in tokenizer.__class__.__name__.lower():
|
23 |
+
return BPEMapping(tokenizer)
|
24 |
+
elif "llama" in tokenizer.__class__.__name__.lower():
|
25 |
+
return LlamaBPEMapping(tokenizer)
|
26 |
+
elif "xglm" in tokenizer.__class__.__name__.lower():
|
27 |
+
return UniGramMapping(tokenizer)
|
28 |
+
else:
|
29 |
+
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__.__name__}")
|
30 |
+
else:
|
31 |
+
if "gpt2" in tokenizer.__class__.__name__.lower():
|
32 |
+
return UnicodeBBPEMapping(tokenizer)
|
33 |
+
else:
|
34 |
+
raise NotImplementedError(
|
35 |
+
f"Unicode mapping for {tokenizer.__class__.__name__}"
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
class Mapping:
|
40 |
+
def __init__(self, tokenizer):
|
41 |
+
self.eos_token_id = tokenizer.eos_token_id
|
42 |
+
self.bos_token_id = tokenizer.bos_token_id
|
43 |
+
self.tokenizer = tokenizer
|
44 |
+
self.special = tokenizer.all_special_ids
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return len(self.tokenizer.get_vocab())
|
48 |
+
|
49 |
+
def _map(self, token_id: int) -> str:
|
50 |
+
# This is the case for BOS,
|
51 |
+
if token_id in self.special:
|
52 |
+
return ""
|
53 |
+
# if token_id is tensor, convert it to int
|
54 |
+
if hasattr(token_id, "item"):
|
55 |
+
token_id = token_id.item()
|
56 |
+
raw_token = self.tokenizer.convert_ids_to_tokens(token_id)
|
57 |
+
return raw_token
|
58 |
+
|
59 |
+
def map(self, token_id: int, verbose=False) -> bytes:
|
60 |
+
token = self._map(token_id)
|
61 |
+
if verbose:
|
62 |
+
log.debug(f"token_id: {token_id}, token: {token}")
|
63 |
+
return bytes(token, "utf-8")
|
64 |
+
|
65 |
+
|
66 |
+
class BBPEMapping(Mapping):
|
67 |
+
def __init__(self, *args, **kwargs):
|
68 |
+
super().__init__(*args, **kwargs)
|
69 |
+
|
70 |
+
def _map(self, token_id: int) -> str:
|
71 |
+
raw_token = super()._map(token_id)
|
72 |
+
if raw_token.startswith("Ġ"):
|
73 |
+
raw_token = raw_token.replace("Ġ", " ")
|
74 |
+
return raw_token
|
75 |
+
|
76 |
+
|
77 |
+
class UnicodeBBPEMapping(Mapping):
|
78 |
+
def __init__(self, *args, **kwargs):
|
79 |
+
super().__init__(*args, **kwargs)
|
80 |
+
self.intermediate_encoding = UnicodeBBPEMapping.get_intermediate_encoding(
|
81 |
+
self.tokenizer
|
82 |
+
)
|
83 |
+
|
84 |
+
def _map(self, token_id: int, verbose=False) -> str:
|
85 |
+
raw_token = super()._map(token_id)
|
86 |
+
# if raw_token.startswith("Ġ"):
|
87 |
+
# raw_token = raw_token.replace("Ġ", " ")
|
88 |
+
return raw_token
|
89 |
+
|
90 |
+
def map(self, token_id: int, verbose=False) -> bytes:
|
91 |
+
raw_token = self._map(token_id, verbose)
|
92 |
+
if verbose:
|
93 |
+
log.debug(f"token_id: {token_id}, raw_token: {raw_token}")
|
94 |
+
return self.intermediate_encoding.token2bytes(raw_token)
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def get_intermediate_encoding(tokenizer):
|
98 |
+
if "gpt2" in tokenizer.__class__.__name__.lower():
|
99 |
+
return ByteEncoding(tokenizer)
|
100 |
+
else:
|
101 |
+
return None
|
102 |
+
|
103 |
+
|
104 |
+
class BPEMapping(Mapping):
|
105 |
+
def __init__(self, tokenizer):
|
106 |
+
super().__init__(tokenizer)
|
107 |
+
self.last_token_id = None
|
108 |
+
|
109 |
+
def _map(self, token_id: int) -> str:
|
110 |
+
raw_token = super()._map(token_id)
|
111 |
+
|
112 |
+
# we need to check if the token is at the beginning of the sentence to remove the space
|
113 |
+
# specific to BPE
|
114 |
+
at_bos = False
|
115 |
+
if self.last_token_id is not None and self.last_token_id == self.bos_token_id:
|
116 |
+
at_bos = True
|
117 |
+
self.last_token_id = token_id
|
118 |
+
if raw_token.startswith("▁"):
|
119 |
+
raw_token = raw_token.replace("▁", " ")
|
120 |
+
if at_bos:
|
121 |
+
# remove space at the beginning of the sentence
|
122 |
+
raw_token = raw_token[1:]
|
123 |
+
return raw_token
|
124 |
+
|
125 |
+
|
126 |
+
class LlamaBPEMapping(BPEMapping):
|
127 |
+
def __init__(self, tokenizer):
|
128 |
+
super().__init__(tokenizer)
|
129 |
+
|
130 |
+
def _map(self, token_id: int) -> str:
|
131 |
+
raw_token = super()._map(token_id)
|
132 |
+
# if the token is hex, token is a string like "<0x00>"
|
133 |
+
# first 256 tokens are hex
|
134 |
+
if raw_token.startswith("<0x"):
|
135 |
+
hex_value = raw_token[4:-1]
|
136 |
+
raw_token = chr(int(hex_value, 16))
|
137 |
+
return raw_token
|
138 |
+
|
139 |
+
|
140 |
+
class WordPieceMapping(Mapping):
|
141 |
+
def __init__(self, tokenizer):
|
142 |
+
super().__init__(tokenizer)
|
143 |
+
|
144 |
+
def map(self, token_id: int) -> bytes:
|
145 |
+
if token_id in self.special:
|
146 |
+
return bytes()
|
147 |
+
return bytes(
|
148 |
+
self.tokenizer.decode([token_id], clean_up_tokenization_spaces=False),
|
149 |
+
"utf-8",
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
class UniGramMapping(Mapping):
|
154 |
+
def __init__(self, tokenizer):
|
155 |
+
super().__init__(tokenizer)
|
156 |
+
|
157 |
+
def map(self, token_id: int) -> bytes:
|
158 |
+
if token_id in self.special:
|
159 |
+
return bytes()
|
160 |
+
return bytes(
|
161 |
+
self.tokenizer.decode([token_id], clean_up_tokenization_spaces=False),
|
162 |
+
"utf-8",
|
163 |
+
)
|
164 |
+
|
165 |
+
|
166 |
+
class XGLMUniGramMapping(Mapping):
|
167 |
+
def __init__(self, tokenizer):
|
168 |
+
super().__init__(tokenizer)
|
169 |
+
self.bos_token_id = tokenizer.eos_token_id
|
170 |
+
self.eos_token_id = None
|
171 |
+
|
172 |
+
|
173 |
+
class ByteEncoding:
|
174 |
+
def __init__(self, tokenizer):
|
175 |
+
# check if the tokenizer is fast, if so, convert it to slow
|
176 |
+
if tokenizer.is_fast:
|
177 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
178 |
+
tokenizer.name_or_path, use_fast=False
|
179 |
+
)
|
180 |
+
self.tokenizer = tokenizer
|
181 |
+
self.byte2char: Dict[int, str] = tokenizer.byte_encoder
|
182 |
+
self.char2byte: Dict[str, int] = tokenizer.byte_decoder
|
183 |
+
# code point to byte
|
184 |
+
self.cdp2byte: Dict[int, int] = {ord(c): b for c, b in self.char2byte.items()}
|
185 |
+
self.byte2cdp: Dict[int, int] = {v: k for k, v in self.cdp2byte.items()}
|
186 |
+
|
187 |
+
def map(self, byte: int) -> int:
|
188 |
+
assert 0 <= byte < 256, f"byte: {byte} is not in the range [0, 256)"
|
189 |
+
return ord(self.byte2char[byte])
|
190 |
+
|
191 |
+
def token_ids2bytes(self, token_ids: List[int]) -> bytes:
|
192 |
+
tokens: List[str] = self.tokenizer.convert_ids_to_tokens(token_ids)
|
193 |
+
# for token id = BOS, the token should be empty string instead of <s>
|
194 |
+
# TODO, this may cause issues because this means that special tokens like BOS can appear at any position
|
195 |
+
tokens = [
|
196 |
+
"" if token in self.tokenizer.all_special_ids else token for token in tokens
|
197 |
+
]
|
198 |
+
bytes: List[List[int]] = [self.token2bytes(token) for token in tokens]
|
199 |
+
# join the bytes
|
200 |
+
return ints2bytes(sum(bytes, []))
|
201 |
+
|
202 |
+
def token_id2bytes(self, token_id: int) -> bytes:
|
203 |
+
token: str = self.tokenizer.convert_ids_to_tokens(token_id)
|
204 |
+
return self.token2bytes(token)
|
205 |
+
|
206 |
+
def token2bytes(self, token: str) -> bytes:
|
207 |
+
# import pdb; pdb.set_trace()
|
208 |
+
bytes_seq: List[int] = [self.char2byte[c] for c in token]
|
209 |
+
return bytes(bytes_seq)
|
transformers_gad/oracle/__init_.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .oracle_trie import Trie, TrieNode, update_oracle_trie
|
transformers_gad/oracle/__pycache__/oracle_trie.cpython-311.pyc
ADDED
Binary file (12.8 kB). View file
|
|
transformers_gad/oracle/oracle_trie.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
|
6 |
+
class TrieNode:
|
7 |
+
def __init__(self,
|
8 |
+
token_id=None, raw_likelihood=None, raw_score=None,
|
9 |
+
success_rate=1,
|
10 |
+
is_start_of_sequence=False, is_end_of_sequence=False,
|
11 |
+
eos_token_id=2):
|
12 |
+
self.children = {}
|
13 |
+
self.parent = None
|
14 |
+
self.token_id = token_id
|
15 |
+
self.raw_likelihood = raw_likelihood
|
16 |
+
self.raw_score = raw_score
|
17 |
+
|
18 |
+
# The default approximation of EFG
|
19 |
+
self.success_rate = success_rate
|
20 |
+
|
21 |
+
self.eos_token_id = eos_token_id
|
22 |
+
self.is_start_of_sequence = is_start_of_sequence
|
23 |
+
self.is_end_of_sequence = is_end_of_sequence
|
24 |
+
|
25 |
+
def insert(self, child_node):
|
26 |
+
"""
|
27 |
+
Insert child_node into the children dictionary
|
28 |
+
"""
|
29 |
+
if child_node.token_id not in self.children:
|
30 |
+
self.children[child_node.token_id] = child_node
|
31 |
+
child_node.parent = self
|
32 |
+
|
33 |
+
if child_node.token_id == self.eos_token_id:
|
34 |
+
child_node.is_end_of_sequence = True
|
35 |
+
|
36 |
+
# update the success rate of the parent node
|
37 |
+
return self.update_success_rate()
|
38 |
+
else:
|
39 |
+
return 0
|
40 |
+
|
41 |
+
def insert_accepted_tokens(self, scores, acceptance):
|
42 |
+
"""
|
43 |
+
Create node from acceptance and scores and
|
44 |
+
insert as children of self node
|
45 |
+
"""
|
46 |
+
likelihoods = F.softmax(scores, dim=-1)
|
47 |
+
|
48 |
+
for batch_index in range(acceptance.size(0)):
|
49 |
+
accepted_tokens = acceptance[batch_index].nonzero().squeeze(-1)
|
50 |
+
|
51 |
+
for token_id in accepted_tokens:
|
52 |
+
if token_id not in self.children:
|
53 |
+
raw_likelihood = likelihoods[batch_index, token_id].item()
|
54 |
+
raw_score = scores[batch_index, token_id].item()
|
55 |
+
|
56 |
+
child_node = TrieNode(
|
57 |
+
token_id=token_id.item(),
|
58 |
+
raw_likelihood=raw_likelihood,
|
59 |
+
raw_score=raw_score)
|
60 |
+
|
61 |
+
self.insert(child_node)
|
62 |
+
|
63 |
+
def get_success_rate(self, token_id):
|
64 |
+
"""
|
65 |
+
Return Approximated Expected Future Grammaticality of the token_id
|
66 |
+
"""
|
67 |
+
if token_id in self.children:
|
68 |
+
return self.children[token_id].success_rate
|
69 |
+
else:
|
70 |
+
return 1
|
71 |
+
|
72 |
+
def update_success_rate(self):
|
73 |
+
"""
|
74 |
+
Re-compute the success rate from the updated success rate of children
|
75 |
+
"""
|
76 |
+
if self.children:
|
77 |
+
total_success_rate = sum(child.raw_likelihood * child.success_rate for child in self.children.values())
|
78 |
+
|
79 |
+
# Get how much of unexplored nodes are covered with this update
|
80 |
+
updated_rate = self.success_rate - total_success_rate
|
81 |
+
self.success_rate = total_success_rate
|
82 |
+
|
83 |
+
# Back propagate the success rate
|
84 |
+
if self.parent:
|
85 |
+
return self.parent.update_success_rate()
|
86 |
+
|
87 |
+
return updated_rate
|
88 |
+
|
89 |
+
def prefix_raw_likelihood(self):
|
90 |
+
if self.parent:
|
91 |
+
return self.raw_likelihood * self.parent.prefix_raw_likelihood()
|
92 |
+
else:
|
93 |
+
return self.raw_likelihood
|
94 |
+
|
95 |
+
def search_token(self, token_id):
|
96 |
+
"""
|
97 |
+
Check if the self node has a children with token_id
|
98 |
+
Return the children node if it exists, return None otherwise
|
99 |
+
"""
|
100 |
+
if token_id in self.children:
|
101 |
+
return self.children[token_id]
|
102 |
+
else:
|
103 |
+
return None
|
104 |
+
|
105 |
+
def to_dict(self):
|
106 |
+
"""
|
107 |
+
Convert a trie into a dictionary by removing the pointer to the parent
|
108 |
+
"""
|
109 |
+
return {
|
110 |
+
"token_id": self.token_id,
|
111 |
+
"raw_likelihood": self.raw_likelihood,
|
112 |
+
"raw_score": self.raw_score,
|
113 |
+
"success_rate": self.success_rate,
|
114 |
+
"eos_token_id": self.eos_token_id,
|
115 |
+
"is_start_of_sequence": self.is_start_of_sequence,
|
116 |
+
"is_end_of_sequence": self.is_end_of_sequence,
|
117 |
+
"children": [child.to_dict() for child in self.children.values()]
|
118 |
+
}
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def from_dict(d):
|
122 |
+
"""
|
123 |
+
Recursively (re)construct trie from dictionary
|
124 |
+
"""
|
125 |
+
node = TrieNode(
|
126 |
+
token_id=d['token_id'],
|
127 |
+
raw_likelihood=d['raw_likelihood'],
|
128 |
+
raw_score=d['raw_score'],
|
129 |
+
success_rate=d['success_rate'],
|
130 |
+
is_start_of_sequence=d['is_start_of_sequence'],
|
131 |
+
is_end_of_sequence=d['is_end_of_sequence'],
|
132 |
+
eos_token_id=d['eos_token_id'])
|
133 |
+
|
134 |
+
node.children = {child['token_id']:TrieNode.from_dict(child) for child in node.children}
|
135 |
+
for child in node.children.values():
|
136 |
+
child.parent = node
|
137 |
+
|
138 |
+
return node
|
139 |
+
|
140 |
+
def __repr__(self):
|
141 |
+
parent_token_id = 'None (Root Node)' if self.parent is None else self.parent.token_id
|
142 |
+
return (f"TrieNode(token_id={self.token_id}', "
|
143 |
+
f"raw_likelihood={self.raw_likelihood}, raw_score={self.raw_score}, children={list(self.children.keys())}, "
|
144 |
+
f"parent={parent_token_id}, success rate={self.success_rate})")
|
145 |
+
|
146 |
+
class Trie:
|
147 |
+
def __init__(self):
|
148 |
+
self.root = TrieNode()
|
149 |
+
self.root.is_start_of_sequence = True
|
150 |
+
|
151 |
+
def search_last_parent(self, prefix: torch.LongTensor):
|
152 |
+
"""
|
153 |
+
Search the longest prefix in the trie that matches to the input sequence of tokens 'prefix'
|
154 |
+
"""
|
155 |
+
matched_prefix = []
|
156 |
+
current_parent = self.root
|
157 |
+
|
158 |
+
# Assume one batch of prefix
|
159 |
+
for time_step, token_id in enumerate(prefix[0]):
|
160 |
+
token_id = token_id.item()
|
161 |
+
if token_id in current_parent.children:
|
162 |
+
current_parent = current_parent.children[token_id]
|
163 |
+
matched_prefix.append(current_parent.token_id)
|
164 |
+
else:
|
165 |
+
print(
|
166 |
+
f"matched prefix is {matched_prefix}; current {token_id} not found in the trie at time step {time_step}")
|
167 |
+
return None
|
168 |
+
|
169 |
+
return current_parent
|
170 |
+
|
171 |
+
def search(self, sequence):
|
172 |
+
"""
|
173 |
+
Return the sequence of nodes that exactly matches with the input
|
174 |
+
"""
|
175 |
+
node = self.root
|
176 |
+
nodes = []
|
177 |
+
for token_id in sequence:
|
178 |
+
if token_id not in node.children:
|
179 |
+
return None
|
180 |
+
node = node.children[token_id]
|
181 |
+
nodes.append(node)
|
182 |
+
return nodes
|
183 |
+
|
184 |
+
def raw_likelihood(self, sequence):
|
185 |
+
"""
|
186 |
+
Return the raw likelihood (before the adjustment) of sequence
|
187 |
+
"""
|
188 |
+
if isinstance(sequence, torch.Tensor):
|
189 |
+
sequence = sequence.tolist()
|
190 |
+
|
191 |
+
nodes = self.search(sequence)
|
192 |
+
if nodes is None:
|
193 |
+
return None
|
194 |
+
|
195 |
+
likelihood = 1
|
196 |
+
for node in nodes:
|
197 |
+
likelihood *= node.raw_likelihood
|
198 |
+
return likelihood
|
199 |
+
|
200 |
+
def json(self):
|
201 |
+
return json.dumps(self.root.to_dict(), indent=2)
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def loads(js):
|
205 |
+
trie = Trie()
|
206 |
+
trie.root = TrieNode.from_dict(json.loads(js))
|
207 |
+
|
208 |
+
return trie
|
209 |
+
|
210 |
+
def print_trie(self, node=None, prefix=None):
|
211 |
+
"""
|
212 |
+
Print all the leaves in the trie
|
213 |
+
"""
|
214 |
+
if node is None:
|
215 |
+
node = self.root
|
216 |
+
if prefix is None:
|
217 |
+
prefix = []
|
218 |
+
|
219 |
+
# If current node marks the end of a sequence, print the prefix as a list
|
220 |
+
if node.is_end_of_sequence or len(node.children) == 0:
|
221 |
+
print(prefix)
|
222 |
+
|
223 |
+
# Recursively call print_trie for all children, appending the current character/token to the prefix
|
224 |
+
for char, child_node in node.children.items():
|
225 |
+
self.print_trie(child_node, prefix + [char])
|
226 |
+
|
227 |
+
def has_full_information(self):
|
228 |
+
"""
|
229 |
+
Checks if all paths in the trie end with an is_end_of_sequence node set to True.
|
230 |
+
Returns True if the trie has full information, False otherwise.
|
231 |
+
"""
|
232 |
+
return self._check_full_information(self.root)
|
233 |
+
|
234 |
+
def _check_full_information(self, node):
|
235 |
+
# If the node has no children, check if it is marked as the end of a sequence
|
236 |
+
if not node.children:
|
237 |
+
return node.is_end_of_sequence
|
238 |
+
|
239 |
+
# Recursively check all children
|
240 |
+
return all(self._check_full_information(child) for child in node.children.values())
|
241 |
+
|
242 |
+
def print_all_nodes(self, node=None, depth=0):
|
243 |
+
"""
|
244 |
+
Print all the nodes in the trie (including non-leaves)
|
245 |
+
"""
|
246 |
+
|
247 |
+
if node is None:
|
248 |
+
node = self.root
|
249 |
+
|
250 |
+
# Print current node's details
|
251 |
+
indent = " " * depth # Create indentation based on the depth in the trie
|
252 |
+
node_details = (f"{indent}TrieNode(token_id={node.token_id}', "
|
253 |
+
f"raw_likelihood={node.raw_likelihood}, raw_score={node.raw_score}, success rate={node.success_rate}, "
|
254 |
+
f"children={list(node.children.keys())}, "
|
255 |
+
f"parent={node.parent.token_id if node.parent else None}, "
|
256 |
+
f"is_end_of_sequence={node.is_end_of_sequence})")
|
257 |
+
print(node_details)
|
258 |
+
|
259 |
+
# Recursively call print_all_nodes for all children
|
260 |
+
for child_node in node.children.values():
|
261 |
+
self.print_all_nodes(child_node, depth + 1)
|
transformers_gad/parser.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import sys
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
END_OF_ALTERNATE_MARKER = 0
|
9 |
+
END_OF_RULE_MARKER = 0
|
10 |
+
END_OF_GRAMMAR_MARKER = 0xFFFF
|
11 |
+
TO_BE_FILLED_MARKER = 0
|
12 |
+
REF_RULE_MARKER = 1
|
13 |
+
LITERAL_MARKER = 2
|
14 |
+
|
15 |
+
|
16 |
+
########################
|
17 |
+
# EBNF Grammar Parsing #
|
18 |
+
########################
|
19 |
+
|
20 |
+
|
21 |
+
class ParseState:
|
22 |
+
def __init__(self):
|
23 |
+
self.symbol_table = {}
|
24 |
+
self.grammar_encoding = [] # old name: out_grammar
|
25 |
+
|
26 |
+
def print(self, file=sys.stdout):
|
27 |
+
print_grammar(file, self)
|
28 |
+
|
29 |
+
|
30 |
+
def get_symbol_id(state: ParseState, symbol_name: str) -> int:
|
31 |
+
if symbol_name not in state.symbol_table:
|
32 |
+
state.symbol_table[symbol_name] = len(state.symbol_table)
|
33 |
+
return state.symbol_table[symbol_name]
|
34 |
+
|
35 |
+
|
36 |
+
def generate_symbol_id(state: ParseState, base_name: str) -> int:
|
37 |
+
next_id = len(state.symbol_table)
|
38 |
+
state.symbol_table[base_name + "_" + str(next_id)] = next_id
|
39 |
+
return next_id
|
40 |
+
|
41 |
+
|
42 |
+
def is_word_char(c: str) -> bool:
|
43 |
+
"""
|
44 |
+
Check if a char is a-z, A-Z, 0-9, -, _, i.e., chars allowed as rule names
|
45 |
+
Returns:
|
46 |
+
|
47 |
+
"""
|
48 |
+
return c.isalnum() or c == "-" or c == "_"
|
49 |
+
|
50 |
+
|
51 |
+
def hex_to_int(c: str) -> int:
|
52 |
+
"""
|
53 |
+
Convert a hex char to int, c should be in the range of 0-9, a-f, A-F
|
54 |
+
case insensitive
|
55 |
+
Args:
|
56 |
+
c: a hex char
|
57 |
+
Returns:
|
58 |
+
int: the int value of the hex char
|
59 |
+
"""
|
60 |
+
if c.isdigit():
|
61 |
+
return int(c)
|
62 |
+
elif "a" <= c.lower() <= "f":
|
63 |
+
return ord(c.lower()) - ord("a") + 10
|
64 |
+
return -1
|
65 |
+
|
66 |
+
|
67 |
+
def remove_leading_white_space(src, rm_leading_newline):
|
68 |
+
"""
|
69 |
+
Skips over whitespace and comments in the input string.
|
70 |
+
|
71 |
+
This function processes the input string, skipping over any spaces, tabs,
|
72 |
+
and content following a '#' character, which denotes a comment. The parsing
|
73 |
+
of a comment continues until the end of the line (denoted by newline characters
|
74 |
+
'\r' or '\n'). If the 'rm_leading_newline' parameter is set to False, the function
|
75 |
+
will stop processing and return the remaining string upon encountering a
|
76 |
+
newline character, otherwise it will skip over newline characters as well.
|
77 |
+
|
78 |
+
Parameters:
|
79 |
+
src (str): The input string to be processed.
|
80 |
+
rm_leading_newline (bool): A flag indicating whether encountering a newline character
|
81 |
+
should stop the parsing (False) or if it should be skipped (True).
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
str: The remaining portion of the input string after skipping whitespace and comments.
|
85 |
+
"""
|
86 |
+
pos = 0
|
87 |
+
while pos < len(src) and (src[pos].isspace() or src[pos] == "#"):
|
88 |
+
if src[pos] == "#":
|
89 |
+
while pos < len(src) and src[pos] not in ("\r", "\n"):
|
90 |
+
pos += 1
|
91 |
+
else:
|
92 |
+
if not rm_leading_newline and src[pos] in ("\r", "\n"):
|
93 |
+
break
|
94 |
+
pos += 1
|
95 |
+
return src[pos:]
|
96 |
+
|
97 |
+
|
98 |
+
def parse_name(src) -> (str, str):
|
99 |
+
"""
|
100 |
+
parse the leading name from the input string
|
101 |
+
Args:
|
102 |
+
src: the input grammar string
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
name, remaining_src
|
106 |
+
"""
|
107 |
+
pos = 0
|
108 |
+
while pos < len(src) and is_word_char(src[pos]):
|
109 |
+
pos += 1
|
110 |
+
if pos == 0:
|
111 |
+
raise RuntimeError("expecting name at " + src)
|
112 |
+
return src[:pos], src[pos:]
|
113 |
+
|
114 |
+
|
115 |
+
def parse_char(src) -> (str, str):
|
116 |
+
"""
|
117 |
+
parse the leading char from the input string
|
118 |
+
:param src:
|
119 |
+
:return: char, remaining_src
|
120 |
+
"""
|
121 |
+
|
122 |
+
# if we have a backslash, it's maybe an escape
|
123 |
+
if src[0] == "\\":
|
124 |
+
esc = src[1]
|
125 |
+
if esc == "x":
|
126 |
+
first = hex_to_int(src[2])
|
127 |
+
if first > -1:
|
128 |
+
second = hex_to_int(src[3])
|
129 |
+
if second > -1:
|
130 |
+
return (first << 4) + second, src[4:]
|
131 |
+
raise RuntimeError("expecting \\xNN at " + src)
|
132 |
+
elif esc in ('"', "[", "]"):
|
133 |
+
return esc, src[2:]
|
134 |
+
elif esc == "r":
|
135 |
+
return "\r", src[2:]
|
136 |
+
elif esc == "n":
|
137 |
+
return "\n", src[2:]
|
138 |
+
elif esc == "t":
|
139 |
+
return "\t", src[2:]
|
140 |
+
elif esc == "\\":
|
141 |
+
return "\\", src[2:]
|
142 |
+
elif esc == "/":
|
143 |
+
return "\\", src[1:]
|
144 |
+
raise RuntimeError("unknown escape at " + src)
|
145 |
+
elif src:
|
146 |
+
return src[0], src[1:]
|
147 |
+
raise RuntimeError("unexpected end of input")
|
148 |
+
|
149 |
+
|
150 |
+
def _parse_rhs_literal_string(src: str, outbuf: List[int]) -> str:
|
151 |
+
assert src[0] == '"', f"rule should start with '\"', but got {src[0]}"
|
152 |
+
remaining_src = src[1:]
|
153 |
+
|
154 |
+
# advance until we get an end quote or run out of input
|
155 |
+
while remaining_src and remaining_src[0] != '"':
|
156 |
+
char, remaining_src = parse_char(remaining_src)
|
157 |
+
outbuf.append(LITERAL_MARKER)
|
158 |
+
# print(f"char: {char}")
|
159 |
+
outbuf.append(ord(char))
|
160 |
+
outbuf.append(ord(char))
|
161 |
+
|
162 |
+
# in case we ran out of input before finding the end quote
|
163 |
+
if not remaining_src:
|
164 |
+
raise RuntimeError(f"expecting an end quote at {src},but not found")
|
165 |
+
|
166 |
+
# remove the end quote and return the remaining string
|
167 |
+
return remaining_src[1:]
|
168 |
+
|
169 |
+
|
170 |
+
def _parse_rhs_char_ranges(src: str, outbuf: List[int]) -> str:
|
171 |
+
assert src[0] == "[", f"rule should start with '[', but got {src[0]}"
|
172 |
+
remaining_src = src[1:]
|
173 |
+
start_idx = len(outbuf)
|
174 |
+
# num chars in range - replaced at end of loop
|
175 |
+
outbuf.append(TO_BE_FILLED_MARKER)
|
176 |
+
while remaining_src and remaining_src[0] != "]":
|
177 |
+
char, remaining_src = parse_char(remaining_src)
|
178 |
+
|
179 |
+
outbuf.append(ord(char))
|
180 |
+
if remaining_src[0] == "-" and remaining_src[1] != "]":
|
181 |
+
endchar_pair, remaining_src = parse_char(remaining_src[1:])
|
182 |
+
outbuf.append(ord(endchar_pair))
|
183 |
+
else:
|
184 |
+
# This is the case for enumerate, e.g., [0123456789], [abcdef]
|
185 |
+
# Each char is considered as a range of itself, i.e., c-c
|
186 |
+
outbuf.append(ord(char))
|
187 |
+
if not remaining_src:
|
188 |
+
raise RuntimeError(
|
189 |
+
f"expecting an ] at {src},but not found, is the char range closed?"
|
190 |
+
)
|
191 |
+
# replace num chars with actual
|
192 |
+
outbuf[start_idx] = len(outbuf) - start_idx - 1
|
193 |
+
return remaining_src[1:]
|
194 |
+
|
195 |
+
|
196 |
+
def _parse_rhs_symbol_reference(src: str, state: ParseState, outbuf: List[int]) -> str:
|
197 |
+
assert is_word_char(src[0]), f"rule should start with a word char, but got {src[0]}"
|
198 |
+
name, remaining_src = parse_name(src)
|
199 |
+
ref_rule_id = get_symbol_id(state, name)
|
200 |
+
outbuf.append(REF_RULE_MARKER)
|
201 |
+
outbuf.append(ref_rule_id)
|
202 |
+
return remaining_src
|
203 |
+
|
204 |
+
|
205 |
+
def _parse_rhs_grouping(
|
206 |
+
remaining_src: str, state: ParseState, rule_name: str, outbuf: List[int]
|
207 |
+
) -> str:
|
208 |
+
assert (
|
209 |
+
remaining_src[0] == "("
|
210 |
+
), f"rule should start with '(', but got {remaining_src[0]}"
|
211 |
+
remaining_src = remove_leading_white_space(remaining_src[1:], True)
|
212 |
+
# parse nested alternates into synthesized rule
|
213 |
+
synthetic_rule_id = generate_symbol_id(state, rule_name)
|
214 |
+
remaining_src = parse_rhs(state, remaining_src, rule_name, synthetic_rule_id, True)
|
215 |
+
# output reference to synthesized rule
|
216 |
+
outbuf.append(REF_RULE_MARKER)
|
217 |
+
outbuf.append(synthetic_rule_id)
|
218 |
+
|
219 |
+
if not remaining_src or remaining_src[0] != ")":
|
220 |
+
raise RuntimeError("expecting ')' at " + remaining_src)
|
221 |
+
return remaining_src[1:]
|
222 |
+
|
223 |
+
|
224 |
+
def _parse_rhs_repetition_operators(
|
225 |
+
remaining_src: str,
|
226 |
+
state: ParseState,
|
227 |
+
rule_name: str,
|
228 |
+
last_sym_start: int,
|
229 |
+
outbuf: List[int],
|
230 |
+
) -> str:
|
231 |
+
assert remaining_src[0] in (
|
232 |
+
"*",
|
233 |
+
"+",
|
234 |
+
"?",
|
235 |
+
), f"rule should start with '*', '+', or '?', but got {remaining_src[0]}"
|
236 |
+
out_grammar = state.grammar_encoding
|
237 |
+
# last_sym_start = len(outbuf)
|
238 |
+
|
239 |
+
# apply transformation to previous symbol (last_sym_start -
|
240 |
+
# end) according to rewrite rules:
|
241 |
+
# S* --> S' ::= S S' |
|
242 |
+
# S+ --> S' ::= S S' | S
|
243 |
+
# S? --> S' ::= S |
|
244 |
+
sub_rule_id = generate_symbol_id(state, rule_name)
|
245 |
+
out_grammar.append(sub_rule_id)
|
246 |
+
sub_rule_offset = len(out_grammar)
|
247 |
+
# placeholder for size of 1st alternate
|
248 |
+
out_grammar.append(TO_BE_FILLED_MARKER)
|
249 |
+
# add preceding symbol to generated rule
|
250 |
+
out_grammar.extend(outbuf[last_sym_start:])
|
251 |
+
if remaining_src[0] in ("*", "+"):
|
252 |
+
# cause generated rule to recurse
|
253 |
+
out_grammar.append(REF_RULE_MARKER)
|
254 |
+
out_grammar.append(sub_rule_id)
|
255 |
+
# apply actual size
|
256 |
+
out_grammar[sub_rule_offset] = len(out_grammar) - sub_rule_offset
|
257 |
+
# mark end of 1st alternate
|
258 |
+
out_grammar.append(END_OF_ALTERNATE_MARKER)
|
259 |
+
sub_rule_offset = len(out_grammar)
|
260 |
+
# placeholder for size of 2nd alternate
|
261 |
+
out_grammar.append(TO_BE_FILLED_MARKER)
|
262 |
+
if remaining_src[0] == "+":
|
263 |
+
# add preceding symbol as alternate only for '+'
|
264 |
+
out_grammar.extend(outbuf[last_sym_start:])
|
265 |
+
# apply actual size of 2nd alternate
|
266 |
+
out_grammar[sub_rule_offset] = len(out_grammar) - sub_rule_offset
|
267 |
+
# mark end of 2nd alternate, then end of rule
|
268 |
+
out_grammar.append(END_OF_ALTERNATE_MARKER)
|
269 |
+
out_grammar.append(END_OF_RULE_MARKER)
|
270 |
+
|
271 |
+
# in original rule, replace previous symbol with reference to generated rule
|
272 |
+
outbuf[last_sym_start:] = [REF_RULE_MARKER, sub_rule_id]
|
273 |
+
return remaining_src[1:]
|
274 |
+
|
275 |
+
|
276 |
+
def parse_simple_rhs(state, rhs: str, rule_name: str, outbuf, is_nested):
|
277 |
+
simple_rhs_offset = len(outbuf)
|
278 |
+
|
279 |
+
# sequence size, will be replaced at end when known
|
280 |
+
outbuf.append(TO_BE_FILLED_MARKER)
|
281 |
+
|
282 |
+
last_sym_start = len(outbuf)
|
283 |
+
remaining_rhs = rhs
|
284 |
+
while remaining_rhs:
|
285 |
+
if remaining_rhs[0] == '"': # literal string
|
286 |
+
# mark the start of the last symbol, for repetition operator
|
287 |
+
last_sym_start = len(outbuf)
|
288 |
+
remaining_rhs = _parse_rhs_literal_string(remaining_rhs, outbuf)
|
289 |
+
elif remaining_rhs[0] == "[": # char range(s)
|
290 |
+
# mark the start of the last symbol, for repetition operator
|
291 |
+
last_sym_start = len(outbuf)
|
292 |
+
remaining_rhs = _parse_rhs_char_ranges(remaining_rhs, outbuf)
|
293 |
+
elif is_word_char(remaining_rhs[0]): # rule reference
|
294 |
+
# mark the start of the last symbol, for repetition operator
|
295 |
+
last_sym_start = len(outbuf)
|
296 |
+
remaining_rhs = _parse_rhs_symbol_reference(remaining_rhs, state, outbuf)
|
297 |
+
elif remaining_rhs[0] == "(": # grouping
|
298 |
+
# mark the start of the last symbol, for repetition operator
|
299 |
+
last_sym_start = len(outbuf)
|
300 |
+
remaining_rhs = _parse_rhs_grouping(remaining_rhs, state, rule_name, outbuf)
|
301 |
+
elif remaining_rhs[0] in ("*", "+", "?"): # repetition operator
|
302 |
+
# No need to mark the start of the last symbol, because we already did it
|
303 |
+
if len(outbuf) - simple_rhs_offset - 1 == 0:
|
304 |
+
raise RuntimeError(
|
305 |
+
"expecting preceeding item to */+/? at " + remaining_rhs
|
306 |
+
)
|
307 |
+
remaining_rhs = _parse_rhs_repetition_operators(
|
308 |
+
remaining_rhs, state, rule_name, last_sym_start, outbuf
|
309 |
+
)
|
310 |
+
else:
|
311 |
+
# case for newline, i.e., end of rule
|
312 |
+
assert remaining_rhs[0] in [
|
313 |
+
"\n",
|
314 |
+
"|",
|
315 |
+
")",
|
316 |
+
], f"rule should end with newline or '|', but got {remaining_rhs[0]}"
|
317 |
+
# we break here so that we call parse_rule again to parse the next rule
|
318 |
+
break
|
319 |
+
# Here we do not rm newline deliberately so that we know the rhs is ended
|
320 |
+
remaining_rhs = remove_leading_white_space(
|
321 |
+
remaining_rhs, rm_leading_newline=is_nested
|
322 |
+
)
|
323 |
+
|
324 |
+
# apply actual size of this alternate sequence
|
325 |
+
outbuf[simple_rhs_offset] = len(outbuf) - simple_rhs_offset
|
326 |
+
# mark end of alternate
|
327 |
+
outbuf.append(END_OF_ALTERNATE_MARKER)
|
328 |
+
return remaining_rhs
|
329 |
+
|
330 |
+
|
331 |
+
def parse_rhs(state, rhs: str, rule_name, rule_id, is_nested):
|
332 |
+
outbuf = []
|
333 |
+
remaining_rhs = parse_simple_rhs(state, rhs, rule_name, outbuf, is_nested)
|
334 |
+
while remaining_rhs and remaining_rhs[0] == "|":
|
335 |
+
remaining_rhs = remove_leading_white_space(remaining_rhs[1:], True)
|
336 |
+
remaining_rhs = parse_simple_rhs(
|
337 |
+
state, remaining_rhs, rule_name, outbuf, is_nested
|
338 |
+
)
|
339 |
+
|
340 |
+
# Now we have finished parsing the rhs, we can add the rule to the grammar_encoding
|
341 |
+
state.grammar_encoding.append(rule_id)
|
342 |
+
state.grammar_encoding.extend(outbuf)
|
343 |
+
state.grammar_encoding.append(END_OF_RULE_MARKER)
|
344 |
+
return remaining_rhs
|
345 |
+
|
346 |
+
|
347 |
+
def parse_rule(state: ParseState, rule_text: str) -> str:
|
348 |
+
name, remaining_rule_text = parse_name(rule_text)
|
349 |
+
remaining_rule_text = remove_leading_white_space(remaining_rule_text, False)
|
350 |
+
# check if the rule is already defined, TODO: what will happen if the rule is already defined?
|
351 |
+
rule_id = get_symbol_id(state, name)
|
352 |
+
|
353 |
+
if remaining_rule_text[:3] != "::=":
|
354 |
+
raise RuntimeError("expecting ::= at " + remaining_rule_text)
|
355 |
+
remaining_rule_text = remove_leading_white_space(remaining_rule_text[3:], True)
|
356 |
+
|
357 |
+
remaining_rule_text = parse_rhs(state, remaining_rule_text, name, rule_id, False)
|
358 |
+
|
359 |
+
if remaining_rule_text and remaining_rule_text[0] == "\r":
|
360 |
+
remaining_rule_text = (
|
361 |
+
remaining_rule_text[2:]
|
362 |
+
if remaining_rule_text[1] == "\n"
|
363 |
+
else remaining_rule_text[1:]
|
364 |
+
)
|
365 |
+
elif remaining_rule_text and remaining_rule_text[0] == "\n":
|
366 |
+
remaining_rule_text = remaining_rule_text[1:]
|
367 |
+
elif remaining_rule_text:
|
368 |
+
raise RuntimeError("expecting newline or end at " + remaining_rule_text)
|
369 |
+
return remove_leading_white_space(remaining_rule_text, True)
|
370 |
+
|
371 |
+
|
372 |
+
def parse_ebnf(grammar_text: str) -> ParseState:
|
373 |
+
try:
|
374 |
+
state = ParseState()
|
375 |
+
remaining_grammar_text = remove_leading_white_space(grammar_text, True)
|
376 |
+
last_grammar_repr = ""
|
377 |
+
while remaining_grammar_text:
|
378 |
+
if last_grammar_repr:
|
379 |
+
last_parsed_rule_len = len(last_grammar_repr) - len(
|
380 |
+
remaining_grammar_text
|
381 |
+
)
|
382 |
+
logger.debug(
|
383 |
+
f"last_parsed_rule: {last_grammar_repr[:last_parsed_rule_len]}"
|
384 |
+
)
|
385 |
+
last_grammar_repr = remaining_grammar_text
|
386 |
+
remaining_grammar_text = parse_rule(state, remaining_grammar_text)
|
387 |
+
state.grammar_encoding.append(END_OF_GRAMMAR_MARKER)
|
388 |
+
return state
|
389 |
+
except RuntimeError as err:
|
390 |
+
logger.warning("error parsing grammar:", err)
|
391 |
+
return ParseState()
|
392 |
+
|
393 |
+
|
394 |
+
###################################
|
395 |
+
# EBNF Grammar Parsing ends here #
|
396 |
+
###################################
|
397 |
+
|
398 |
+
|
399 |
+
def break_grammar_into_rules(grammar_encoding: List[int]) -> List[List[int]]:
|
400 |
+
offset = 0
|
401 |
+
# we loop until we reach the end of the grammar_encoding
|
402 |
+
rule_encodings = []
|
403 |
+
i = 0
|
404 |
+
while i < len(grammar_encoding) - 2:
|
405 |
+
if (
|
406 |
+
grammar_encoding[i] == END_OF_ALTERNATE_MARKER
|
407 |
+
and grammar_encoding[i + 1] == END_OF_RULE_MARKER
|
408 |
+
):
|
409 |
+
rule_encodings.append(grammar_encoding[offset : i + 2])
|
410 |
+
offset = i + 2
|
411 |
+
# skip the END_OF_RULE_MARKER
|
412 |
+
# This is mandatory because if we do not skip the END_OF_RULE_MARKER
|
413 |
+
# we fail in the case where the next rule has rule_id 0
|
414 |
+
i += 1
|
415 |
+
i += 1
|
416 |
+
return rule_encodings
|
417 |
+
|
418 |
+
|
419 |
+
def break_rule_into_elements(rule_encoding: List[int]) -> List[List[int]]:
|
420 |
+
rule_id = rule_encoding.pop(0)
|
421 |
+
end_of_rule_marker = rule_encoding.pop(-1)
|
422 |
+
assert (
|
423 |
+
end_of_rule_marker == END_OF_RULE_MARKER
|
424 |
+
), f"rule should end with {END_OF_RULE_MARKER}, but got {end_of_rule_marker}"
|
425 |
+
|
426 |
+
offset = 0
|
427 |
+
elements = []
|
428 |
+
while offset < len(rule_encoding):
|
429 |
+
element_size = rule_encoding[offset]
|
430 |
+
assert (
|
431 |
+
rule_encoding[offset + element_size] == END_OF_ALTERNATE_MARKER
|
432 |
+
), f"element should end with {END_OF_ALTERNATE_MARKER}, but got {rule_encoding[offset + element_size]}"
|
433 |
+
elements.append(rule_encoding[offset : offset + element_size + 1])
|
434 |
+
offset += element_size + 1
|
435 |
+
return elements
|
436 |
+
|
437 |
+
|
438 |
+
def _print_annotated_grammar(file, grammar_encoding, symbol_id_names, index=0):
|
439 |
+
rule_id = grammar_encoding[index]
|
440 |
+
print(f"<{index}>{symbol_id_names[rule_id]} ::=", end=" ", file=file)
|
441 |
+
pos = index + 1
|
442 |
+
while grammar_encoding[pos]:
|
443 |
+
if pos - 1 > index:
|
444 |
+
print("|", end=" ", file=file)
|
445 |
+
pos += 1 # sequence size, not needed here
|
446 |
+
while grammar_encoding[pos]:
|
447 |
+
if grammar_encoding[pos] == REF_RULE_MARKER:
|
448 |
+
ref_rule_id = grammar_encoding[pos + 1]
|
449 |
+
print(
|
450 |
+
f"<{pos}>{symbol_id_names[ref_rule_id]}",
|
451 |
+
end=" ",
|
452 |
+
file=file,
|
453 |
+
)
|
454 |
+
pos += 2
|
455 |
+
else:
|
456 |
+
print("<{}>[".format(pos), end="", file=file)
|
457 |
+
num_chars = grammar_encoding[pos]
|
458 |
+
pos += 1
|
459 |
+
|
460 |
+
for i in range(0, num_chars, 2):
|
461 |
+
print(
|
462 |
+
"{}-".format(chr(grammar_encoding[pos + i])), end="", file=file
|
463 |
+
)
|
464 |
+
if i + 1 < num_chars:
|
465 |
+
print(
|
466 |
+
"{}".format(chr(grammar_encoding[pos + i + 1])),
|
467 |
+
end="",
|
468 |
+
file=file,
|
469 |
+
)
|
470 |
+
print("]", end=" ", file=file)
|
471 |
+
pos += num_chars
|
472 |
+
pos += 1
|
473 |
+
print(file=file)
|
474 |
+
return pos + 1
|
475 |
+
|
476 |
+
|
477 |
+
def print_grammar(file, state):
|
478 |
+
pos = 0
|
479 |
+
symbol_id_names = {v: k for k, v in state.symbol_table.items()}
|
480 |
+
print("Grammar Rules:", file=file)
|
481 |
+
while (
|
482 |
+
pos < len(state.grammar_encoding)
|
483 |
+
and state.grammar_encoding[pos] != END_OF_GRAMMAR_MARKER
|
484 |
+
):
|
485 |
+
pos = _print_annotated_grammar(
|
486 |
+
file, state.grammar_encoding, symbol_id_names, pos
|
487 |
+
)
|
488 |
+
if pos > len(state.grammar_encoding):
|
489 |
+
raise Warning(f"grammar_encoding is not ended with {END_OF_GRAMMAR_MARKER}")
|
490 |
+
pos = 0
|
491 |
+
print("\nGrammar Hex representation:", file=file)
|
492 |
+
while (
|
493 |
+
pos < len(state.grammar_encoding)
|
494 |
+
and state.grammar_encoding[pos] != END_OF_GRAMMAR_MARKER
|
495 |
+
):
|
496 |
+
print(f"{state.grammar_encoding[pos]:04x}", end=" ", file=file)
|
497 |
+
pos += 1
|
498 |
+
if pos > len(state.grammar_encoding):
|
499 |
+
raise Warning(f"grammar_encoding is not ended with {END_OF_GRAMMAR_MARKER}")
|
500 |
+
else:
|
501 |
+
print("ffff\n")
|
502 |
+
|
503 |
+
print("Rules Decimal representation:", file=file)
|
504 |
+
# we loop until we reach the end of the grammar_encoding
|
505 |
+
rule_encodings = break_grammar_into_rules(state.grammar_encoding)
|
506 |
+
for rule_encoding in rule_encodings:
|
507 |
+
rule_id = rule_encoding[0]
|
508 |
+
print(
|
509 |
+
f"<{rule_id}> {break_rule_into_elements(rule_encoding)}",
|
510 |
+
file=file,
|
511 |
+
)
|
512 |
+
|
513 |
+
|
514 |
+
if __name__ == "__main__":
|
515 |
+
parser = argparse.ArgumentParser(description="Parse EBNF grammar files.")
|
516 |
+
parser.add_argument(
|
517 |
+
"-g",
|
518 |
+
"--grammar-file",
|
519 |
+
nargs="?",
|
520 |
+
default="/nobackup2/yf/mila/GD/examples/sygus/PRE_100_bare.ebnf",
|
521 |
+
help="Path to the grammar file",
|
522 |
+
)
|
523 |
+
|
524 |
+
args = parser.parse_args()
|
525 |
+
|
526 |
+
# set logging level
|
527 |
+
logging.basicConfig(level=logging.DEBUG)
|
528 |
+
|
529 |
+
with open(args.grammar_file, "r") as file:
|
530 |
+
input_text = file.read()
|
531 |
+
parsed_grammar = parse_ebnf(input_text)
|
532 |
+
print("parse state:")
|
533 |
+
parsed_grammar.print()
|
534 |
+
# print(f"symbol_ids: \n{parsed_grammar.symbol_table}")
|
535 |
+
|
536 |
+
# start_rule_id = parsed_grammar.symbol_table["root"]
|
537 |
+
|
538 |
+
# DEBUG: __main__:last_parsed_rule: root: := "0"d | "1"a
|
539 |
+
#
|
540 |
+
# DEBUG: __main__:last_parsed_rule: a: := "0"c | "1"b
|
541 |
+
#
|
542 |
+
# DEBUG: __main__:last_parsed_rule: b: := "0" | "1"
|
543 |
+
#
|
544 |
+
# DEBUG: __main__:last_parsed_rule: c: := "0" | "1"
|
545 |
+
#
|
546 |
+
# DEBUG: __main__:last_parsed_rule: d: := "0"e
|
547 |
+
#
|
548 |
+
# parse state:
|
549 |
+
# Grammar Rules:
|
550 |
+
# < 0 > root: := < 2 > [0 - 0] < 5 > d | < 9 > [1 - 1] < 12 > a
|
551 |
+
# < 16 > a: := < 18 > [0 - 0] < 21 > c | < 25 > [1 - 1] < 28 > b
|
552 |
+
# < 32 > b: := < 34 > [0 - 0] | < 39 > [1 - 1]
|
553 |
+
# < 44 > c: := < 46 > [0 - 0] | < 51 > [1 - 1]
|
554 |
+
# < 56 > d: := < 58 > [0 - 0] < 61 > e
|
555 |
+
# < 65 > e: := < 67 > [0 - 0]
|
556 |
+
#
|
557 |
+
# Grammar Hex representation:
|
558 |
+
# 0000 0006 0002 0030 0030 0001 0001 0000
|
559 |
+
# 0006 0002 0031 0031 0001 0002 0000 0000
|
560 |
+
# 0002 0006 0002 0030 0030 0001 0003 0000
|
561 |
+
# 0006 0002 0031 0031 0001 0004 0000 0000
|
562 |
+
# 0004 0004 0002 0030 0030 0000 0004 0002
|
563 |
+
# 0031 0031 0000 0000 0003 0004 0002 0030
|
564 |
+
# 0030 0000 0004 0002 0031 0031 0000 0000
|
565 |
+
# 0001 0006 0002 0030 0030 0001 0005 0000
|
566 |
+
# 0000 0005 0004 0002 0030 0030 0000 0000 ffff
|
567 |
+
#
|
568 |
+
# Rules Decimal representation:
|
569 |
+
# < 0 > [[6, 2, 48, 48, 1, 1, 0], [6, 2, 49, 49, 1, 2, 0]]
|
570 |
+
# < 2 > [[6, 2, 48, 48, 1, 3, 0], [6, 2, 49, 49, 1, 4, 0]]
|
571 |
+
# < 4 > [[4, 2, 48, 48, 0], [4, 2, 49, 49, 0]]
|
572 |
+
# < 3 > [[4, 2, 48, 48, 0], [4, 2, 49, 49, 0]]
|
573 |
+
# < 1 > [[6, 2, 48, 48, 1, 5, 0]]
|
574 |
+
# < 5 > [[4, 2, 48, 48, 0]]
|
575 |
+
|
576 |
+
|
transformers_gad/parser_cfg.py
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import sys
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
END_OF_ALTERNATE_MARKER = 0
|
9 |
+
END_OF_RULE_MARKER = 0
|
10 |
+
END_OF_GRAMMAR_MARKER = 0xFFFF
|
11 |
+
TO_BE_FILLED_MARKER = 0
|
12 |
+
REF_RULE_MARKER = 1
|
13 |
+
LITERAL_MARKER = 2
|
14 |
+
|
15 |
+
|
16 |
+
########################
|
17 |
+
# EBNF Grammar Parsing #
|
18 |
+
########################
|
19 |
+
|
20 |
+
|
21 |
+
class ParseState:
|
22 |
+
def __init__(self):
|
23 |
+
self.symbol_table = {}
|
24 |
+
self.grammar_encoding = [] # old name: out_grammar
|
25 |
+
|
26 |
+
def print(self, file=sys.stdout):
|
27 |
+
print_grammar(file, self)
|
28 |
+
|
29 |
+
|
30 |
+
def get_symbol_id(state: ParseState, symbol_name: str) -> int:
|
31 |
+
if symbol_name not in state.symbol_table:
|
32 |
+
state.symbol_table[symbol_name] = len(state.symbol_table)
|
33 |
+
return state.symbol_table[symbol_name]
|
34 |
+
|
35 |
+
|
36 |
+
def generate_symbol_id(state: ParseState, base_name: str) -> int:
|
37 |
+
next_id = len(state.symbol_table)
|
38 |
+
state.symbol_table[base_name + "_" + str(next_id)] = next_id
|
39 |
+
return next_id
|
40 |
+
|
41 |
+
|
42 |
+
def is_word_char(c: str) -> bool:
|
43 |
+
"""
|
44 |
+
Check if a char is a-z, A-Z, 0-9, -, _, i.e., chars allowed as rule names
|
45 |
+
Returns:
|
46 |
+
|
47 |
+
"""
|
48 |
+
return c.isalnum() or c == "-" or c == "_"
|
49 |
+
|
50 |
+
|
51 |
+
def hex_to_int(c: str) -> int:
|
52 |
+
"""
|
53 |
+
Convert a hex char to int, c should be in the range of 0-9, a-f, A-F
|
54 |
+
case insensitive
|
55 |
+
Args:
|
56 |
+
c: a hex char
|
57 |
+
Returns:
|
58 |
+
int: the int value of the hex char
|
59 |
+
"""
|
60 |
+
if c.isdigit():
|
61 |
+
return int(c)
|
62 |
+
elif "a" <= c.lower() <= "f":
|
63 |
+
return ord(c.lower()) - ord("a") + 10
|
64 |
+
return -1
|
65 |
+
|
66 |
+
|
67 |
+
def remove_leading_white_space(src, rm_leading_newline):
|
68 |
+
"""
|
69 |
+
Skips over whitespace and comments in the input string.
|
70 |
+
|
71 |
+
This function processes the input string, skipping over any spaces, tabs,
|
72 |
+
and content following a '#' character, which denotes a comment. The parsing
|
73 |
+
of a comment continues until the end of the line (denoted by newline characters
|
74 |
+
'\r' or '\n'). If the 'rm_leading_newline' parameter is set to False, the function
|
75 |
+
will stop processing and return the remaining string upon encountering a
|
76 |
+
newline character, otherwise it will skip over newline characters as well.
|
77 |
+
|
78 |
+
Parameters:
|
79 |
+
src (str): The input string to be processed.
|
80 |
+
rm_leading_newline (bool): A flag indicating whether encountering a newline character
|
81 |
+
should stop the parsing (False) or if it should be skipped (True).
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
str: The remaining portion of the input string after skipping whitespace and comments.
|
85 |
+
"""
|
86 |
+
pos = 0
|
87 |
+
while pos < len(src) and (src[pos].isspace() or src[pos] == "#"):
|
88 |
+
if src[pos] == "#":
|
89 |
+
while pos < len(src) and src[pos] not in ("\r", "\n"):
|
90 |
+
pos += 1
|
91 |
+
else:
|
92 |
+
if not rm_leading_newline and src[pos] in ("\r", "\n"):
|
93 |
+
break
|
94 |
+
pos += 1
|
95 |
+
return src[pos:]
|
96 |
+
|
97 |
+
|
98 |
+
def parse_name(src) -> (str, str):
|
99 |
+
"""
|
100 |
+
parse the leading name from the input string
|
101 |
+
Args:
|
102 |
+
src: the input grammar string
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
name, remaining_src
|
106 |
+
"""
|
107 |
+
pos = 0
|
108 |
+
while pos < len(src) and is_word_char(src[pos]):
|
109 |
+
pos += 1
|
110 |
+
if pos == 0:
|
111 |
+
raise RuntimeError("expecting name at " + src)
|
112 |
+
return src[:pos], src[pos:]
|
113 |
+
|
114 |
+
|
115 |
+
def parse_char(src) -> (str, str):
|
116 |
+
"""
|
117 |
+
parse the leading char from the input string
|
118 |
+
:param src:
|
119 |
+
:return: char, remaining_src
|
120 |
+
"""
|
121 |
+
|
122 |
+
# if we have a backslash, it's maybe an escape
|
123 |
+
if src[0] == "\\":
|
124 |
+
esc = src[1]
|
125 |
+
if esc == "x":
|
126 |
+
first = hex_to_int(src[2])
|
127 |
+
if first > -1:
|
128 |
+
second = hex_to_int(src[3])
|
129 |
+
if second > -1:
|
130 |
+
return (first << 4) + second, src[4:]
|
131 |
+
raise RuntimeError("expecting \\xNN at " + src)
|
132 |
+
elif esc in ('"', "[", "]"):
|
133 |
+
return esc, src[2:]
|
134 |
+
elif esc == "r":
|
135 |
+
return "\r", src[2:]
|
136 |
+
elif esc == "n":
|
137 |
+
return "\n", src[2:]
|
138 |
+
elif esc == "t":
|
139 |
+
return "\t", src[2:]
|
140 |
+
raise RuntimeError("unknown escape at " + src)
|
141 |
+
elif src:
|
142 |
+
return src[0], src[1:]
|
143 |
+
raise RuntimeError("unexpected end of input")
|
144 |
+
|
145 |
+
|
146 |
+
def _parse_rhs_literal_string(src: str, outbuf: List[int]) -> str:
|
147 |
+
assert src[0] == '"', f"rule should start with '\"', but got {src[0]}"
|
148 |
+
remaining_src = src[1:]
|
149 |
+
|
150 |
+
# advance until we get an end quote or run out of input
|
151 |
+
while remaining_src and remaining_src[0] != '"':
|
152 |
+
char, remaining_src = parse_char(remaining_src)
|
153 |
+
outbuf.append(LITERAL_MARKER)
|
154 |
+
outbuf.append(ord(char))
|
155 |
+
outbuf.append(ord(char))
|
156 |
+
|
157 |
+
# in case we ran out of input before finding the end quote
|
158 |
+
if not remaining_src:
|
159 |
+
raise RuntimeError(f"expecting an end quote at {src},but not found")
|
160 |
+
|
161 |
+
# remove the end quote and return the remaining string
|
162 |
+
return remaining_src[1:]
|
163 |
+
|
164 |
+
|
165 |
+
def _parse_rhs_char_ranges(src: str, outbuf: List[int]) -> str:
|
166 |
+
assert src[0] == "[", f"rule should start with '[', but got {src[0]}"
|
167 |
+
remaining_src = src[1:]
|
168 |
+
start_idx = len(outbuf)
|
169 |
+
# num chars in range - replaced at end of loop
|
170 |
+
outbuf.append(TO_BE_FILLED_MARKER)
|
171 |
+
while remaining_src and remaining_src[0] != "]":
|
172 |
+
char, remaining_src = parse_char(remaining_src)
|
173 |
+
|
174 |
+
outbuf.append(ord(char))
|
175 |
+
if remaining_src[0] == "-" and remaining_src[1] != "]":
|
176 |
+
endchar_pair, remaining_src = parse_char(remaining_src[1:])
|
177 |
+
outbuf.append(ord(endchar_pair))
|
178 |
+
else:
|
179 |
+
# This is the case for enumerate, e.g., [0123456789], [abcdef]
|
180 |
+
# Each char is considered as a range of itself, i.e., c-c
|
181 |
+
outbuf.append(ord(char))
|
182 |
+
if not remaining_src:
|
183 |
+
raise RuntimeError(
|
184 |
+
f"expecting an ] at {src},but not found, is the char range closed?"
|
185 |
+
)
|
186 |
+
# replace num chars with actual
|
187 |
+
outbuf[start_idx] = len(outbuf) - start_idx - 1
|
188 |
+
return remaining_src[1:]
|
189 |
+
|
190 |
+
|
191 |
+
def _parse_rhs_symbol_reference(src: str, state: ParseState, outbuf: List[int]) -> str:
|
192 |
+
assert is_word_char(src[0]), f"rule should start with a word char, but got {src[0]}"
|
193 |
+
name, remaining_src = parse_name(src)
|
194 |
+
ref_rule_id = get_symbol_id(state, name)
|
195 |
+
outbuf.append(REF_RULE_MARKER)
|
196 |
+
outbuf.append(ref_rule_id)
|
197 |
+
return remaining_src
|
198 |
+
|
199 |
+
|
200 |
+
def _parse_rhs_grouping(
|
201 |
+
remaining_src: str, state: ParseState, rule_name: str, outbuf: List[int]
|
202 |
+
) -> str:
|
203 |
+
assert (
|
204 |
+
remaining_src[0] == "("
|
205 |
+
), f"rule should start with '(', but got {remaining_src[0]}"
|
206 |
+
remaining_src = remove_leading_white_space(remaining_src[1:], True)
|
207 |
+
# parse nested alternates into synthesized rule
|
208 |
+
synthetic_rule_id = generate_symbol_id(state, rule_name)
|
209 |
+
remaining_src = parse_rhs(state, remaining_src, rule_name, synthetic_rule_id, True)
|
210 |
+
# output reference to synthesized rule
|
211 |
+
outbuf.append(REF_RULE_MARKER)
|
212 |
+
outbuf.append(synthetic_rule_id)
|
213 |
+
|
214 |
+
if not remaining_src or remaining_src[0] != ")":
|
215 |
+
raise RuntimeError("expecting ')' at " + remaining_src)
|
216 |
+
return remaining_src[1:]
|
217 |
+
|
218 |
+
|
219 |
+
def _parse_rhs_repetition_operators(
|
220 |
+
remaining_src: str,
|
221 |
+
state: ParseState,
|
222 |
+
rule_name: str,
|
223 |
+
last_sym_start: int,
|
224 |
+
outbuf: List[int],
|
225 |
+
) -> str:
|
226 |
+
assert remaining_src[0] in (
|
227 |
+
"*",
|
228 |
+
"+",
|
229 |
+
"?",
|
230 |
+
), f"rule should start with '*', '+', or '?', but got {remaining_src[0]}"
|
231 |
+
out_grammar = state.grammar_encoding
|
232 |
+
# last_sym_start = len(outbuf)
|
233 |
+
|
234 |
+
# apply transformation to previous symbol (last_sym_start -
|
235 |
+
# end) according to rewrite rules:
|
236 |
+
# S* --> S' ::= S S' |
|
237 |
+
# S+ --> S' ::= S S' | S
|
238 |
+
# S? --> S' ::= S |
|
239 |
+
sub_rule_id = generate_symbol_id(state, rule_name)
|
240 |
+
out_grammar.append(sub_rule_id)
|
241 |
+
sub_rule_offset = len(out_grammar)
|
242 |
+
# placeholder for size of 1st alternate
|
243 |
+
out_grammar.append(TO_BE_FILLED_MARKER)
|
244 |
+
# add preceding symbol to generated rule
|
245 |
+
out_grammar.extend(outbuf[last_sym_start:])
|
246 |
+
if remaining_src[0] in ("*", "+"):
|
247 |
+
# cause generated rule to recurse
|
248 |
+
out_grammar.append(REF_RULE_MARKER)
|
249 |
+
out_grammar.append(sub_rule_id)
|
250 |
+
# apply actual size
|
251 |
+
out_grammar[sub_rule_offset] = len(out_grammar) - sub_rule_offset
|
252 |
+
# mark end of 1st alternate
|
253 |
+
out_grammar.append(END_OF_ALTERNATE_MARKER)
|
254 |
+
sub_rule_offset = len(out_grammar)
|
255 |
+
# placeholder for size of 2nd alternate
|
256 |
+
out_grammar.append(TO_BE_FILLED_MARKER)
|
257 |
+
if remaining_src[0] == "+":
|
258 |
+
# add preceding symbol as alternate only for '+'
|
259 |
+
out_grammar.extend(outbuf[last_sym_start:])
|
260 |
+
# apply actual size of 2nd alternate
|
261 |
+
out_grammar[sub_rule_offset] = len(out_grammar) - sub_rule_offset
|
262 |
+
# mark end of 2nd alternate, then end of rule
|
263 |
+
out_grammar.append(END_OF_ALTERNATE_MARKER)
|
264 |
+
out_grammar.append(END_OF_RULE_MARKER)
|
265 |
+
|
266 |
+
# in original rule, replace previous symbol with reference to generated rule
|
267 |
+
outbuf[last_sym_start:] = [REF_RULE_MARKER, sub_rule_id]
|
268 |
+
return remaining_src[1:]
|
269 |
+
|
270 |
+
|
271 |
+
def parse_simple_rhs(state, rhs: str, rule_name: str, outbuf, is_nested):
|
272 |
+
simple_rhs_offset = len(outbuf)
|
273 |
+
|
274 |
+
# sequence size, will be replaced at end when known
|
275 |
+
outbuf.append(TO_BE_FILLED_MARKER)
|
276 |
+
|
277 |
+
last_sym_start = len(outbuf)
|
278 |
+
remaining_rhs = rhs
|
279 |
+
while remaining_rhs:
|
280 |
+
if remaining_rhs[0] == '"': # literal string
|
281 |
+
# mark the start of the last symbol, for repetition operator
|
282 |
+
last_sym_start = len(outbuf)
|
283 |
+
remaining_rhs = _parse_rhs_literal_string(remaining_rhs, outbuf)
|
284 |
+
elif remaining_rhs[0] == "[": # char range(s)
|
285 |
+
# mark the start of the last symbol, for repetition operator
|
286 |
+
last_sym_start = len(outbuf)
|
287 |
+
remaining_rhs = _parse_rhs_char_ranges(remaining_rhs, outbuf)
|
288 |
+
elif is_word_char(remaining_rhs[0]): # rule reference
|
289 |
+
# mark the start of the last symbol, for repetition operator
|
290 |
+
last_sym_start = len(outbuf)
|
291 |
+
remaining_rhs = _parse_rhs_symbol_reference(remaining_rhs, state, outbuf)
|
292 |
+
elif remaining_rhs[0] == "(": # grouping
|
293 |
+
# mark the start of the last symbol, for repetition operator
|
294 |
+
last_sym_start = len(outbuf)
|
295 |
+
remaining_rhs = _parse_rhs_grouping(remaining_rhs, state, rule_name, outbuf)
|
296 |
+
elif remaining_rhs[0] in ("*", "+", "?"): # repetition operator
|
297 |
+
# No need to mark the start of the last symbol, because we already did it
|
298 |
+
if len(outbuf) - simple_rhs_offset - 1 == 0:
|
299 |
+
raise RuntimeError(
|
300 |
+
"expecting preceeding item to */+/? at " + remaining_rhs
|
301 |
+
)
|
302 |
+
remaining_rhs = _parse_rhs_repetition_operators(
|
303 |
+
remaining_rhs, state, rule_name, last_sym_start, outbuf
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
# case for newline, i.e., end of rule
|
307 |
+
assert remaining_rhs[0] in [
|
308 |
+
"\n",
|
309 |
+
"|",
|
310 |
+
")",
|
311 |
+
], f"rule should end with newline or '|', but got {remaining_rhs[0]}"
|
312 |
+
# we break here so that we call parse_rule again to parse the next rule
|
313 |
+
break
|
314 |
+
# Here we do not rm newline deliberately so that we know the rhs is ended
|
315 |
+
remaining_rhs = remove_leading_white_space(
|
316 |
+
remaining_rhs, rm_leading_newline=is_nested
|
317 |
+
)
|
318 |
+
|
319 |
+
# apply actual size of this alternate sequence
|
320 |
+
outbuf[simple_rhs_offset] = len(outbuf) - simple_rhs_offset
|
321 |
+
# mark end of alternate
|
322 |
+
outbuf.append(END_OF_ALTERNATE_MARKER)
|
323 |
+
return remaining_rhs
|
324 |
+
|
325 |
+
|
326 |
+
def parse_rhs(state, rhs: str, rule_name, rule_id, is_nested):
|
327 |
+
outbuf = []
|
328 |
+
remaining_rhs = parse_simple_rhs(state, rhs, rule_name, outbuf, is_nested)
|
329 |
+
while remaining_rhs and remaining_rhs[0] == "|":
|
330 |
+
remaining_rhs = remove_leading_white_space(remaining_rhs[1:], True)
|
331 |
+
remaining_rhs = parse_simple_rhs(
|
332 |
+
state, remaining_rhs, rule_name, outbuf, is_nested
|
333 |
+
)
|
334 |
+
|
335 |
+
# Now we have finished parsing the rhs, we can add the rule to the grammar_encoding
|
336 |
+
state.grammar_encoding.append(rule_id)
|
337 |
+
state.grammar_encoding.extend(outbuf)
|
338 |
+
state.grammar_encoding.append(END_OF_RULE_MARKER)
|
339 |
+
return remaining_rhs
|
340 |
+
|
341 |
+
|
342 |
+
def parse_rule(state: ParseState, rule_text: str) -> str:
|
343 |
+
name, remaining_rule_text = parse_name(rule_text)
|
344 |
+
remaining_rule_text = remove_leading_white_space(remaining_rule_text, False)
|
345 |
+
# check if the rule is already defined, TODO: what will happen if the rule is already defined?
|
346 |
+
rule_id = get_symbol_id(state, name)
|
347 |
+
|
348 |
+
if remaining_rule_text[:3] != "::=":
|
349 |
+
raise RuntimeError("expecting ::= at " + remaining_rule_text)
|
350 |
+
remaining_rule_text = remove_leading_white_space(remaining_rule_text[3:], True)
|
351 |
+
|
352 |
+
remaining_rule_text = parse_rhs(state, remaining_rule_text, name, rule_id, False)
|
353 |
+
|
354 |
+
if remaining_rule_text and remaining_rule_text[0] == "\r":
|
355 |
+
remaining_rule_text = (
|
356 |
+
remaining_rule_text[2:]
|
357 |
+
if remaining_rule_text[1] == "\n"
|
358 |
+
else remaining_rule_text[1:]
|
359 |
+
)
|
360 |
+
elif remaining_rule_text and remaining_rule_text[0] == "\n":
|
361 |
+
remaining_rule_text = remaining_rule_text[1:]
|
362 |
+
elif remaining_rule_text:
|
363 |
+
raise RuntimeError("expecting newline or end at " + remaining_rule_text)
|
364 |
+
return remove_leading_white_space(remaining_rule_text, True)
|
365 |
+
|
366 |
+
|
367 |
+
def parse_ebnf(grammar_text: str) -> ParseState:
|
368 |
+
try:
|
369 |
+
state = ParseState()
|
370 |
+
remaining_grammar_text = remove_leading_white_space(grammar_text, True)
|
371 |
+
last_grammar_repr = ""
|
372 |
+
while remaining_grammar_text:
|
373 |
+
if last_grammar_repr:
|
374 |
+
last_parsed_rule_len = len(last_grammar_repr) - len(
|
375 |
+
remaining_grammar_text
|
376 |
+
)
|
377 |
+
logger.debug(
|
378 |
+
f"last_parsed_rule: {last_grammar_repr[:last_parsed_rule_len]}"
|
379 |
+
)
|
380 |
+
last_grammar_repr = remaining_grammar_text
|
381 |
+
remaining_grammar_text = parse_rule(state, remaining_grammar_text)
|
382 |
+
state.grammar_encoding.append(END_OF_GRAMMAR_MARKER)
|
383 |
+
return state
|
384 |
+
except RuntimeError as err:
|
385 |
+
logger.warning("error parsing grammar:", err)
|
386 |
+
return ParseState()
|
387 |
+
|
388 |
+
|
389 |
+
###################################
|
390 |
+
# EBNF Grammar Parsing ends here #
|
391 |
+
###################################
|
392 |
+
|
393 |
+
|
394 |
+
def break_grammar_into_rules(grammar_encoding: List[int]) -> List[List[int]]:
|
395 |
+
offset = 0
|
396 |
+
# we loop until we reach the end of the grammar_encoding
|
397 |
+
rule_encodings = []
|
398 |
+
i = 0
|
399 |
+
while i < len(grammar_encoding) - 2:
|
400 |
+
if (
|
401 |
+
grammar_encoding[i] == END_OF_ALTERNATE_MARKER
|
402 |
+
and grammar_encoding[i + 1] == END_OF_RULE_MARKER
|
403 |
+
):
|
404 |
+
rule_encodings.append(grammar_encoding[offset : i + 2])
|
405 |
+
offset = i + 2
|
406 |
+
# skip the END_OF_RULE_MARKER
|
407 |
+
# This is mandatory because if we do not skip the END_OF_RULE_MARKER
|
408 |
+
# we fail in the case where the next rule has rule_id 0
|
409 |
+
i += 1
|
410 |
+
i += 1
|
411 |
+
return rule_encodings
|
412 |
+
|
413 |
+
|
414 |
+
def break_rule_into_elements(rule_encoding: List[int]) -> List[List[int]]:
|
415 |
+
rule_id = rule_encoding.pop(0)
|
416 |
+
end_of_rule_marker = rule_encoding.pop(-1)
|
417 |
+
assert (
|
418 |
+
end_of_rule_marker == END_OF_RULE_MARKER
|
419 |
+
), f"rule should end with {END_OF_RULE_MARKER}, but got {end_of_rule_marker}"
|
420 |
+
|
421 |
+
offset = 0
|
422 |
+
elements = []
|
423 |
+
while offset < len(rule_encoding):
|
424 |
+
element_size = rule_encoding[offset]
|
425 |
+
assert (
|
426 |
+
rule_encoding[offset + element_size] == END_OF_ALTERNATE_MARKER
|
427 |
+
), f"element should end with {END_OF_ALTERNATE_MARKER}, but got {rule_encoding[offset + element_size]}"
|
428 |
+
elements.append(rule_encoding[offset : offset + element_size + 1])
|
429 |
+
offset += element_size + 1
|
430 |
+
return elements
|
431 |
+
|
432 |
+
|
433 |
+
def _print_annotated_grammar(file, grammar_encoding, symbol_id_names, index=0):
|
434 |
+
rule_id = grammar_encoding[index]
|
435 |
+
print(f"<{index}>{symbol_id_names[rule_id]} ::=", end=" ", file=file)
|
436 |
+
pos = index + 1
|
437 |
+
while grammar_encoding[pos]:
|
438 |
+
if pos - 1 > index:
|
439 |
+
print("|", end=" ", file=file)
|
440 |
+
pos += 1 # sequence size, not needed here
|
441 |
+
while grammar_encoding[pos]:
|
442 |
+
if grammar_encoding[pos] == REF_RULE_MARKER:
|
443 |
+
ref_rule_id = grammar_encoding[pos + 1]
|
444 |
+
print(
|
445 |
+
f"<{pos}>{symbol_id_names[ref_rule_id]}",
|
446 |
+
end=" ",
|
447 |
+
file=file,
|
448 |
+
)
|
449 |
+
pos += 2
|
450 |
+
else:
|
451 |
+
print("<{}>[".format(pos), end="", file=file)
|
452 |
+
num_chars = grammar_encoding[pos]
|
453 |
+
pos += 1
|
454 |
+
|
455 |
+
for i in range(0, num_chars, 2):
|
456 |
+
print(
|
457 |
+
"{}-".format(chr(grammar_encoding[pos + i])), end="", file=file
|
458 |
+
)
|
459 |
+
if i + 1 < num_chars:
|
460 |
+
print(
|
461 |
+
"{}".format(chr(grammar_encoding[pos + i + 1])),
|
462 |
+
end="",
|
463 |
+
file=file,
|
464 |
+
)
|
465 |
+
print("]", end=" ", file=file)
|
466 |
+
pos += num_chars
|
467 |
+
pos += 1
|
468 |
+
print(file=file)
|
469 |
+
return pos + 1
|
470 |
+
|
471 |
+
|
472 |
+
def print_grammar(file, state):
|
473 |
+
pos = 0
|
474 |
+
symbol_id_names = {v: k for k, v in state.symbol_table.items()}
|
475 |
+
print("Grammar Rules:", file=file)
|
476 |
+
while (
|
477 |
+
pos < len(state.grammar_encoding)
|
478 |
+
and state.grammar_encoding[pos] != END_OF_GRAMMAR_MARKER
|
479 |
+
):
|
480 |
+
pos = _print_annotated_grammar(
|
481 |
+
file, state.grammar_encoding, symbol_id_names, pos
|
482 |
+
)
|
483 |
+
if pos > len(state.grammar_encoding):
|
484 |
+
raise Warning(f"grammar_encoding is not ended with {END_OF_GRAMMAR_MARKER}")
|
485 |
+
pos = 0
|
486 |
+
print("\nGrammar Hex representation:", file=file)
|
487 |
+
while (
|
488 |
+
pos < len(state.grammar_encoding)
|
489 |
+
and state.grammar_encoding[pos] != END_OF_GRAMMAR_MARKER
|
490 |
+
):
|
491 |
+
print(f"{state.grammar_encoding[pos]:04x}", end=" ", file=file)
|
492 |
+
pos += 1
|
493 |
+
if pos > len(state.grammar_encoding):
|
494 |
+
raise Warning(f"grammar_encoding is not ended with {END_OF_GRAMMAR_MARKER}")
|
495 |
+
else:
|
496 |
+
print("ffff\n")
|
497 |
+
|
498 |
+
print("Rules Decimal representation:", file=file)
|
499 |
+
# we loop until we reach the end of the grammar_encoding
|
500 |
+
rule_encodings = break_grammar_into_rules(state.grammar_encoding)
|
501 |
+
for rule_encoding in rule_encodings:
|
502 |
+
rule_id = rule_encoding[0]
|
503 |
+
print(
|
504 |
+
f"<{rule_id}> {break_rule_into_elements(rule_encoding)}",
|
505 |
+
file=file,
|
506 |
+
)
|
507 |
+
|
508 |
+
|
509 |
+
if __name__ == "__main__":
|
510 |
+
parser = argparse.ArgumentParser(description="Parse EBNF grammar files.")
|
511 |
+
parser.add_argument(
|
512 |
+
"-g",
|
513 |
+
"--grammar-file",
|
514 |
+
nargs="?",
|
515 |
+
default="examples/grammars/json.ebnf",
|
516 |
+
help="Path to the grammar file (default: examples/grammars/json.ebnf)",
|
517 |
+
)
|
518 |
+
|
519 |
+
args = parser.parse_args()
|
520 |
+
|
521 |
+
# set logging level
|
522 |
+
logging.basicConfig(level=logging.DEBUG)
|
523 |
+
|
524 |
+
with open(args.grammar_file, "r") as file:
|
525 |
+
input_text = file.read()
|
526 |
+
parsed_grammar = parse_ebnf(input_text)
|
527 |
+
parsed_grammar.print()
|
528 |
+
print(f"symbol_ids: \n{parsed_grammar.symbol_table}")
|
529 |
+
|
530 |
+
start_rule_id = parsed_grammar.symbol_table["root"]
|
transformers_gad/recognizer.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
from functools import lru_cache
|
4 |
+
from typing import List, Tuple, Dict
|
5 |
+
|
6 |
+
from transformers_gad.parser import (
|
7 |
+
END_OF_RULE_MARKER,
|
8 |
+
END_OF_ALTERNATE_MARKER,
|
9 |
+
parse_ebnf,
|
10 |
+
REF_RULE_MARKER,
|
11 |
+
)
|
12 |
+
from transformers_gad.utf8_utils import PartialUTF8, decode_utf8
|
13 |
+
from transformers_gad.utils import intervals_intersect
|
14 |
+
import logging
|
15 |
+
|
16 |
+
|
17 |
+
class AcceptState:
|
18 |
+
def __init__(self, stacks, partial_utf8):
|
19 |
+
self.stacks = stacks
|
20 |
+
self.partial_utf8 = partial_utf8
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def empty_state():
|
24 |
+
return AcceptState([], PartialUTF8())
|
25 |
+
|
26 |
+
|
27 |
+
class StringRecognizer:
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
grammar_encoding: List[int],
|
31 |
+
start_rule_id: int = None,
|
32 |
+
rule_offsets: List[int] = None,
|
33 |
+
stacks: List[List[int]] = None,
|
34 |
+
):
|
35 |
+
# strictly speaking, we don't need to copy grammar_encoding because we don't modify it
|
36 |
+
# but we do it anyway to be safe
|
37 |
+
# in case where the grammar is very large, we can consider not copying it
|
38 |
+
self.grammar_encoding = grammar_encoding
|
39 |
+
if rule_offsets is not None:
|
40 |
+
self.rule_offsets = rule_offsets
|
41 |
+
else:
|
42 |
+
if start_rule_id is None:
|
43 |
+
raise ValueError("start_rule_id cannot be None if rule_offsets is None")
|
44 |
+
self.rule_offsets = self.init_rules(start_rule_id)
|
45 |
+
# each stack is a list of indices into grammar_encoding
|
46 |
+
# each index points to a rule's
|
47 |
+
if stacks is not None:
|
48 |
+
self.stacks = stacks
|
49 |
+
else:
|
50 |
+
if start_rule_id is None:
|
51 |
+
raise ValueError("start_rule_id cannot be None if stacks is None")
|
52 |
+
self.stacks: List[List[int]] = self.init_stack(start_rule_id)
|
53 |
+
self.start_rule_id = start_rule_id
|
54 |
+
|
55 |
+
def init_rules(self, start_rule_id: int) -> List[int]:
|
56 |
+
_rule_offset = 0
|
57 |
+
rule_offsets = []
|
58 |
+
# Build `rules` as an array of rule IDs to their positions in `grammar_src`
|
59 |
+
while self.grammar_encoding[_rule_offset] != 0xFFFF:
|
60 |
+
rule_id = self.grammar_encoding[_rule_offset]
|
61 |
+
# store the offset idx
|
62 |
+
if len(rule_offsets) <= rule_id:
|
63 |
+
rule_offsets.extend([-1] * (rule_id - len(rule_offsets) + 1))
|
64 |
+
rule_offsets[rule_id] = _rule_offset
|
65 |
+
|
66 |
+
# Skip rule ID
|
67 |
+
# _rule_offset += 1
|
68 |
+
simple_rhs_offset = _rule_offset + 1
|
69 |
+
|
70 |
+
# Skip rule alternates
|
71 |
+
while self.grammar_encoding[simple_rhs_offset] != END_OF_RULE_MARKER:
|
72 |
+
simple_rhs_offset = (
|
73 |
+
simple_rhs_offset + 1 + self.grammar_encoding[simple_rhs_offset]
|
74 |
+
)
|
75 |
+
|
76 |
+
# Skip 0 denoting end of rule
|
77 |
+
# _rule_offset += 1
|
78 |
+
_rule_offset = simple_rhs_offset + 1
|
79 |
+
|
80 |
+
retrieved_start_rule_id = self.grammar_encoding[rule_offsets[start_rule_id]]
|
81 |
+
assert retrieved_start_rule_id == start_rule_id
|
82 |
+
|
83 |
+
return rule_offsets
|
84 |
+
|
85 |
+
def init_stack(self, start_rule_id: int) -> List[List[int]]:
|
86 |
+
|
87 |
+
stacks = []
|
88 |
+
# Loop over alternates of start rule to build initial stacks
|
89 |
+
sub_rhs_offset = self.rule_offsets[start_rule_id] + 1
|
90 |
+
while self.grammar_encoding[sub_rhs_offset]:
|
91 |
+
stack: List[int] = []
|
92 |
+
# If alternate is nonempty, add to stack
|
93 |
+
element_offset = sub_rhs_offset + 1
|
94 |
+
if self.grammar_encoding[element_offset] != END_OF_ALTERNATE_MARKER:
|
95 |
+
stack.append(element_offset)
|
96 |
+
stacks.extend(self.advance_stack(tuple(stack)))
|
97 |
+
sub_rhs_offset += 1 + self.grammar_encoding[sub_rhs_offset]
|
98 |
+
return stacks
|
99 |
+
|
100 |
+
def get_initial_accept_state(self) -> AcceptState:
|
101 |
+
return AcceptState(self.init_stack(self.start_rule_id), PartialUTF8())
|
102 |
+
|
103 |
+
def get_termination_accept_state(self) -> AcceptState:
|
104 |
+
return AcceptState([], PartialUTF8())
|
105 |
+
|
106 |
+
@lru_cache(maxsize=32768)
|
107 |
+
def advance_stack(self, stack: Tuple[int]) -> List[List[int]]:
|
108 |
+
stack = list(stack)
|
109 |
+
if len(stack) == 0:
|
110 |
+
return [stack]
|
111 |
+
|
112 |
+
# we get the last element of the stack, which is the element we are currently processing
|
113 |
+
cur_element_offset = stack[-1]
|
114 |
+
|
115 |
+
# if the element is a terminal, we don't need to advance the stack
|
116 |
+
if self.grammar_encoding[cur_element_offset] != REF_RULE_MARKER:
|
117 |
+
return [stack]
|
118 |
+
# the remaining case is that the element is a non-terminal, i.e. a reference to another rule
|
119 |
+
else:
|
120 |
+
ref_rule_id = self.grammar_encoding[cur_element_offset + 1]
|
121 |
+
# find the offset of the referenced rule
|
122 |
+
ref_subrule_offset = self.rule_offsets[ref_rule_id] + 1
|
123 |
+
new_stacks: List[List[int]] = []
|
124 |
+
# Loop over alternates of referenced rule to build new stacks
|
125 |
+
while self.grammar_encoding[ref_subrule_offset] != END_OF_RULE_MARKER:
|
126 |
+
# copy the original stack without the last element
|
127 |
+
new_stack = stack[:-1]
|
128 |
+
# if the rule ref is followed by another element, we add it to the stack
|
129 |
+
next_element_offset = cur_element_offset + 2
|
130 |
+
if (
|
131 |
+
self.grammar_encoding[next_element_offset]
|
132 |
+
!= END_OF_ALTERNATE_MARKER
|
133 |
+
):
|
134 |
+
new_stack.append(next_element_offset)
|
135 |
+
|
136 |
+
# if the referenced rule is not empty, we add its element offset to the stack
|
137 |
+
ref_element_offset = ref_subrule_offset + 1
|
138 |
+
if self.grammar_encoding[ref_element_offset] != END_OF_ALTERNATE_MARKER:
|
139 |
+
new_stack.append(ref_element_offset)
|
140 |
+
|
141 |
+
new_stacks.extend(self.advance_stack(tuple(new_stack)))
|
142 |
+
ref_subrule_offset += self.grammar_encoding[ref_subrule_offset] + 1
|
143 |
+
|
144 |
+
return new_stacks
|
145 |
+
|
146 |
+
def _consume_byte(self, byte: int, accept_state: AcceptState):
|
147 |
+
# suppose we have code point 一, ord('一') = 19968, we need to match 3 bytes
|
148 |
+
# we need to match 3 bytes, so we need to call _consume_byte_partial_match 3 times
|
149 |
+
self._consume_bytes(bytes([byte]), accept_state)
|
150 |
+
|
151 |
+
# @lru_cache(maxsize=32768)
|
152 |
+
def _probe_bytes(
|
153 |
+
self,
|
154 |
+
byte_seq: bytes,
|
155 |
+
stacks: List[List[int]],
|
156 |
+
partial_utf8: PartialUTF8,
|
157 |
+
verbose=True,
|
158 |
+
):
|
159 |
+
if type(byte_seq) is list:
|
160 |
+
byte_seq = bytes(byte_seq)
|
161 |
+
code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8)
|
162 |
+
if verbose:
|
163 |
+
logging.debug(
|
164 |
+
f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}"
|
165 |
+
)
|
166 |
+
new_stacks = self._consume_code_points(code_points, stacks)
|
167 |
+
|
168 |
+
for stack in new_stacks:
|
169 |
+
|
170 |
+
# stack is empty, meaning that the variables are all consumed
|
171 |
+
if len(stack) == 0:
|
172 |
+
return True
|
173 |
+
element_offset = stack[-1]
|
174 |
+
if self.partial_utf8_accept_at_element(element_offset, new_partial_utf8):
|
175 |
+
return True
|
176 |
+
return False
|
177 |
+
|
178 |
+
def _consume_bytes(
|
179 |
+
self,
|
180 |
+
byte_seq: bytes,
|
181 |
+
accept_state: AcceptState = None,
|
182 |
+
verbose=True,
|
183 |
+
):
|
184 |
+
if accept_state is None:
|
185 |
+
accept_state = self.get_initial_accept_state()
|
186 |
+
stacks = accept_state.stacks
|
187 |
+
partial_utf8 = accept_state.partial_utf8
|
188 |
+
if type(byte_seq) is list:
|
189 |
+
byte_seq = bytes(byte_seq)
|
190 |
+
code_points, new_partial_utf8 = decode_utf8(byte_seq, partial_utf8)
|
191 |
+
if verbose:
|
192 |
+
logging.debug(
|
193 |
+
f"code_points: {code_points}; new_partial_utf8: {new_partial_utf8}"
|
194 |
+
)
|
195 |
+
new_stacks = self._consume_code_points(code_points, stacks)
|
196 |
+
|
197 |
+
new_new_stacks = []
|
198 |
+
for stack in new_stacks:
|
199 |
+
if len(stack) == 0:
|
200 |
+
continue
|
201 |
+
element_offset = stack[-1]
|
202 |
+
if self.partial_utf8_accept_at_element(element_offset, new_partial_utf8):
|
203 |
+
new_new_stacks.append(stack)
|
204 |
+
return AcceptState(new_new_stacks, new_partial_utf8)
|
205 |
+
|
206 |
+
##########################
|
207 |
+
#
|
208 |
+
# Code point recognition
|
209 |
+
#
|
210 |
+
##########################
|
211 |
+
|
212 |
+
@lru_cache(maxsize=30000)
|
213 |
+
def _consume_code_point(
|
214 |
+
self, code_point: int, stacks: Tuple[Tuple[int]]
|
215 |
+
) -> List[List[int]]:
|
216 |
+
"""
|
217 |
+
consume a character from the stack
|
218 |
+
char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127]
|
219 |
+
"""
|
220 |
+
new_stacks = []
|
221 |
+
|
222 |
+
stacks: List[List[int]] = list([list(stack) for stack in stacks])
|
223 |
+
if code_point == 0:
|
224 |
+
return new_stacks
|
225 |
+
for stack in stacks:
|
226 |
+
new_stacks.extend(
|
227 |
+
self._consume_code_point_per_stack(code_point, tuple(stack))
|
228 |
+
)
|
229 |
+
return new_stacks
|
230 |
+
|
231 |
+
@lru_cache(maxsize=30000)
|
232 |
+
def _consume_code_point_per_stack(
|
233 |
+
self, code_point: int, stack: Tuple[int]
|
234 |
+
) -> List[List[int]]:
|
235 |
+
"""
|
236 |
+
consume a character from the stack
|
237 |
+
char_code_point: can be a Unicode code point, including ascii code points which are in the range [0, 127]
|
238 |
+
"""
|
239 |
+
# TODO, the below code will raise an error when the stack is empty, but why is this happening?
|
240 |
+
# if len(stacks) == 0:
|
241 |
+
# raise ValueError("Stacks don't contain any stack, meaning that no character can be consumed")
|
242 |
+
# code_point = 0 is a special case when the uf8 sequence is not complete, we return an empty stack
|
243 |
+
# to indicate that the character is not accepted
|
244 |
+
stack = list(stack)
|
245 |
+
new_stacks = []
|
246 |
+
if code_point == 0:
|
247 |
+
return new_stacks
|
248 |
+
# stack is empty
|
249 |
+
if len(stack) == 0:
|
250 |
+
return new_stacks
|
251 |
+
|
252 |
+
element_offset = stack[-1]
|
253 |
+
|
254 |
+
found = self.accept_code_point_at_element(code_point, element_offset)
|
255 |
+
if not found:
|
256 |
+
return new_stacks
|
257 |
+
|
258 |
+
size = self.grammar_encoding[element_offset]
|
259 |
+
element_offset += size + 1
|
260 |
+
new_stack = stack[:-1]
|
261 |
+
if self.grammar_encoding[element_offset]:
|
262 |
+
new_stack.append(element_offset)
|
263 |
+
return self.advance_stack(tuple(new_stack))
|
264 |
+
|
265 |
+
def _consume_code_points(
|
266 |
+
self, code_points: List[int], stacks: List[List[int]], verbose=False
|
267 |
+
) -> List[List[int]]:
|
268 |
+
for i, code_point in enumerate(code_points):
|
269 |
+
# for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
|
270 |
+
tuple_stacks: Tuple[Tuple[int]] = tuple([tuple(stack) for stack in stacks])
|
271 |
+
stacks = self._consume_code_point(code_point, tuple_stacks)
|
272 |
+
if len(stacks) > 0 and verbose:
|
273 |
+
accepted_code_point = code_points[: i + 1]
|
274 |
+
corresponding_char = chr(code_point)
|
275 |
+
logging.debug(
|
276 |
+
f"code point {accepted_code_point} corresponding to {corresponding_char} is accepted"
|
277 |
+
)
|
278 |
+
return stacks
|
279 |
+
|
280 |
+
def _accept_code_points(
|
281 |
+
self, code_points: List[int], stacks: List[List[int]], verbose=False
|
282 |
+
) -> bool:
|
283 |
+
stacks = self._consume_code_points(code_points, stacks, verbose)
|
284 |
+
return len(stacks) > 0
|
285 |
+
|
286 |
+
@lru_cache(maxsize=30000)
|
287 |
+
def accept_code_point_at_element(
|
288 |
+
self, code_point: int, element_offset: int
|
289 |
+
) -> bool:
|
290 |
+
size = self.grammar_encoding[element_offset]
|
291 |
+
# to make idx point to the range_start of the first range
|
292 |
+
element_offset += 1
|
293 |
+
for i in range(0, size, 2):
|
294 |
+
if (
|
295 |
+
self.grammar_encoding[element_offset + i]
|
296 |
+
<= code_point
|
297 |
+
<= self.grammar_encoding[element_offset + i + 1]
|
298 |
+
):
|
299 |
+
return True
|
300 |
+
return False
|
301 |
+
|
302 |
+
# def _accept_code_point(self, code_point: int, stacks: List[List[int]]):
|
303 |
+
# # for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
|
304 |
+
# tuple_stacks: Tuple[Tuple[int]] = tuple([tuple(stack) for stack in stacks])
|
305 |
+
# new_stacks: List[List[int]] = self._consume_code_point(code_point, tuple_stacks)
|
306 |
+
# return len(new_stacks) > 0
|
307 |
+
|
308 |
+
#############################
|
309 |
+
#
|
310 |
+
# Partial UTF-8 recognition
|
311 |
+
#
|
312 |
+
#############################
|
313 |
+
|
314 |
+
def partial_utf8_accept_at_element(
|
315 |
+
self, element_offset: int, partial_utf8: PartialUTF8
|
316 |
+
) -> bool:
|
317 |
+
# Extract the accumulated value and the number of remaining bytes from the partial_utf8 object.
|
318 |
+
partial_value = partial_utf8.value
|
319 |
+
n_remain = partial_utf8.n_remain
|
320 |
+
|
321 |
+
# Return False if there are no remaining bytes to process or if it's an invalid UTF-8 sequence.
|
322 |
+
if n_remain == 1 and partial_value < 2:
|
323 |
+
return False
|
324 |
+
|
325 |
+
# If there are no remaining bytes, this means we had already consumed a complete UTF-8 sequence.
|
326 |
+
if n_remain <= 0:
|
327 |
+
return True
|
328 |
+
|
329 |
+
# Calculate the lowest possible Unicode code point that can be formed with the remaining bytes.
|
330 |
+
low = partial_value << (n_remain * 6)
|
331 |
+
# Calculate the highest possible Unicode code point by setting all remaining bits to 1.
|
332 |
+
high = low | ((1 << (n_remain * 6)) - 1)
|
333 |
+
|
334 |
+
# If the low end of the range is 0 and a specific number of bytes remain, adjust low to the minimum value
|
335 |
+
# that can be represented with that number of bytes. This accounts for UTF-8 encoding rules.
|
336 |
+
if low == 0:
|
337 |
+
if n_remain == 2:
|
338 |
+
low = 1 << 11 # Minimum value representable with 2 additional bytes.
|
339 |
+
elif n_remain == 3:
|
340 |
+
low = 1 << 16 # Minimum value representable with 3 additional bytes.
|
341 |
+
|
342 |
+
# Get the size of the grammar rule starting at the current element_offset.
|
343 |
+
size = self.grammar_encoding[element_offset]
|
344 |
+
# Move the element_offset to the start of the grammar rule's definition.
|
345 |
+
element_offset += 1
|
346 |
+
|
347 |
+
# Iterate over the grammar rule, checking if the range defined by low-high overlaps with any specified ranges.
|
348 |
+
for i in range(0, size, 2):
|
349 |
+
# If the current range (specified in the grammar encoding) overlaps with the low-high range, return True.
|
350 |
+
if intervals_intersect(
|
351 |
+
low,
|
352 |
+
high,
|
353 |
+
self.grammar_encoding[element_offset + i],
|
354 |
+
self.grammar_encoding[element_offset + i + 1],
|
355 |
+
):
|
356 |
+
return True
|
357 |
+
|
358 |
+
# If no overlap is found with any of the ranges, return False, indicating no valid partial match.
|
359 |
+
return False
|
360 |
+
|
361 |
+
#############################
|
362 |
+
#
|
363 |
+
# String recognition
|
364 |
+
#
|
365 |
+
#############################
|
366 |
+
|
367 |
+
def _consume_string(self, string: str, accept_state: AcceptState):
|
368 |
+
# _bytes = bytes(string, "utf-8")
|
369 |
+
code_points = [ord(char) for char in string]
|
370 |
+
stacks = self._consume_code_points(code_points, accept_state.stacks)
|
371 |
+
return AcceptState(stacks, accept_state.partial_utf8)
|
372 |
+
|
373 |
+
def _accept_prefix(self, string: str, accept_state: AcceptState = None):
|
374 |
+
if accept_state is None:
|
375 |
+
accept_state = self.get_initial_accept_state()
|
376 |
+
new_accept_state = self._consume_string(string, accept_state)
|
377 |
+
return len(new_accept_state.stacks) > 0
|
378 |
+
|
379 |
+
def _accept_string(self, string: str, accept_state: AcceptState = None):
|
380 |
+
if accept_state is None:
|
381 |
+
accept_state = self.get_initial_accept_state()
|
382 |
+
new_accept_state = self._consume_string(string, accept_state)
|
383 |
+
at_least_one_stack_is_empty = any(
|
384 |
+
len(stack) == 0 for stack in new_accept_state.stacks
|
385 |
+
)
|
386 |
+
return at_least_one_stack_is_empty
|
387 |
+
|
388 |
+
def _can_stop(self, stacks: List[List[int]]):
|
389 |
+
# This happens in practice, but maybe it shouldn't? TODO
|
390 |
+
if len(stacks) == 0:
|
391 |
+
return True
|
392 |
+
# if any of the stack is empty, we can stop
|
393 |
+
for stack in stacks:
|
394 |
+
if len(stack) == 0:
|
395 |
+
return True
|
396 |
+
else:
|
397 |
+
return False
|
398 |
+
|
399 |
+
def _must_stop(self, stacks: List[List[int]]):
|
400 |
+
return len(stacks) == 0 or all(len(stack) == 0 for stack in stacks)
|
401 |
+
|
402 |
+
#############################
|
403 |
+
#
|
404 |
+
# Not Used
|
405 |
+
#
|
406 |
+
#############################
|
407 |
+
|
408 |
+
# For each sub-rule in the grammar, cache whether each byte is accepted.
|
409 |
+
@lru_cache(maxsize=None)
|
410 |
+
def char_acceptance_at_element(self, element_offset):
|
411 |
+
"""
|
412 |
+
Caches and returns a dictionary indicating whether a Unicode character is accepted
|
413 |
+
at a given rule position. This function considers Unicode characters, dynamically
|
414 |
+
inserting accepted ranges into a dictionary to optimize memory usage.
|
415 |
+
|
416 |
+
Args:
|
417 |
+
- rule_offset: The offset in the grammar encoding where the rule starts.
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
- A dictionary where each key is a Unicode character (or range) and the value is True if accepted.
|
421 |
+
"""
|
422 |
+
logging.debug(f"element_offset: {element_offset}")
|
423 |
+
acceptance = {}
|
424 |
+
num_chars = self.grammar_encoding[element_offset]
|
425 |
+
element_offset += 1
|
426 |
+
for i in range(0, num_chars, 2):
|
427 |
+
start = self.grammar_encoding[element_offset + i]
|
428 |
+
end = self.grammar_encoding[element_offset + i + 1]
|
429 |
+
for j in range(start, end + 1):
|
430 |
+
acceptance[j] = True
|
431 |
+
logging.debug(acceptance)
|
432 |
+
return acceptance
|
433 |
+
|
434 |
+
def _consume_code_points_new(
|
435 |
+
self, code_points: List[int], stacks: List[List[int]], verbose=False
|
436 |
+
) -> List[List[int]]:
|
437 |
+
new_stacks: List[List[int]] = []
|
438 |
+
for stack in stacks:
|
439 |
+
new_stacks.extend(
|
440 |
+
self._consume_code_points_per_stack(
|
441 |
+
tuple(code_points), tuple(stack), verbose
|
442 |
+
)
|
443 |
+
)
|
444 |
+
return new_stacks
|
445 |
+
|
446 |
+
@lru_cache(maxsize=30000)
|
447 |
+
def _consume_code_points_per_stack(
|
448 |
+
self, code_points: Tuple[int], stack: Tuple[int], verbose=False
|
449 |
+
) -> List[List[int]]:
|
450 |
+
code_points = list(code_points)
|
451 |
+
stacks = (stack,)
|
452 |
+
for i, code_point in enumerate(code_points):
|
453 |
+
# for lru_cache to work, we need to convert the list of stacks into a tuple of stacks
|
454 |
+
stacks = self._consume_code_point(code_point, stacks)
|
455 |
+
stacks = tuple([tuple(stack) for stack in stacks])
|
456 |
+
return [list(stack) for stack in stacks]
|
transformers_gad/token_grammar_recognizer.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
from abc import ABC
|
4 |
+
from functools import lru_cache
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from transformers_gad.recognizer import StringRecognizer, AcceptState
|
10 |
+
from transformers_gad.parser import parse_ebnf
|
11 |
+
from transformers_gad.trie import ByteTrie
|
12 |
+
from transformers_gad.utf8_utils import PartialUTF8
|
13 |
+
from .vocab_struct import LEAF, TokenTrie
|
14 |
+
from transformers_gad.mapping import get_mapping
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class AbsTokenRecognizer(ABC):
|
20 |
+
def __init__(self, grammar_str, tokenizer, start_rule_name="root", unicode=False):
|
21 |
+
parsed_grammar = parse_ebnf(grammar_str)
|
22 |
+
grammar_encoding = parsed_grammar.grammar_encoding
|
23 |
+
self.start_rule_id = parsed_grammar.symbol_table.get(start_rule_name)
|
24 |
+
self.byte_encoding = unicode
|
25 |
+
|
26 |
+
if unicode and not tokenizer.__class__.__name__.lower().startswith(
|
27 |
+
"gpt2"
|
28 |
+
): # gpt2tokenizer or gpt2tokenizerfast
|
29 |
+
raise ValueError(
|
30 |
+
"Constrained decoding with unicode is only supported for GPT2 model. Support for other models is coming soon."
|
31 |
+
"Or you can use the constraints with only ascii characters."
|
32 |
+
)
|
33 |
+
|
34 |
+
self.eos_token_id = tokenizer.eos_token_id
|
35 |
+
self.token_trie = TokenTrie(tokenizer)
|
36 |
+
self.tokenizer = tokenizer
|
37 |
+
self.string_recognizer = StringRecognizer(grammar_encoding, self.start_rule_id)
|
38 |
+
self.unicode_trie = ByteTrie.from_tokenizer(tokenizer, unicode=unicode)
|
39 |
+
self.mapping = get_mapping(tokenizer, unicode=unicode)
|
40 |
+
assert len(self.mapping) == len(
|
41 |
+
self.token_trie
|
42 |
+
), f"{len(self.mapping)}, {len(self.token_trie)}"
|
43 |
+
|
44 |
+
def _consume_token_id(
|
45 |
+
self, token_id: int, accept_state: AcceptState
|
46 |
+
) -> AcceptState:
|
47 |
+
if self.string_recognizer._must_stop(accept_state.stacks):
|
48 |
+
if token_id == self.eos_token_id:
|
49 |
+
return self.string_recognizer.get_termination_accept_state()
|
50 |
+
else:
|
51 |
+
raise ValueError(
|
52 |
+
f"All stacks are empty, so the only token accepted is EOS({self.eos_token_id}), but got {token_id}"
|
53 |
+
)
|
54 |
+
if token_id == self.eos_token_id:
|
55 |
+
if self.string_recognizer._can_stop(accept_state.stacks):
|
56 |
+
# if at least one of the stack is empty, we can stop
|
57 |
+
# we clear all the stacks, meaning that we don't accept any token after EOS
|
58 |
+
return self.string_recognizer.get_termination_accept_state()
|
59 |
+
else:
|
60 |
+
raise ValueError(
|
61 |
+
f"At least one of the stack should be empty when EOS is reached. However, "
|
62 |
+
f"the stacks are {accept_state.stacks}"
|
63 |
+
)
|
64 |
+
|
65 |
+
bytes_or_codepoints = self.mapping.map(token_id)
|
66 |
+
accept_state = self.string_recognizer._consume_bytes(
|
67 |
+
bytes_or_codepoints, accept_state
|
68 |
+
)
|
69 |
+
return accept_state
|
70 |
+
|
71 |
+
def probe_token_id(self, token_id: int, accept_state: AcceptState) -> bool:
|
72 |
+
stacks = accept_state.stacks
|
73 |
+
if self.string_recognizer._must_stop(stacks):
|
74 |
+
if token_id == self.eos_token_id:
|
75 |
+
return True
|
76 |
+
else:
|
77 |
+
return False
|
78 |
+
if token_id == self.eos_token_id:
|
79 |
+
if self.string_recognizer._can_stop(stacks):
|
80 |
+
# if at least one of the stack is empty, we can stop
|
81 |
+
# we clear all the stacks, meaning that we don't accept any token after EOS
|
82 |
+
return True
|
83 |
+
else:
|
84 |
+
return False
|
85 |
+
# for code_point in self.mapping.map(token_id):
|
86 |
+
# stacks = self.grammar._consume_char_code_point(code_point, stacks)
|
87 |
+
bytes_or_codepoints = self.mapping.map(token_id, verbose=False)
|
88 |
+
new_acc_state = self.string_recognizer._consume_bytes(
|
89 |
+
bytes_or_codepoints, accept_state, verbose=False
|
90 |
+
)
|
91 |
+
return len(new_acc_state.stacks) > 0
|
92 |
+
|
93 |
+
def advance_token_ids(self, *args, **kwargs):
|
94 |
+
"""Process a list of tokens according to the grammar rules."""
|
95 |
+
raise NotImplementedError
|
96 |
+
|
97 |
+
def batch_filter_vocab(self, batch_accept_states, device) -> torch.Tensor:
|
98 |
+
batch_acceptance = []
|
99 |
+
for accept_state in batch_accept_states:
|
100 |
+
batch_acceptance.append(self.filter_vocab(accept_state, device))
|
101 |
+
return torch.stack(batch_acceptance)
|
102 |
+
|
103 |
+
def filter_vocab(self, accept_state, device) -> torch.Tensor:
|
104 |
+
if not accept_state.stacks: # Check if stacks is empty
|
105 |
+
# Handle the empty case: for example, return a tensor of False
|
106 |
+
# The size of the tensor should match the size of your vocabulary
|
107 |
+
vocab_size = len(self.mapping)
|
108 |
+
logger.debug(f"Empty stack, sum of acceptance: {0}")
|
109 |
+
# size of the vocab
|
110 |
+
accepts = [False] * vocab_size
|
111 |
+
accepts[self.eos_token_id] = True
|
112 |
+
return torch.tensor(accepts, dtype=torch.bool, device=device)
|
113 |
+
|
114 |
+
acceptance = self.get_token_acceptance(accept_state, device)
|
115 |
+
|
116 |
+
return acceptance
|
117 |
+
|
118 |
+
def get_token_acceptance(self, accept_state, device) -> torch.Tensor:
|
119 |
+
acceptance_matrix = torch.cat(
|
120 |
+
[
|
121 |
+
self.get_token_acceptance_array_for_stack(
|
122 |
+
tuple(stack), accept_state.partial_utf8, device
|
123 |
+
)
|
124 |
+
for stack in accept_state.stacks
|
125 |
+
]
|
126 |
+
)
|
127 |
+
# Merge stacks: any True => True
|
128 |
+
acceptance = acceptance_matrix.reshape(len(accept_state.stacks), -1).any(dim=0)
|
129 |
+
return acceptance
|
130 |
+
|
131 |
+
@lru_cache(maxsize=32768)
|
132 |
+
def get_token_acceptance_array_for_stack(self, stack, partial_utf8, device):
|
133 |
+
# stack = list(stack) # needs to come in as a tuple for lru_cache
|
134 |
+
assert isinstance(stack, tuple)
|
135 |
+
stack = list(stack)
|
136 |
+
|
137 |
+
if self.byte_encoding:
|
138 |
+
|
139 |
+
accept_f = lambda x: self.string_recognizer._probe_bytes(
|
140 |
+
x, [stack], partial_utf8=partial_utf8
|
141 |
+
)
|
142 |
+
token_acceptance = self.unicode_trie.get_token_acceptance(
|
143 |
+
accept=accept_f, accept_eos=False, eos_token_id=self.eos_token_id
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
accepts = [False] * len(self.mapping)
|
147 |
+
token_acceptance = check_token_acceptance_in_trie(
|
148 |
+
self.token_trie.trie,
|
149 |
+
[stack],
|
150 |
+
self.string_recognizer,
|
151 |
+
self.eos_token_id,
|
152 |
+
accepts,
|
153 |
+
)
|
154 |
+
x = torch.tensor(token_acceptance, dtype=torch.bool, device=device)
|
155 |
+
x_eos = self.validate_and_set_eos_acceptance(x)
|
156 |
+
return x_eos
|
157 |
+
|
158 |
+
def validate_and_set_eos_acceptance(self, acceptance: torch.Tensor) -> torch.Tensor:
|
159 |
+
if torch.any(acceptance) == 0:
|
160 |
+
acceptance[self.eos_token_id] = True
|
161 |
+
else:
|
162 |
+
if acceptance[self.eos_token_id]:
|
163 |
+
raise ValueError()
|
164 |
+
acceptance[self.eos_token_id] = False
|
165 |
+
return acceptance
|
166 |
+
|
167 |
+
|
168 |
+
class IncrementalTokenRecognizer(AbsTokenRecognizer):
|
169 |
+
def __init__(self, grammar_str, start_rule_name, tokenizer, unicode=False):
|
170 |
+
super().__init__(grammar_str, tokenizer, start_rule_name, unicode)
|
171 |
+
self.last_size = None
|
172 |
+
self.is_incremental = True
|
173 |
+
|
174 |
+
# if self.last_size is not set (which would be the case when processing the first token).
|
175 |
+
# In this case, do nothing.
|
176 |
+
|
177 |
+
def advance_token_ids(self, input_ids, batch_accept_states, parse_start_index=None):
|
178 |
+
|
179 |
+
if self.last_size is None:
|
180 |
+
prefix_to_parse = [
|
181 |
+
single_input_ids[parse_start_index:]
|
182 |
+
if parse_start_index is not None
|
183 |
+
else []
|
184 |
+
for single_input_ids in input_ids
|
185 |
+
]
|
186 |
+
|
187 |
+
# self.grammar_acceptor.accept_token_ids(prefix_to_parse, self.stacks)
|
188 |
+
batch_accept_states = [
|
189 |
+
self._consume_token_ids(prefix, accept_state)
|
190 |
+
for prefix, accept_state in zip(prefix_to_parse, batch_accept_states)
|
191 |
+
]
|
192 |
+
# if the length of the current input IDs (input_ids[0]) is exactly one more than self.last_size.
|
193 |
+
# This is expected in a scenario where inputs are processed incrementally, one token at a time.
|
194 |
+
elif len(input_ids[0]) == self.last_size + 1:
|
195 |
+
batch_accept_states = [
|
196 |
+
self._consume_token_id(single_input_ids[-1], accept_state)
|
197 |
+
for single_input_ids, accept_state in zip(
|
198 |
+
input_ids, batch_accept_states
|
199 |
+
)
|
200 |
+
]
|
201 |
+
# ensure that the input size is consistent with the expected incremental processing
|
202 |
+
# (i.e., one token at a time).
|
203 |
+
else:
|
204 |
+
# here we check if the input_ids are one token longer than the last time we processed
|
205 |
+
# but we don't check if input_ids are actually valid.
|
206 |
+
# Imagine a scenario where we generate 10 tokens, then we replace the 10 generated tokens with 10 new tokens.
|
207 |
+
# In this case, the input_ids will be consistent with the last_size, but the input_ids are not valid.
|
208 |
+
# However, should we really check if the input_ids are valid here?
|
209 |
+
# If we do, then we need to reparse the whole input_ids at each call, which is not efficient.
|
210 |
+
# Maybe we should just trust the user to provide valid input_ids?
|
211 |
+
# The conclusion is that, we assume the input_ids are valid, and our generation will be correct.
|
212 |
+
# If the input_ids are not valid, then the generation result will be wrong and we don't take responsibility for that.
|
213 |
+
raise RuntimeError(
|
214 |
+
"Input ID's length is inconsistent with the current state of "
|
215 |
+
"the GrammarConstrainedLogitsProcessor. If you want to process "
|
216 |
+
"another input sequence, please instantiate a new "
|
217 |
+
"GrammarConstrainedLogitsProcessor "
|
218 |
+
"or call reset_parser method of GrammarAlignedOracleLogitsProcessor"
|
219 |
+
)
|
220 |
+
self.last_size = len(input_ids[0])
|
221 |
+
|
222 |
+
return batch_accept_states
|
223 |
+
|
224 |
+
def _consume_token_ids(
|
225 |
+
self, token_ids: List[int], accept_state: AcceptState = None, as_string=True
|
226 |
+
):
|
227 |
+
if accept_state is None:
|
228 |
+
accept_state = self.string_recognizer.get_initial_accept_state()
|
229 |
+
if as_string:
|
230 |
+
string = self.tokenizer.decode(token_ids)
|
231 |
+
accept_state = self.string_recognizer._consume_string(string, accept_state)
|
232 |
+
else:
|
233 |
+
for i, token_id in enumerate(token_ids):
|
234 |
+
accept_state = self._consume_token_id(token_id, accept_state)
|
235 |
+
if len(accept_state.stacks) > 0:
|
236 |
+
cur_token_ids = token_ids[: i + 1]
|
237 |
+
logging.debug(f"{cur_token_ids} is accepted")
|
238 |
+
decoded_string = self.tokenizer.decode(cur_token_ids)
|
239 |
+
logging.debug(f"The decoded string is {decoded_string}")
|
240 |
+
return accept_state
|
241 |
+
|
242 |
+
def reset(self):
|
243 |
+
self.last_size = None
|
244 |
+
|
245 |
+
def check_token_acceptance_in_trie(trie, stacks, grammar, eos_token_id, accepts):
|
246 |
+
|
247 |
+
for byte, next_trie in trie.items():
|
248 |
+
if byte == LEAF:
|
249 |
+
token_id = next_trie
|
250 |
+
if token_id != eos_token_id:
|
251 |
+
# if the stacks is not empty, it means we can still continue to parse
|
252 |
+
# so we should accept the token
|
253 |
+
accepts[token_id] = bool(stacks)
|
254 |
+
continue
|
255 |
+
|
256 |
+
new_stacks = []
|
257 |
+
for stk in stacks:
|
258 |
+
if not stk:
|
259 |
+
continue
|
260 |
+
|
261 |
+
next_element_offset = stk[-1]
|
262 |
+
num_chars = grammar.grammar_encoding[next_element_offset]
|
263 |
+
|
264 |
+
if not grammar.char_acceptance_at_element(next_element_offset).get(
|
265 |
+
byte, False
|
266 |
+
):
|
267 |
+
# if the current byte is not accepted by the current rule, we need to try next rule
|
268 |
+
continue
|
269 |
+
|
270 |
+
next_element_offset += num_chars + 1
|
271 |
+
new_stack = stk[:-1]
|
272 |
+
if grammar.grammar_encoding[next_element_offset]:
|
273 |
+
new_stack.append(next_element_offset)
|
274 |
+
new_stacks.extend(grammar.advance_stack(tuple(new_stack)))
|
275 |
+
|
276 |
+
if new_stacks:
|
277 |
+
check_token_acceptance_in_trie(
|
278 |
+
next_trie, new_stacks, grammar, eos_token_id, accepts
|
279 |
+
)
|
280 |
+
|
281 |
+
return accepts
|
282 |
+
|
283 |
+
|
284 |
+
if __name__ == "__main__":
|
285 |
+
from transformers import AutoTokenizer
|
286 |
+
|
287 |
+
with open("examples/grammars/japanese.ebnf", "r") as file:
|
288 |
+
input_text = file.read()
|
289 |
+
parsed_grammar = parse_ebnf(input_text)
|
290 |
+
parsed_grammar.print()
|
291 |
+
|
292 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
293 |
+
|
294 |
+
tokenRecognizer = IncrementalTokenRecognizer(
|
295 |
+
grammar_str=input_text, start_rule_name="root", tokenizer=tokenizer
|
296 |
+
)
|
297 |
+
|
298 |
+
japanese = "トリーム" # "こんにちは"
|
299 |
+
token_ids = tokenizer.encode(japanese)
|
300 |
+
# 13298, 12675, 12045, 254
|
301 |
+
stacks = tokenRecognizer._consume_token_ids(
|
302 |
+
token_ids, tokenRecognizer.string_recognizer.stacks, as_string=False
|
303 |
+
)
|
304 |
+
|
305 |
+
if stacks:
|
306 |
+
print("The Japanese input is accepted")
|
307 |
+
else:
|
308 |
+
print("The Japanese input is not accepted")
|
309 |
+
|
310 |
+
korean = "안녕하세요"
|
311 |
+
token_ids = tokenizer.encode(korean)
|
312 |
+
|
313 |
+
try:
|
314 |
+
stacks = tokenRecognizer._consume_token_ids(
|
315 |
+
token_ids, tokenRecognizer.string_recognizer.stacks, as_string=False
|
316 |
+
)
|
317 |
+
if stacks:
|
318 |
+
print("The Korean input is accepted")
|
319 |
+
else:
|
320 |
+
print("The Korean input is not accepted")
|
321 |
+
except ValueError as e:
|
322 |
+
print("The Korean input is not accepted")
|
transformers_gad/trie.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from functools import lru_cache
|
3 |
+
from typing import Dict, List, Tuple
|
4 |
+
from collections import deque
|
5 |
+
|
6 |
+
from transformers_gad.mapping import get_mapping
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class TrieNode:
|
11 |
+
def __init__(self):
|
12 |
+
self.children = {}
|
13 |
+
self.is_end_of_word = False
|
14 |
+
self.token_id = None
|
15 |
+
|
16 |
+
|
17 |
+
class ByteTrie:
|
18 |
+
def __init__(self):
|
19 |
+
self.root = TrieNode()
|
20 |
+
|
21 |
+
def insert(self, word, token_id=None):
|
22 |
+
node = self.root
|
23 |
+
for char in word:
|
24 |
+
if char not in node.children:
|
25 |
+
node.children[char] = TrieNode()
|
26 |
+
node = node.children[char]
|
27 |
+
node.is_end_of_word = True
|
28 |
+
node.token_id = token_id
|
29 |
+
|
30 |
+
def search(self, word):
|
31 |
+
node = self.root
|
32 |
+
for char in word:
|
33 |
+
if char not in node.children:
|
34 |
+
return False
|
35 |
+
node = node.children[char]
|
36 |
+
return node.is_end_of_word
|
37 |
+
|
38 |
+
def start_with_prefix(self, prefix):
|
39 |
+
node = self.root
|
40 |
+
for char in prefix:
|
41 |
+
if char not in node.children:
|
42 |
+
return False
|
43 |
+
node = node.children[char]
|
44 |
+
return True
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def from_tokenizer(cls, tokenizer, unicode=True):
|
48 |
+
vocab: Dict[str, int] = tokenizer.get_vocab()
|
49 |
+
trie = cls()
|
50 |
+
mapping = get_mapping(tokenizer, unicode=unicode)
|
51 |
+
for token_id in vocab.values():
|
52 |
+
byte_repr = mapping.map(token_id)
|
53 |
+
trie.insert(byte_repr, token_id)
|
54 |
+
return trie
|
55 |
+
|
56 |
+
@lru_cache(maxsize=128)
|
57 |
+
def __len__(self):
|
58 |
+
return len(self.dfs(verbose=False))
|
59 |
+
|
60 |
+
def dfs(self, accept=lambda x: True, verbose=False) -> List[Tuple[List[int], int]]:
|
61 |
+
result = []
|
62 |
+
counter = {"visited": 0, "pruned": 0}
|
63 |
+
_dfs(self.root, [], result, accept, counter)
|
64 |
+
return result
|
65 |
+
|
66 |
+
def bfs(
|
67 |
+
self, predicate=lambda x: True, verbose=False
|
68 |
+
) -> List[Tuple[List[int], int]]:
|
69 |
+
queue = deque([(self.root, [])])
|
70 |
+
valid_byte_seqs: List[Tuple[List[int], int]] = []
|
71 |
+
counter = {"visited": 0, "pruned": 0}
|
72 |
+
|
73 |
+
while queue:
|
74 |
+
counter["visited"] += 1
|
75 |
+
node, byte_seq = queue.popleft()
|
76 |
+
if predicate(byte_seq):
|
77 |
+
if node.is_end_of_word:
|
78 |
+
valid_byte_seqs.append((byte_seq, node.token_id))
|
79 |
+
for char, next_node in node.children.items():
|
80 |
+
new_byte_seq: List[int] = byte_seq.copy()
|
81 |
+
new_byte_seq.append(char)
|
82 |
+
queue.append((next_node, new_byte_seq))
|
83 |
+
else:
|
84 |
+
counter["pruned"] += 1
|
85 |
+
return valid_byte_seqs
|
86 |
+
|
87 |
+
def get_token_acceptance(
|
88 |
+
self, accept=lambda x: True, accept_eos=True, eos_token_id=None
|
89 |
+
) -> List[bool]:
|
90 |
+
valid_byte_seqs: List[Tuple[List[int], int]] = self.bfs(accept, verbose=True)
|
91 |
+
valid_token_ids: List[int] = [token_id for _, token_id in valid_byte_seqs]
|
92 |
+
token_acceptance: List[bool] = [False] * (len(self))
|
93 |
+
for token_id in valid_token_ids:
|
94 |
+
token_acceptance[token_id] = True
|
95 |
+
if not accept_eos:
|
96 |
+
# eos_token is mapped to an empty string, so it's always accepted regardless of the accept function
|
97 |
+
# this can be undesirable, so we can set it to False to ignore it
|
98 |
+
token_acceptance[eos_token_id] = False
|
99 |
+
return token_acceptance
|
100 |
+
|
101 |
+
|
102 |
+
def _dfs(
|
103 |
+
node,
|
104 |
+
cur_byte_seq: List[int],
|
105 |
+
result: List[Tuple[List[int], int]],
|
106 |
+
accept: callable,
|
107 |
+
counter: Dict[str, int],
|
108 |
+
):
|
109 |
+
counter["visited"] += 1
|
110 |
+
if accept(cur_byte_seq):
|
111 |
+
if node.is_end_of_word:
|
112 |
+
result.append((cur_byte_seq, node.token_id))
|
113 |
+
for char, next_node in node.children.items():
|
114 |
+
new_byte_seq: List[int] = cur_byte_seq.copy()
|
115 |
+
new_byte_seq.append(char)
|
116 |
+
_dfs(next_node, new_byte_seq, result, accept, counter)
|
117 |
+
else:
|
118 |
+
# Skip the entire subtree if the predict function returns False
|
119 |
+
counter["pruned"] += 1
|
120 |
+
return
|
121 |
+
|
122 |
+
|
123 |
+
def starts_with_prefix(prefix, target):
|
124 |
+
"""
|
125 |
+
Check if the given prefix is a valid start of the target word or if the target word is a valid start of the given prefix.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
prefix (str): The string prefix to be checked.
|
129 |
+
target (str): The target word to compare the prefix against.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
bool: True if prefix is a valid start of target or if target is a valid start of prefix, False otherwise.
|
133 |
+
"""
|
134 |
+
|
135 |
+
# Check if the target word starts with the given prefix.
|
136 |
+
# This covers the case where the prefix is shorter than the target word.
|
137 |
+
if target.startswith(prefix):
|
138 |
+
return True
|
139 |
+
|
140 |
+
# Check if the given prefix starts with the target word.
|
141 |
+
# This covers the case where the prefix is longer than or equal to the target word.
|
142 |
+
if prefix.startswith(target):
|
143 |
+
return True
|
144 |
+
|
145 |
+
# If neither of the above conditions are true, return False.
|
146 |
+
return False
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
import logging
|
151 |
+
|
152 |
+
# Configure logging
|
153 |
+
logging.basicConfig(level=logging.INFO)
|
154 |
+
from transformers import AutoTokenizer
|
155 |
+
|
156 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2", fast=True)
|
157 |
+
|
158 |
+
trie = ByteTrie.from_tokenizer(tokenizer, unicode=True)
|
159 |
+
print(f"length of trie: {len(trie)}=={len(tokenizer.vocab.items())}")
|
160 |
+
|
161 |
+
#
|
162 |
+
# print(trie.search("hello")) # Example, replace with actual words from the vocab
|
163 |
+
# print(trie.start_with_prefix("hell"))
|
164 |
+
#
|
165 |
+
# # Example Usage
|
166 |
+
# words = trie.dfs(accept=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0)
|
167 |
+
# for word in words:
|
168 |
+
# print(bytes(word[0]).decode("utf-8"))
|
169 |
+
#
|
170 |
+
# # Example Usage
|
171 |
+
# words = trie.bfs(predicate=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0)
|
172 |
+
# for word in words:
|
173 |
+
# print(bytes(word[0]).decode("utf-8"))
|
174 |
+
#
|
175 |
+
# token_acceptance = trie.get_token_acceptance(accept=lambda x: len(x) > 0 and x[0] == 65 or len(x)==0)
|
176 |
+
# print(sum(token_acceptance))
|
177 |
+
# assert sum(token_acceptance) == len(words)
|
178 |
+
|
179 |
+
########################
|
180 |
+
# UTF-8
|
181 |
+
########################
|
182 |
+
|
183 |
+
# from transformers import AutoTokenizer
|
184 |
+
#
|
185 |
+
# japanese = "こんにちは世界"
|
186 |
+
# with open("examples/grammars/japanese.ebnf", "r") as file:
|
187 |
+
# input_text = file.read()
|
188 |
+
# parsed_grammar = parse_ebnf(input_text)
|
189 |
+
#
|
190 |
+
# start_rule_id = parsed_grammar.symbol_table["root"]
|
191 |
+
#
|
192 |
+
# recognizer = GrammarRecognizer(parsed_grammar.grammar_encoding, start_rule_id)
|
193 |
+
# accept_state = recognizer.init_accept_state()
|
194 |
+
# token_acc = trie.get_token_acceptance(accept=lambda x: recognizer._probe_bytes_partial_match(x, accept_state=accept_state))
|
transformers_gad/utf8_utils.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
from dataclasses import dataclass
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class PartialUTF8:
|
9 |
+
"""
|
10 |
+
A data class representing the state of a partially decoded UTF-8 sequence.
|
11 |
+
|
12 |
+
Attributes:
|
13 |
+
- value (int): The current accumulated value of the partially decoded Unicode code point.
|
14 |
+
This attribute stores the bits that have been decoded so far. For a fully decoded
|
15 |
+
character or before any partial decoding has started, this would typically be `0`.
|
16 |
+
|
17 |
+
- n_remain (int): The number of bytes remaining to complete the current UTF-8 encoded character.
|
18 |
+
A value of `-1` indicates that there is no ongoing partial decoding, i.e.,
|
19 |
+
either decoding has not started, or the last character was fully decoded.
|
20 |
+
|
21 |
+
This class is used to handle situations where UTF-8 encoded data may end in the middle of a character
|
22 |
+
sequence, allowing for the decoding process to be resumed when more data becomes available.
|
23 |
+
"""
|
24 |
+
|
25 |
+
value: int = 0 # Default to 0, indicating no partial value accumulated
|
26 |
+
n_remain: int = (
|
27 |
+
-1
|
28 |
+
) # Default to -1, indicating no bytes are currently expected to complete the character
|
29 |
+
|
30 |
+
def __hash__(self):
|
31 |
+
return hash((self.value, self.n_remain))
|
32 |
+
|
33 |
+
def __eq__(self, other):
|
34 |
+
if not isinstance(other, PartialUTF8):
|
35 |
+
return NotImplemented
|
36 |
+
return self.value == other.value and self.n_remain == other.n_remain
|
37 |
+
|
38 |
+
|
39 |
+
from typing import List, Tuple
|
40 |
+
from functools import lru_cache
|
41 |
+
|
42 |
+
|
43 |
+
@lru_cache(maxsize=3000000)
|
44 |
+
def decode_utf8(
|
45 |
+
src: bytes, partial_start: PartialUTF8
|
46 |
+
) -> Tuple[List[int], PartialUTF8]:
|
47 |
+
# Lookup table for determining the total bytes based on the first byte's high 4 bits
|
48 |
+
lookup = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4]
|
49 |
+
pos = 0 # Position in the src bytes to start decoding from
|
50 |
+
code_points = [] # List to store the decoded Unicode code points
|
51 |
+
value = partial_start.value # Start with any previously partial decoded value
|
52 |
+
n_remain = partial_start.n_remain # Number of bytes remaining from a partial decode
|
53 |
+
|
54 |
+
# If there's a partial sequence left from last decode, try to continue decoding it
|
55 |
+
while pos < len(src) and n_remain > 0:
|
56 |
+
next_byte = src[pos] # Get the next byte to process
|
57 |
+
# Check if the continuation byte format is correct (`10xxxxxx`)
|
58 |
+
if (next_byte >> 6) != 2:
|
59 |
+
# If not, it's an invalid sequence. Abort and return a special error state.
|
60 |
+
code_points = [0]
|
61 |
+
return code_points, PartialUTF8(0, -1)
|
62 |
+
|
63 |
+
# Accumulate the value by shifting left and adding the relevant 6 bits
|
64 |
+
value = (value << 6) + (next_byte & 0x3F)
|
65 |
+
pos += 1 # Move to the next byte
|
66 |
+
n_remain -= 1 # Decrement the number of remaining bytes
|
67 |
+
|
68 |
+
# If we've completed a partial sequence, add its value to the code points
|
69 |
+
if partial_start.n_remain > 0 and n_remain == 0:
|
70 |
+
code_points.append(value)
|
71 |
+
|
72 |
+
# Process the rest of src as complete or new UTF-8 sequences
|
73 |
+
while pos < len(src):
|
74 |
+
first_byte = src[pos] # Get the first byte of the next sequence
|
75 |
+
highbits = first_byte >> 4 # Extract the high 4 bits for the lookup table
|
76 |
+
n_remain = lookup[highbits] - 1 # Determine remaining bytes in this sequence
|
77 |
+
|
78 |
+
# If lookup returns an invalid number, it's an invalid sequence. Abort.
|
79 |
+
if n_remain < 0:
|
80 |
+
# raise ValueError("Invalid UTF-8 sequence")
|
81 |
+
code_points = [0]
|
82 |
+
return code_points, PartialUTF8(0, -1)
|
83 |
+
|
84 |
+
# Calculate the mask to isolate significant bits from the first byte
|
85 |
+
mask = (1 << (7 - n_remain)) - 1
|
86 |
+
value = first_byte & mask # Apply the mask to get the initial value
|
87 |
+
pos += 1 # Move to the next byte
|
88 |
+
|
89 |
+
# Process the continuation bytes
|
90 |
+
while pos < len(src) and n_remain > 0:
|
91 |
+
next_byte = src[pos]
|
92 |
+
# Shift the accumulated value and add the next 6 significant bits
|
93 |
+
value = (value << 6) + (next_byte & 0x3F)
|
94 |
+
pos += 1 # Move to the next byte
|
95 |
+
n_remain -= 1 # Decrement the number of remaining bytes
|
96 |
+
|
97 |
+
# If the sequence is complete, add its decoded value to the code points
|
98 |
+
if n_remain == 0:
|
99 |
+
code_points.append(value)
|
100 |
+
|
101 |
+
# # Append a terminating value to indicate the end (following llama-cpp implementation)
|
102 |
+
# code_points.append(0)
|
103 |
+
# the following line is crucial for LRU cache to work, as it reset to the initial state
|
104 |
+
if n_remain == 0:
|
105 |
+
n_remain = -1
|
106 |
+
value = 0
|
107 |
+
|
108 |
+
# Return the decoded code points and the state of any partial decoding
|
109 |
+
return code_points, PartialUTF8(value, n_remain)
|
110 |
+
|
111 |
+
|
112 |
+
def decode_utf8_leading_char(src: bytes) -> tuple:
|
113 |
+
first_byte = src[0]
|
114 |
+
highbits = first_byte >> 4
|
115 |
+
lookup = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4]
|
116 |
+
char_len = lookup[highbits]
|
117 |
+
|
118 |
+
# Extract the relevant bytes for the UTF-8 character
|
119 |
+
utf8_char_bytes = src[:char_len]
|
120 |
+
|
121 |
+
# Decode the character
|
122 |
+
char = utf8_char_bytes.decode("utf-8")
|
123 |
+
|
124 |
+
# Use ord() to convert the single character to its Unicode code point
|
125 |
+
code_point = ord(char)
|
126 |
+
|
127 |
+
# Remaining bytes
|
128 |
+
remaining_bytes = src[char_len:]
|
129 |
+
|
130 |
+
return code_point, remaining_bytes
|
131 |
+
|
132 |
+
|
133 |
+
def decode_utf8_string(utf8_bytes: bytes) -> list:
|
134 |
+
code_points = []
|
135 |
+
while utf8_bytes:
|
136 |
+
code_point, utf8_bytes = decode_utf8_leading_char(utf8_bytes)
|
137 |
+
code_points.append(code_point)
|
138 |
+
return code_points
|
139 |
+
|
140 |
+
if __name__ == "__main__":
|
141 |
+
# Given string
|
142 |
+
my_string = "€Hello" # The Euro symbol followed by "Hello"
|
143 |
+
|
144 |
+
# Get UTF-8 encoded bytes
|
145 |
+
utf8_bytes = my_string.encode("utf-8")
|
146 |
+
|
147 |
+
assert utf8_bytes == b"\xe2\x82\xacHello"
|
148 |
+
|
149 |
+
# Example usage with the Euro symbol followed by more characters
|
150 |
+
code_point, remaining_bytes = decode_utf8_leading_char(utf8_bytes)
|
151 |
+
|
152 |
+
print(f"Code Point: {code_point}") # Expected Output: 8364 (Euro symbol)
|
153 |
+
print(f"Remaining Bytes: {remaining_bytes}") # Expected Output: b'Hello'
|
154 |
+
|
155 |
+
# Example usage with the entire string
|
156 |
+
code_points = decode_utf8_string(utf8_bytes)
|
157 |
+
|
158 |
+
print(
|
159 |
+
f"Code Points: {code_points}"
|
160 |
+
) # Expected Output: [8364, 72, 101, 108, 108, 111]
|
161 |
+
|
162 |
+
print("-" * 50)
|
163 |
+
|
164 |
+
# Example usage:
|
165 |
+
utf8_bytes = b"\xe2\x82\xacHello" # UTF-8 encoded string (Euro symbol + "Hello")
|
166 |
+
partial_start = PartialUTF8() # Assuming start with no partial sequence
|
167 |
+
code_points, partial_utf8 = decode_utf8(utf8_bytes, partial_start)
|
168 |
+
|
169 |
+
print("Code Points:", code_points)
|
170 |
+
print("Remaining UTF-8 State:", partial_utf8.value, partial_utf8.n_remain)
|
transformers_gad/utils.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import warnings
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
from termcolor import colored
|
6 |
+
|
7 |
+
|
8 |
+
def ints2bytes(sequence: List[int]) -> bytes:
|
9 |
+
# check in the range of 0-255
|
10 |
+
for item in sequence:
|
11 |
+
if not 0 <= item <= 255:
|
12 |
+
raise ValueError(f"item: {item} is not in the range [0, 255]")
|
13 |
+
return bytes(sequence)
|
14 |
+
|
15 |
+
|
16 |
+
def bytes2ints(byte_sequence: bytes) -> List[int]:
|
17 |
+
return list(byte_sequence)
|
18 |
+
|
19 |
+
|
20 |
+
def intervals_intersect(low1, high1, low2, high2):
|
21 |
+
"""
|
22 |
+
Check if two intervals [low1, high1] and [low2, high2] intersect.
|
23 |
+
|
24 |
+
:param high1: High bound of the first interval.
|
25 |
+
:param low1: Low bound of the first interval.
|
26 |
+
:param high2: High bound of the second interval.
|
27 |
+
:param low2: Low bound of the second interval.
|
28 |
+
:return: True if the intervals intersect, False otherwise.
|
29 |
+
"""
|
30 |
+
# Check if one interval is completely to the right of the other
|
31 |
+
if low1 > high2 or low2 > high1:
|
32 |
+
return False
|
33 |
+
|
34 |
+
# If the above condition is not met, the intervals intersect
|
35 |
+
return True
|
36 |
+
|
37 |
+
|
38 |
+
def pprint_token_ids(tokenizer, token_ids=None, text=None):
|
39 |
+
if token_ids is None and text is None:
|
40 |
+
raise ValueError("Either token_ids or text should be provided")
|
41 |
+
if token_ids is None:
|
42 |
+
token_ids = tokenizer.encode(text, add_special_tokens=False)
|
43 |
+
special_token_ids = tokenizer.all_special_ids
|
44 |
+
special_tokens = tokenizer.all_special_tokens
|
45 |
+
special_id2token = {
|
46 |
+
id: token for id, token in zip(special_token_ids, special_tokens)
|
47 |
+
}
|
48 |
+
# loop over token_ids and color the special tokens
|
49 |
+
colored_token_ids = []
|
50 |
+
|
51 |
+
for token_id in token_ids:
|
52 |
+
if token_id in special_id2token:
|
53 |
+
colored_token_ids.append(colored(token_id, "red", attrs=["bold"]))
|
54 |
+
else:
|
55 |
+
colored_token_ids.append(str(token_id))
|
56 |
+
colored_token_ids_str = [str(item) for item in colored_token_ids]
|
57 |
+
print("[" + ", ".join(colored_token_ids_str) + "]")
|
58 |
+
|
59 |
+
|
60 |
+
def get_tokenizer_model_type(model: str = "gpt2"):
|
61 |
+
"""
|
62 |
+
reference https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_fast.py#L729
|
63 |
+
:param model:
|
64 |
+
:return: BPE, Unigram, WordPiece, WordLevel
|
65 |
+
SentencePiece is used in conjunction with Unigram
|
66 |
+
"""
|
67 |
+
from transformers import AutoTokenizer
|
68 |
+
|
69 |
+
# if the tokenizer is not in the repo, it will raise OSError
|
70 |
+
# OSError: Can't load tokenizer for 'xxx'
|
71 |
+
# This happens when the model reuses the tokenizer of another model
|
72 |
+
if type(model) == str:
|
73 |
+
try:
|
74 |
+
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
|
75 |
+
# check if the tokenizer is fast
|
76 |
+
except OSError:
|
77 |
+
return None
|
78 |
+
else:
|
79 |
+
tokenizer = model
|
80 |
+
|
81 |
+
if not tokenizer.is_fast:
|
82 |
+
raise ValueError(f"The tokenizer {model} is not fast tokenizer")
|
83 |
+
tokenizer_json = json.loads(tokenizer._tokenizer.to_str())
|
84 |
+
model_type = tokenizer_json["model"]["type"]
|
85 |
+
if (
|
86 |
+
model_type == "BPE"
|
87 |
+
and tokenizer_json["pre_tokenizer"] is not None
|
88 |
+
and (
|
89 |
+
tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel"
|
90 |
+
or (
|
91 |
+
"pretokenizers" in tokenizer_json["pre_tokenizer"]
|
92 |
+
and tokenizer_json["pre_tokenizer"]["pretokenizers"][1]["type"]
|
93 |
+
== "ByteLevel"
|
94 |
+
)
|
95 |
+
)
|
96 |
+
):
|
97 |
+
model_type = "ByteLevelBPE"
|
98 |
+
return model_type
|
transformers_gad/vocab_struct.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#################
|
2 |
+
# DATA STRUCTURES
|
3 |
+
#################
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import re
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
LEAF = -1
|
11 |
+
|
12 |
+
# TokenTrie is a trie that maps token IDs to their byte representations
|
13 |
+
|
14 |
+
class TokenTrie:
|
15 |
+
def __init__(self, tokenizer):
|
16 |
+
self.eos_token_id = tokenizer.eos_token_id
|
17 |
+
self.tokens = []
|
18 |
+
self.trie = {}
|
19 |
+
self.load_tokens(tokenizer)
|
20 |
+
|
21 |
+
def id2str(self, token_id):
|
22 |
+
return self.tokens[token_id]
|
23 |
+
|
24 |
+
def __len__(self):
|
25 |
+
return len(self.tokens)
|
26 |
+
|
27 |
+
def load_tokens(self, tokenizer):
|
28 |
+
def replace_hex(match):
|
29 |
+
hex_value = match.group(1)
|
30 |
+
return chr(int(hex_value, 16))
|
31 |
+
|
32 |
+
if "gpt2" in tokenizer.__class__.__name__.lower():
|
33 |
+
special = tokenizer.additional_special_tokens_ids
|
34 |
+
|
35 |
+
# Here, the decoder does a string replace on a bunch of sequences
|
36 |
+
# like ' .' for '.'. This interferes with our assumptions, where a
|
37 |
+
# token should always have exactly one representation.
|
38 |
+
# Fortunately(?) text-generation-inference doesn't seem to run this
|
39 |
+
# cleanup, so we get extraneous spaces. So, in order to generate
|
40 |
+
# the right token set for TGI, we have to skip the space trimming.
|
41 |
+
# See:
|
42 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3588-L3600
|
43 |
+
def fmt_token(id):
|
44 |
+
if id in special:
|
45 |
+
return None
|
46 |
+
return bytes(
|
47 |
+
tokenizer.decode([id], clean_up_tokenization_spaces=False), "utf-8"
|
48 |
+
)
|
49 |
+
|
50 |
+
elif (
|
51 |
+
"llama" in tokenizer.__class__.__name__.lower()
|
52 |
+
or "t5" in tokenizer.__class__.__name__.lower()
|
53 |
+
):
|
54 |
+
|
55 |
+
def fmt_token(id):
|
56 |
+
token = tokenizer.convert_ids_to_tokens(id)
|
57 |
+
token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token)
|
58 |
+
token = token.replace("▁", " ")
|
59 |
+
return bytes(token, "utf-8") # here return bytes representations of the tokens
|
60 |
+
|
61 |
+
else:
|
62 |
+
logger.warning(
|
63 |
+
"Warning: unrecognized tokenizer: using default token formatting"
|
64 |
+
)
|
65 |
+
|
66 |
+
def fmt_token(id):
|
67 |
+
token = tokenizer.convert_ids_to_tokens(id)
|
68 |
+
return bytes(token, "utf-8")
|
69 |
+
|
70 |
+
# note: vocab_size doesn't work here because there are also
|
71 |
+
# get_added_vocab() tokens
|
72 |
+
self.tokens = [fmt_token(i) for i in range(len(tokenizer.get_vocab()))]
|
73 |
+
for token_id, token_bytes in enumerate(self.tokens):
|
74 |
+
if token_bytes is not None:
|
75 |
+
self.insert_into_trie(self.trie, token_bytes, token_id)
|
76 |
+
|
77 |
+
def insert_into_trie(self, trie, token_bytes, token_id):
|
78 |
+
current = trie
|
79 |
+
for byte in token_bytes:
|
80 |
+
if byte not in current:
|
81 |
+
current[byte] = {}
|
82 |
+
current = current[byte]
|
83 |
+
current[LEAF] = token_id
|