Binder / nsql /parser.py
Timothyxxx
Init
f6f97d8
from typing import List
import re
import sqlparse
class TreeNode(object):
def __init__(self, name=None, father=None):
self.name: str = name
self.rename: str = name
self.father: TreeNode = father
self.children: List = []
self.produced_col_name_s = None
def __eq__(self, other):
return self.rename == other.rename
def __hash__(self):
return hash(self.rename)
def set_name(self, name):
self.name = name
self.rename = name
def add_child(self, child):
self.children.append(child)
child.father = self
def rename_father_col(self, col_idx: int, col_prefix: str = "col_"):
new_col_name = "{}{}".format(col_prefix, col_idx)
self.father.rename = self.father.rename.replace(self.name, "{}".format(new_col_name))
self.produced_col_name_s = [new_col_name] # fixme when multiple outputs for a qa func
def rename_father_val(self, val_names):
if len(val_names) == 1:
val_name = val_names[0]
new_val_equals_str = "'{}'".format(val_name) if isinstance(convert_type(val_name), str) else "{}".format(
val_name)
else:
new_val_equals_str = '({})'.format(', '.join(["'{}'".format(val_name) for val_name in val_names]))
self.father.rename = self.father.rename.replace(self.name, new_val_equals_str)
def get_cfg_tree(nsql: str):
"""
Parse QA() into a tree for execution guiding.
@param nsql:
@return:
"""
stack: List = [] # Saving the state of the char.
expression_stack: List = [] # Saving the state of the expression.
current_tree_node = TreeNode(name=nsql)
for idx in range(len(nsql)):
if nsql[idx] == "(":
stack.append(idx)
if idx > 1 and nsql[idx - 2:idx + 1] == "QA(" and idx - 2 != 0:
tree_node = TreeNode()
current_tree_node.add_child(tree_node)
expression_stack.append(current_tree_node)
current_tree_node = tree_node
elif nsql[idx] == ")":
left_clause_idx = stack.pop()
if idx > 1 and nsql[left_clause_idx - 2:left_clause_idx + 1] == "QA(" and left_clause_idx - 2 != 0:
# the QA clause
nsql_span = nsql[left_clause_idx - 2:idx + 1]
current_tree_node.set_name(nsql_span)
current_tree_node = expression_stack.pop()
return current_tree_node
def get_steps(tree_node: TreeNode, steps: List):
"""Pred-Order Traversal"""
for child in tree_node.children:
get_steps(child, steps)
steps.append(tree_node)
def parse_question_paras(nsql: str, qa_model):
# We assume there's no nested qa inside when running this func
nsql = nsql.strip(" ;")
assert nsql[:3] == "QA(" and nsql[-1] == ")", "must start with QA( symbol and end with )"
assert not "QA" in nsql[2:-1], "must have no nested qa inside"
# Get question and the left part(paras_raw_str)
all_quote_idx = [i.start() for i in re.finditer('\"', nsql)]
question = nsql[all_quote_idx[0] + 1: all_quote_idx[1]]
paras_raw_str = nsql[all_quote_idx[1] + 1:-1].strip(" ;")
# Split Parameters(SQL/column/value) from all parameters.
paras = [_para.strip(' ;') for _para in sqlparse.split(paras_raw_str)]
return question, paras
def convert_type(value):
try:
return eval(value)
except Exception as e:
return value
def nsql_role_recognize(nsql_like_str, all_headers, all_passage_titles, all_image_titles):
"""Recognize role. (SQL/column/value) """
orig_nsql_like_str = nsql_like_str
# strip the first and the last '`'
if nsql_like_str.startswith('`') and nsql_like_str.endswith('`'):
nsql_like_str = nsql_like_str[1:-1]
# Case 1: if col in header, it is column type.
if nsql_like_str in all_headers or nsql_like_str in list(map(lambda x: x.lower(), all_headers)):
return 'col', orig_nsql_like_str
# fixme: add case when the this nsql_like_str both in table headers, images title and in passages title.
# Case 2.1: if it is title of certain passage.
if (nsql_like_str.lower() in list(map(lambda x: x.lower(), all_passage_titles))) \
and (nsql_like_str.lower() in list(map(lambda x: x.lower(), all_image_titles))):
return "passage_title_and_image_title", orig_nsql_like_str
else:
try:
nsql_like_str_evaled = str(eval(nsql_like_str))
if (nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_passage_titles))) \
and (nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_image_titles))):
return "passage_title_and_image_title", nsql_like_str_evaled
except:
pass
# Case 2.2: if it is title of certain passage.
if nsql_like_str.lower() in list(map(lambda x: x.lower(), all_passage_titles)):
return "passage_title", orig_nsql_like_str
else:
try:
nsql_like_str_evaled = str(eval(nsql_like_str))
if nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_passage_titles)):
return "passage_title", nsql_like_str_evaled
except:
pass
# Case 2.3: if it is title of certain picture.
if nsql_like_str.lower() in list(map(lambda x: x.lower(), all_image_titles)):
return "image_title", orig_nsql_like_str
else:
try:
nsql_like_str_evaled = str(eval(nsql_like_str))
if nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_image_titles)):
return "image_title", nsql_like_str_evaled
except:
pass
# Case 4: if it can be parsed by eval(), it is value type.
try:
eval(nsql_like_str)
return 'val', orig_nsql_like_str
except Exception as e:
pass
# Case 5: else it should be the sql, if it isn't, exception will be raised.
return 'complete_sql', orig_nsql_like_str
def remove_duplicate(original_list):
no_duplicate_list = []
[no_duplicate_list.append(i) for i in original_list if i not in no_duplicate_list]
return no_duplicate_list
def extract_answers(sub_table):
if not sub_table or sub_table['header'] is None:
return []
answer = []
if 'row_id' in sub_table['header']:
for _row in sub_table['rows']:
answer.extend(_row[1:])
return answer
else:
for _row in sub_table['rows']:
answer.extend(_row)
return answer