Spaces:
Runtime error
Runtime error
import os | |
import random | |
from generation.prompt import OpenAIQAPromptBuilder | |
from generation.generator import Generator | |
from retrieval.retriever import OpenAIQARetriever | |
from retrieval.retrieve_pool import OpenAIQARetrievePool, QAItem | |
num_parallel_prompts = 10 | |
num_qa_shots = 8 | |
infinite_rows_len = 50 # If the table contain rows larger than this number, it will be handled rows by rows. | |
max_tokens = 1024 | |
ROOT_DIR = os.path.join(os.path.dirname(__file__), "../../") | |
class OpenAIQAModel(object): | |
def __init__(self, args, keys=None): | |
super().__init__() | |
# Prepare keys | |
self.key_current_id = 0 | |
self.keys = keys | |
random.seed(42) | |
random.shuffle(self.keys) | |
retrieve_pool = OpenAIQARetrievePool( | |
data_path=os.path.join(ROOT_DIR, args.qa_retrieve_pool_file) | |
) | |
self.retriever = OpenAIQARetriever(retrieve_pool) | |
self.generator = Generator(args=None, keys=self.keys) # Just to use its call api function | |
self.prompting_method = 'new_db' | |
self.answer_split_token: str = ';' | |
self.db_mapping_token = "\t" | |
def call_openai_api_completion(self, prompt): | |
completion = self.generator._call_codex_api(engine="text-davinci-002", | |
prompt=prompt, | |
max_tokens=max_tokens, | |
temperature=0, | |
top_p=1, | |
n=1, | |
stop=["\n\n"]) | |
return completion | |
def call_openai_for_completion_text(self, prompt, openai_usage_type="completion"): | |
if openai_usage_type == "completion": | |
completion = self.call_openai_api_completion(prompt) | |
return completion.choices[0].text | |
else: | |
raise ValueError("The model usage type '{}' doesn't exists!".format(openai_usage_type)) | |
def merge_tables(tables, by='row_id'): | |
assert len(set([len(_table['rows']) for _table in tables])) == 1, "Tables must have the same rows!" | |
merged_header = [by] | |
by_idx = tables[0]['header'].index(by) | |
merged_rows = [[_row[by_idx]] for _row in tables[0]['rows']] | |
for _table in tables: | |
header, rows = _table['header'], _table['rows'] | |
for col_idx, col in enumerate(header): | |
if col == by: | |
continue | |
if col in merged_header: | |
# When the column is duplicate, and postfix _0, _1 etc. | |
col = "{}_{}".format(col, merged_header.count(col)) | |
merged_header.append(col) | |
for i, row in enumerate(rows): | |
merged_rows[i].append(row[col_idx]) | |
return {"header": merged_header, "rows": merged_rows} | |
def wrap_with_prompt_for_table_qa(self, | |
question, | |
sub_table, | |
table_title=None, | |
answer_split_token=None, | |
qa_type="ans", | |
prompting_method="new_db", | |
db_mapping_token="π ", | |
verbose=True): | |
prompt = "Question Answering Over Database:\n\n" | |
if qa_type in ['map', 'ans'] and num_qa_shots > 0: | |
query_item = QAItem(qa_question=question, table=sub_table, title=table_title) | |
retrieved_items = self.retriever.retrieve(item=query_item, num_shots=num_qa_shots, qa_type=qa_type) | |
few_shot_prompt_list = [] | |
for item in retrieved_items: | |
one_shot_prompt = OpenAIQAPromptBuilder.build_one_shot_prompt( | |
item=item, | |
answer_split_token=answer_split_token, | |
verbose=verbose, | |
prompting_method=prompting_method, | |
db_mapping_token=db_mapping_token | |
) | |
few_shot_prompt_list.append(one_shot_prompt) | |
few_shot_prompt = '\n'.join(few_shot_prompt_list[:num_qa_shots]) | |
prompt = few_shot_prompt | |
prompt += "\nGive a database as shown below:\n{}\n\n".format( | |
OpenAIQAPromptBuilder.table2codex_prompt(sub_table, table_title) | |
) | |
if qa_type == "map": | |
prompt += "Q: Answer question \"{}\" row by row.".format(question) | |
assert answer_split_token is not None | |
if prompting_method == "basic": | |
prompt += " The answer should be a list split by '{}' and have {} items in total.".format( | |
answer_split_token, len(sub_table['rows'])) | |
elif qa_type == "ans": | |
prompt += "Q: Answer question \"{}\" for the table.".format(question) | |
prompt += " " | |
else: | |
raise ValueError("The QA type is not supported!") | |
prompt += "\n" | |
if qa_type == "map": | |
if prompting_method == "basic": | |
prompt += "A:" | |
elif qa_type == "ans": | |
prompt += "A:" | |
return prompt | |
def qa(self, question, sub_tables, qa_type: str, verbose: bool = True, **args): | |
# If it is not a problem API can handle, answer it with a QA model. | |
merged_table = OpenAIQAModel.merge_tables(sub_tables) | |
if verbose: | |
print("Make Question {} on {}".format(question, merged_table)) | |
if qa_type == "map": | |
# Map: col(s) -question> one col | |
# Make model make a QA towards a sub-table | |
# col(s) -> one col, all QA in one time | |
def do_map(_table): | |
_prompt = self.wrap_with_prompt_for_table_qa(question, | |
_table, | |
args['table_title'], | |
self.answer_split_token, | |
qa_type, | |
prompting_method=self.prompting_method, | |
db_mapping_token=self.db_mapping_token, | |
verbose=verbose) | |
completion_str = self.call_openai_for_completion_text(_prompt).lower().strip(' []') | |
if verbose: | |
print(f'QA map@ input:\n{_prompt}') | |
print(f'QA map@ output:\n{completion_str}') | |
if self.prompting_method == "basic": | |
answers = [_answer.strip(" '").lower() for _answer in | |
completion_str.split(self.answer_split_token)] | |
elif self.prompting_method == "new_db": | |
answers = [line.split(self.db_mapping_token)[-1] for line in completion_str.split("\n")[2:-1]] | |
else: | |
raise ValueError("No such prompting methods: '{}'! ".format(self.prompting_method)) | |
return answers | |
# Handle infinite rows, rows by rows. | |
answers = [] | |
rows_len = len(merged_table['rows']) | |
run_times = int(rows_len / infinite_rows_len) if rows_len % infinite_rows_len == 0 else int( | |
rows_len / infinite_rows_len) + 1 | |
for run_idx in range(run_times): | |
_table = { | |
"header": merged_table['header'], | |
"rows": merged_table['rows'][run_idx * infinite_rows_len:] | |
} if run_idx == run_times - 1 else \ | |
{ | |
"header": merged_table['header'], | |
"rows": merged_table['rows'][run_idx * infinite_rows_len:(run_idx + 1) * infinite_rows_len] | |
} | |
answers.extend(do_map(_table)) | |
if verbose: | |
print("The map@ openai answers are {}".format(answers)) | |
# Add row_id in addition for finding to corresponding rows. | |
return {"header": ['row_id'] + args['new_col_name_s'], | |
"rows": [[row[0], answer] for row, answer in zip(merged_table['rows'], answers)]} | |
elif qa_type == "ans": | |
# Ans: col(s) -question> answer | |
prompt = self.wrap_with_prompt_for_table_qa(question, | |
merged_table, | |
args['table_title'], | |
prompting_method=self.prompting_method, | |
verbose=verbose) | |
answers = [self.call_openai_for_completion_text(prompt).lower().strip(' []')] | |
if verbose: | |
print(f'QA ans@ input:\n{prompt}') | |
print(f'QA ans@ output:\n{answers}') | |
return answers | |
else: | |
raise ValueError("Please choose from map and ans in the qa usage!!") | |