Spaces:
Runtime error
Runtime error
import json | |
from typing import List, Dict | |
from nsql.qa_module.openai_qa import OpenAIQAModel | |
from nsql.qa_module.vqa import vqa_call | |
from nsql.database import NeuralDB | |
from nsql.parser import get_cfg_tree, get_steps, remove_duplicate, TreeNode, parse_question_paras, nsql_role_recognize, \ | |
extract_answers | |
class NSQLExecutor(object): | |
def __init__(self, args, keys=None): | |
self.new_col_name_id = 0 | |
self.qa_model = OpenAIQAModel(args, keys) | |
def generate_new_col_names(self, number): | |
col_names = ["col_{}".format(i) for i in range(self.new_col_name_id, self.new_col_name_id + number)] | |
self.new_col_name_id += number | |
return col_names | |
def sql_exec(self, sql: str, db: NeuralDB, verbose=True): | |
if verbose: | |
print("Exec SQL '{}' with additional row_id on {}".format(sql, db)) | |
result = db.execute_query(sql) | |
return result | |
def nsql_exec(self, stamp, nsql: str, db: NeuralDB, verbose=True): | |
steps = [] | |
root_node = get_cfg_tree(nsql) # Parse execution tree from nsql. | |
get_steps(root_node, steps) # Flatten the execution tree and get the steps. | |
steps = remove_duplicate(steps) # Remove the duplicate steps. | |
if verbose: | |
print("Steps:", [s.rename for s in steps]) | |
with open("tmp_for_vis/{}_tmp_for_vis_steps.txt".format(stamp), "w") as f: | |
json.dump([s.rename for s in steps], f) | |
col_idx = 0 | |
for step in steps: | |
# All steps should be formatted as 'QA()' except for last step which could also be normal SQL. | |
assert isinstance(step, TreeNode), "step must be treenode" | |
nsql = step.rename | |
if nsql.startswith('QA('): | |
question, sql_s = parse_question_paras(nsql, self.qa_model) | |
sql_executed_sub_tables = [] | |
# Execute all SQLs and get the results as parameters | |
for sql_item in sql_s: | |
role, sql_item = nsql_role_recognize(sql_item, | |
db.get_header(), | |
db.get_passages_titles(), | |
db.get_images_titles()) | |
if role in ['col', 'complete_sql']: | |
sql_executed_sub_table = self.sql_exec(sql_item, db, verbose=verbose) | |
sql_executed_sub_tables.append(sql_executed_sub_table) | |
elif role == 'val': | |
val = eval(sql_item) | |
sql_executed_sub_tables.append({ | |
"header": ["row_id", "val"], | |
"rows": [["0", val]] | |
}) | |
elif role == 'passage_title_and_image_title': | |
sql_executed_sub_tables.append({ | |
"header": ["row_id", "{}".format(sql_item)], | |
"rows": [["0", db.get_passage_by_title(sql_item) + | |
db.get_image_caption_by_title(sql_item) | |
# "{} (The answer of '{}' is {})".format( | |
# sql_item, | |
# # Add image qa result as backup info | |
# question[len("***@"):], | |
# vqa_call(question=question[len("***@"):], | |
# image_path=db.get_image_by_title(sql_item))) | |
]] | |
}) | |
elif role == 'passage_title': | |
sql_executed_sub_tables.append({ | |
"header": ["row_id", "{}".format(sql_item)], | |
"rows": [["0", db.get_passage_by_title(sql_item)]] | |
}) | |
elif role == 'image_title': | |
sql_executed_sub_tables.append({ | |
"header": ["row_id", "{}".format(sql_item)], | |
"rows": [["0", db.get_image_caption_by_title(sql_item)]], | |
# "rows": [["0", "{} (The answer of '{}' is {})".format( | |
# sql_item, | |
# # Add image qa result as backup info | |
# question[len("***@"):], | |
# vqa_call(question=question[len("***@"):], | |
# image_path=db.get_image_by_title(sql_item)))]], | |
}) | |
# If the sub_tables to execute with link, append it to the cell. | |
passage_linker = db.get_passage_linker() | |
image_linker = db.get_image_linker() | |
for _sql_executed_sub_table in sql_executed_sub_tables: | |
for i in range(len(_sql_executed_sub_table['rows'])): | |
for j in range(len(_sql_executed_sub_table['rows'][i])): | |
_cell = _sql_executed_sub_table['rows'][i][j] | |
if _cell in passage_linker.keys(): | |
_sql_executed_sub_table['rows'][i][j] += " ({})".format( | |
# Add passage text as backup info | |
db.get_passage_by_title(passage_linker[_cell])) | |
if _cell in image_linker.keys(): | |
_sql_executed_sub_table['rows'][i][j] += " ({})".format( | |
# Add image caption as backup info | |
db.get_image_caption_by_title(image_linker[_cell])) | |
# _sql_executed_sub_table['rows'][i][j] += " (The answer of '{}' is {})".format( | |
# # Add image qa result as backup info | |
# question[len("***@"):], | |
# vqa_call(question=question[len("***@"):], | |
# image_path=db.get_image_by_title(image_linker[_cell]))) | |
pass | |
if question.lower().startswith("map@"): | |
# When the question is a type of mapping, we return the mapped column. | |
question = question[len("map@"):] | |
if step.father: | |
step.rename_father_col(col_idx=col_idx) | |
sub_table: Dict = self.qa_model.qa(question, | |
sql_executed_sub_tables, | |
table_title=db.table_title, | |
qa_type="map", | |
new_col_name_s=step.produced_col_name_s, | |
verbose=verbose) | |
with open("tmp_for_vis/{}_result_step_{}_input.txt".format(stamp, steps.index(step)), "w") as f: | |
json.dump(sql_executed_sub_tables, f) | |
with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f: | |
json.dump(sub_table, f) | |
db.add_sub_table(sub_table, verbose=verbose) | |
col_idx += 1 | |
else: # This step is the final step | |
sub_table: Dict = self.qa_model.qa(question, | |
sql_executed_sub_tables, | |
table_title=db.table_title, | |
qa_type="map", | |
new_col_name_s=["col_{}".format(col_idx)], | |
verbose=verbose) | |
with open("tmp_for_vis/{}_result_step_{}_input.txt".format(stamp, steps.index(step)), "w") as f: | |
json.dump(sql_executed_sub_tables, f) | |
with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f: | |
json.dump(sub_table, f) | |
return extract_answers(sub_table) | |
elif question.lower().startswith("ans@"): | |
# When the question is a type of answering, we return an answer list. | |
question = question[len("ans@"):] | |
answer: List = self.qa_model.qa(question, | |
sql_executed_sub_tables, | |
table_title=db.table_title, | |
qa_type="ans", | |
verbose=verbose) | |
with open("tmp_for_vis/{}_result_step_{}_input.txt".format(stamp, steps.index(step)), "w") as f: | |
json.dump(sql_executed_sub_tables, f) | |
with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f: | |
json.dump(answer, f) | |
if step.father: | |
step.rename_father_val(answer) | |
else: # This step is the final step | |
return answer | |
else: | |
raise ValueError( | |
"Except for operators or NL question must start with 'map@' or 'ans@'!, check '{}'".format( | |
question)) | |
else: | |
sub_table = self.sql_exec(nsql, db, verbose=verbose) | |
with open("tmp_for_vis/{}_result_step_{}.txt".format(stamp, steps.index(step)), "w") as f: | |
json.dump(sub_table, f) | |
return extract_answers(sub_table) | |