Spaces:
Running
Running
import gradio as gr | |
description = "Do you have a long document and bunch of questions that can be answered given the data in this file? Fear not because following demo can do it for you. Upload your pdf, ask question and wait for the magic to happen." | |
title = "QA answering from a pdf" | |
import numpy as np | |
import time | |
import hashlib | |
import torch | |
from transformers import AutoTokenizer, AutoModel, AutoModelForQuestionAnswering, pipeline | |
from tqdm import tqdm | |
import os | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
import textract | |
from scipy.special import softmax | |
import pandas as pd | |
from datetime import datetime | |
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1") | |
model = AutoModel.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1").to(device).eval() | |
tokenizer_ans = AutoTokenizer.from_pretrained("deepset/roberta-large-squad2") | |
model_ans = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-large-squad2").to(device).eval() | |
if device == 'cuda:0': | |
pipe = pipeline("question-answering",model_ans,tokenizer =tokenizer_ans,device = 0) | |
else: | |
pipe = pipeline("question-answering",model_ans,tokenizer =tokenizer_ans) | |
def cls_pooling(model_output): | |
return model_output.last_hidden_state[:,0] | |
def encode_query(query): | |
encoded_input = tokenizer(query, truncation=True, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
model_output = model(**encoded_input, return_dict=True) | |
embeddings = cls_pooling(model_output) | |
return embeddings.cpu() | |
def encode_docs(docs,maxlen = 64, stride = 32): | |
encoded_input = [] | |
embeddings = [] | |
spans = [] | |
file_names = [] | |
name, text = docs | |
text = text.split(" ") | |
if len(text) < maxlen: | |
text = " ".join(text) | |
encoded_input.append(tokenizer(temp_text, return_tensors='pt', truncation = True).to(device)) | |
spans.append(temp_text) | |
file_names.append(name) | |
else: | |
num_iters = int(len(text)/maxlen)+1 | |
for i in range(num_iters): | |
if i == 0: | |
temp_text = " ".join(text[i*maxlen:(i+1)*maxlen+stride]) | |
else: | |
temp_text = " ".join(text[(i-1)*maxlen:(i)*maxlen][-stride:] + text[i*maxlen:(i+1)*maxlen]) | |
encoded_input.append(tokenizer(temp_text, return_tensors='pt', truncation = True).to(device)) | |
spans.append(temp_text) | |
file_names.append(name) | |
with torch.no_grad(): | |
for encoded in tqdm(encoded_input): | |
model_output = model(**encoded, return_dict=True) | |
embeddings.append(cls_pooling(model_output)) | |
embeddings = np.float32(torch.stack(embeddings).transpose(0, 1).cpu()) | |
np.save("emb_{}.npy".format(name),dict(zip(list(range(len(embeddings))),embeddings))) | |
np.save("spans_{}.npy".format(name),dict(zip(list(range(len(spans))),spans))) | |
np.save("file_{}.npy".format(name),dict(zip(list(range(len(file_names))),file_names))) | |
return embeddings, spans, file_names | |
def predict(query,data): | |
name_to_save = data.name.split("/")[-1].split(".")[0][:-8] | |
print(name_to_save) | |
st = str([query,name_to_save]) | |
hist = st + " " + str(hashlib.sha256(st.encode()).hexdigest()) | |
now = datetime.now() | |
current_time = now.strftime("%H:%M:%S") | |
try: | |
df = pd.read_csv("{}.csv".format(hash(st))) | |
return df | |
except Exception as e: | |
print(e) | |
print(st) | |
if name_to_save+".txt" in os.listdir(): | |
doc_emb = np.load('emb_{}.npy'.format(name_to_save),allow_pickle='TRUE').item() | |
doc_text = np.load('spans_{}.npy'.format(name_to_save),allow_pickle='TRUE').item() | |
file_names_dicto = np.load('file_{}.npy'.format(name_to_save),allow_pickle='TRUE').item() | |
doc_emb = np.array(list(doc_emb.values())).reshape(-1,768) | |
doc_text = list(doc_text.values()) | |
file_names = list(file_names_dicto.values()) | |
else: | |
text = textract.process("{}".format(data.name)).decode('utf8') | |
text = text.replace("\r", " ") | |
text = text.replace("\n", " ") | |
text = text.replace(" . "," ") | |
doc_emb, doc_text, file_names = encode_docs((name_to_save,text),maxlen = 64, stride = 32) | |
doc_emb = doc_emb.reshape(-1, 768) | |
with open("{}.txt".format(name_to_save),"w",encoding="utf-8") as f: | |
f.write(text) | |
start = time.time() | |
query_emb = encode_query(query) | |
scores = np.matmul(query_emb, doc_emb.transpose(1,0))[0].tolist() | |
doc_score_pairs = list(zip(doc_text, scores, file_names)) | |
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True) | |
k = 5 | |
probs_sum = 0 | |
probs = softmax(sorted(scores,reverse = True)[:k]) | |
table = {"Passage":[],"Answer":[],"Probabilities":[],"Source":[]} | |
for i, (passage, _, names) in enumerate(doc_score_pairs[:k]): | |
passage = passage.replace("\n","") | |
passage = passage.replace(" . "," ") | |
if probs[i] > 0.1 or (i < 3 and probs[i] > 0.05): #generate answers for more likely passages but no less than 2 | |
QA = {'question':query,'context':passage} | |
ans = pipe(QA) | |
probabilities = "P(a|p): {}, P(a|p,q): {}, P(p|q): {}".format(round(ans["score"],5), | |
round(ans["score"]*probs[i],5), | |
round(probs[i],5)) | |
passage = passage.replace(str(ans["answer"]),str(ans["answer"]).upper()) | |
table["Passage"].append(passage) | |
table["Passage"].append("---") | |
table["Answer"].append(str(ans["answer"]).upper()) | |
table["Answer"].append("---") | |
table["Probabilities"].append(probabilities) | |
table["Probabilities"].append("---") | |
table["Source"].append(names) | |
table["Source"].append("---") | |
else: | |
table["Passage"].append(passage) | |
table["Passage"].append("---") | |
table["Answer"].append("no_answer_calculated") | |
table["Answer"].append("---") | |
table["Probabilities"].append("P(p|q): {}".format(round(probs[i],5))) | |
table["Probabilities"].append("---") | |
table["Source"].append(names) | |
table["Source"].append("---") | |
df = pd.DataFrame(table) | |
print("time: "+ str(time.time()-start)) | |
with open("HISTORY.txt","a", encoding = "utf-8") as f: | |
f.write(hist) | |
f.write(" " + str(current_time)) | |
f.write("\n") | |
f.close() | |
df.to_csv("{}.csv".format(hash(st)), index=False) | |
return df | |
iface = gr.Interface(examples = [ | |
["How high is the highest mountain?","China.pdf"], | |
["Where does UK prime minister live?","London.pdf"] | |
], | |
fn =predict, | |
inputs = [gr.inputs.Textbox(default="What is Open-domain question answering?"), | |
gr.inputs.File(), | |
], | |
outputs = [ | |
gr.outputs.Dataframe(), | |
], | |
allow_flagging ="manual",flagging_options = ["correct","wrong"], | |
allow_screenshot=False) | |
iface.launch(share = True,enable_queue=True, show_error =True) |