Spaces:
Runtime error
Runtime error
from typing import List | |
import re | |
import sqlparse | |
class TreeNode(object): | |
def __init__(self, name=None, father=None): | |
self.name: str = name | |
self.rename: str = name | |
self.father: TreeNode = father | |
self.children: List = [] | |
self.produced_col_name_s = None | |
def __eq__(self, other): | |
return self.rename == other.rename | |
def __hash__(self): | |
return hash(self.rename) | |
def set_name(self, name): | |
self.name = name | |
self.rename = name | |
def add_child(self, child): | |
self.children.append(child) | |
child.father = self | |
def rename_father_col(self, col_idx: int, col_prefix: str = "col_"): | |
new_col_name = "{}{}".format(col_prefix, col_idx) | |
self.father.rename = self.father.rename.replace(self.name, "{}".format(new_col_name)) | |
self.produced_col_name_s = [new_col_name] # fixme when multiple outputs for a qa func | |
def rename_father_val(self, val_names): | |
if len(val_names) == 1: | |
val_name = val_names[0] | |
new_val_equals_str = "'{}'".format(val_name) if isinstance(convert_type(val_name), str) else "{}".format( | |
val_name) | |
else: | |
new_val_equals_str = '({})'.format(', '.join(["'{}'".format(val_name) for val_name in val_names])) | |
self.father.rename = self.father.rename.replace(self.name, new_val_equals_str) | |
def get_cfg_tree(nsql: str): | |
""" | |
Parse QA() into a tree for execution guiding. | |
@param nsql: | |
@return: | |
""" | |
stack: List = [] # Saving the state of the char. | |
expression_stack: List = [] # Saving the state of the expression. | |
current_tree_node = TreeNode(name=nsql) | |
for idx in range(len(nsql)): | |
if nsql[idx] == "(": | |
stack.append(idx) | |
if idx > 1 and nsql[idx - 2:idx + 1] == "QA(" and idx - 2 != 0: | |
tree_node = TreeNode() | |
current_tree_node.add_child(tree_node) | |
expression_stack.append(current_tree_node) | |
current_tree_node = tree_node | |
elif nsql[idx] == ")": | |
left_clause_idx = stack.pop() | |
if idx > 1 and nsql[left_clause_idx - 2:left_clause_idx + 1] == "QA(" and left_clause_idx - 2 != 0: | |
# the QA clause | |
nsql_span = nsql[left_clause_idx - 2:idx + 1] | |
current_tree_node.set_name(nsql_span) | |
current_tree_node = expression_stack.pop() | |
return current_tree_node | |
def get_steps(tree_node: TreeNode, steps: List): | |
"""Pred-Order Traversal""" | |
for child in tree_node.children: | |
get_steps(child, steps) | |
steps.append(tree_node) | |
def parse_question_paras(nsql: str, qa_model): | |
# We assume there's no nested qa inside when running this func | |
nsql = nsql.strip(" ;") | |
assert nsql[:3] == "QA(" and nsql[-1] == ")", "must start with QA( symbol and end with )" | |
assert not "QA" in nsql[2:-1], "must have no nested qa inside" | |
# Get question and the left part(paras_raw_str) | |
all_quote_idx = [i.start() for i in re.finditer('\"', nsql)] | |
question = nsql[all_quote_idx[0] + 1: all_quote_idx[1]] | |
paras_raw_str = nsql[all_quote_idx[1] + 1:-1].strip(" ;") | |
# Split Parameters(SQL/column/value) from all parameters. | |
paras = [_para.strip(' ;') for _para in sqlparse.split(paras_raw_str)] | |
return question, paras | |
def convert_type(value): | |
try: | |
return eval(value) | |
except Exception as e: | |
return value | |
def nsql_role_recognize(nsql_like_str, all_headers, all_passage_titles, all_image_titles): | |
"""Recognize role. (SQL/column/value) """ | |
orig_nsql_like_str = nsql_like_str | |
# strip the first and the last '`' | |
if nsql_like_str.startswith('`') and nsql_like_str.endswith('`'): | |
nsql_like_str = nsql_like_str[1:-1] | |
# Case 1: if col in header, it is column type. | |
if nsql_like_str in all_headers or nsql_like_str in list(map(lambda x: x.lower(), all_headers)): | |
return 'col', orig_nsql_like_str | |
# fixme: add case when the this nsql_like_str both in table headers, images title and in passages title. | |
# Case 2.1: if it is title of certain passage. | |
if (nsql_like_str.lower() in list(map(lambda x: x.lower(), all_passage_titles))) \ | |
and (nsql_like_str.lower() in list(map(lambda x: x.lower(), all_image_titles))): | |
return "passage_title_and_image_title", orig_nsql_like_str | |
else: | |
try: | |
nsql_like_str_evaled = str(eval(nsql_like_str)) | |
if (nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_passage_titles))) \ | |
and (nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_image_titles))): | |
return "passage_title_and_image_title", nsql_like_str_evaled | |
except: | |
pass | |
# Case 2.2: if it is title of certain passage. | |
if nsql_like_str.lower() in list(map(lambda x: x.lower(), all_passage_titles)): | |
return "passage_title", orig_nsql_like_str | |
else: | |
try: | |
nsql_like_str_evaled = str(eval(nsql_like_str)) | |
if nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_passage_titles)): | |
return "passage_title", nsql_like_str_evaled | |
except: | |
pass | |
# Case 2.3: if it is title of certain picture. | |
if nsql_like_str.lower() in list(map(lambda x: x.lower(), all_image_titles)): | |
return "image_title", orig_nsql_like_str | |
else: | |
try: | |
nsql_like_str_evaled = str(eval(nsql_like_str)) | |
if nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_image_titles)): | |
return "image_title", nsql_like_str_evaled | |
except: | |
pass | |
# Case 4: if it can be parsed by eval(), it is value type. | |
try: | |
eval(nsql_like_str) | |
return 'val', orig_nsql_like_str | |
except Exception as e: | |
pass | |
# Case 5: else it should be the sql, if it isn't, exception will be raised. | |
return 'complete_sql', orig_nsql_like_str | |
def remove_duplicate(original_list): | |
no_duplicate_list = [] | |
[no_duplicate_list.append(i) for i in original_list if i not in no_duplicate_list] | |
return no_duplicate_list | |
def extract_answers(sub_table): | |
if not sub_table or sub_table['header'] is None: | |
return [] | |
answer = [] | |
if 'row_id' in sub_table['header']: | |
for _row in sub_table['rows']: | |
answer.extend(_row[1:]) | |
return answer | |
else: | |
for _row in sub_table['rows']: | |
answer.extend(_row) | |
return answer | |