Spaces:
Runtime error
Runtime error
File size: 6,594 Bytes
f6f97d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
from typing import List
import re
import sqlparse
class TreeNode(object):
def __init__(self, name=None, father=None):
self.name: str = name
self.rename: str = name
self.father: TreeNode = father
self.children: List = []
self.produced_col_name_s = None
def __eq__(self, other):
return self.rename == other.rename
def __hash__(self):
return hash(self.rename)
def set_name(self, name):
self.name = name
self.rename = name
def add_child(self, child):
self.children.append(child)
child.father = self
def rename_father_col(self, col_idx: int, col_prefix: str = "col_"):
new_col_name = "{}{}".format(col_prefix, col_idx)
self.father.rename = self.father.rename.replace(self.name, "{}".format(new_col_name))
self.produced_col_name_s = [new_col_name] # fixme when multiple outputs for a qa func
def rename_father_val(self, val_names):
if len(val_names) == 1:
val_name = val_names[0]
new_val_equals_str = "'{}'".format(val_name) if isinstance(convert_type(val_name), str) else "{}".format(
val_name)
else:
new_val_equals_str = '({})'.format(', '.join(["'{}'".format(val_name) for val_name in val_names]))
self.father.rename = self.father.rename.replace(self.name, new_val_equals_str)
def get_cfg_tree(nsql: str):
"""
Parse QA() into a tree for execution guiding.
@param nsql:
@return:
"""
stack: List = [] # Saving the state of the char.
expression_stack: List = [] # Saving the state of the expression.
current_tree_node = TreeNode(name=nsql)
for idx in range(len(nsql)):
if nsql[idx] == "(":
stack.append(idx)
if idx > 1 and nsql[idx - 2:idx + 1] == "QA(" and idx - 2 != 0:
tree_node = TreeNode()
current_tree_node.add_child(tree_node)
expression_stack.append(current_tree_node)
current_tree_node = tree_node
elif nsql[idx] == ")":
left_clause_idx = stack.pop()
if idx > 1 and nsql[left_clause_idx - 2:left_clause_idx + 1] == "QA(" and left_clause_idx - 2 != 0:
# the QA clause
nsql_span = nsql[left_clause_idx - 2:idx + 1]
current_tree_node.set_name(nsql_span)
current_tree_node = expression_stack.pop()
return current_tree_node
def get_steps(tree_node: TreeNode, steps: List):
"""Pred-Order Traversal"""
for child in tree_node.children:
get_steps(child, steps)
steps.append(tree_node)
def parse_question_paras(nsql: str, qa_model):
# We assume there's no nested qa inside when running this func
nsql = nsql.strip(" ;")
assert nsql[:3] == "QA(" and nsql[-1] == ")", "must start with QA( symbol and end with )"
assert not "QA" in nsql[2:-1], "must have no nested qa inside"
# Get question and the left part(paras_raw_str)
all_quote_idx = [i.start() for i in re.finditer('\"', nsql)]
question = nsql[all_quote_idx[0] + 1: all_quote_idx[1]]
paras_raw_str = nsql[all_quote_idx[1] + 1:-1].strip(" ;")
# Split Parameters(SQL/column/value) from all parameters.
paras = [_para.strip(' ;') for _para in sqlparse.split(paras_raw_str)]
return question, paras
def convert_type(value):
try:
return eval(value)
except Exception as e:
return value
def nsql_role_recognize(nsql_like_str, all_headers, all_passage_titles, all_image_titles):
"""Recognize role. (SQL/column/value) """
orig_nsql_like_str = nsql_like_str
# strip the first and the last '`'
if nsql_like_str.startswith('`') and nsql_like_str.endswith('`'):
nsql_like_str = nsql_like_str[1:-1]
# Case 1: if col in header, it is column type.
if nsql_like_str in all_headers or nsql_like_str in list(map(lambda x: x.lower(), all_headers)):
return 'col', orig_nsql_like_str
# fixme: add case when the this nsql_like_str both in table headers, images title and in passages title.
# Case 2.1: if it is title of certain passage.
if (nsql_like_str.lower() in list(map(lambda x: x.lower(), all_passage_titles))) \
and (nsql_like_str.lower() in list(map(lambda x: x.lower(), all_image_titles))):
return "passage_title_and_image_title", orig_nsql_like_str
else:
try:
nsql_like_str_evaled = str(eval(nsql_like_str))
if (nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_passage_titles))) \
and (nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_image_titles))):
return "passage_title_and_image_title", nsql_like_str_evaled
except:
pass
# Case 2.2: if it is title of certain passage.
if nsql_like_str.lower() in list(map(lambda x: x.lower(), all_passage_titles)):
return "passage_title", orig_nsql_like_str
else:
try:
nsql_like_str_evaled = str(eval(nsql_like_str))
if nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_passage_titles)):
return "passage_title", nsql_like_str_evaled
except:
pass
# Case 2.3: if it is title of certain picture.
if nsql_like_str.lower() in list(map(lambda x: x.lower(), all_image_titles)):
return "image_title", orig_nsql_like_str
else:
try:
nsql_like_str_evaled = str(eval(nsql_like_str))
if nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_image_titles)):
return "image_title", nsql_like_str_evaled
except:
pass
# Case 4: if it can be parsed by eval(), it is value type.
try:
eval(nsql_like_str)
return 'val', orig_nsql_like_str
except Exception as e:
pass
# Case 5: else it should be the sql, if it isn't, exception will be raised.
return 'complete_sql', orig_nsql_like_str
def remove_duplicate(original_list):
no_duplicate_list = []
[no_duplicate_list.append(i) for i in original_list if i not in no_duplicate_list]
return no_duplicate_list
def extract_answers(sub_table):
if not sub_table or sub_table['header'] is None:
return []
answer = []
if 'row_id' in sub_table['header']:
for _row in sub_table['rows']:
answer.extend(_row[1:])
return answer
else:
for _row in sub_table['rows']:
answer.extend(_row)
return answer
|