import gradio as gr
import numpy as np
import time
import hashlib
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForQuestionAnswering, pipeline
from tqdm import tqdm
import os
device = "cuda:0" if torch.cuda.is_available() else "cpu"
import textract
from scipy.special import softmax
import pandas as pd
from datetime import datetime
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1")
model = AutoModel.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1").to(device).eval()
tokenizer_ans = AutoTokenizer.from_pretrained("deepset/roberta-large-squad2")
model_ans = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-large-squad2").to(device).eval()
if device == 'cuda:0':
    pipe = pipeline("question-answering",model_ans,tokenizer =tokenizer_ans,device = 0)
else:
    pipe = pipeline("question-answering",model_ans,tokenizer =tokenizer_ans)
    
def cls_pooling(model_output):
    return model_output.last_hidden_state[:,0]

def encode_query(query):
    encoded_input = tokenizer(query, truncation=True, return_tensors='pt').to(device)

    with torch.no_grad():
        model_output = model(**encoded_input, return_dict=True)

    embeddings = cls_pooling(model_output)

    return embeddings.cpu()


def encode_docs(docs,maxlen = 64, stride = 32):
    encoded_input = []
    embeddings = []
    spans = []
    file_names = []
    name, text = docs
    
    text = text.split(" ")
    if len(text) < maxlen:
        text = " ".join(text)
        
        encoded_input.append(tokenizer(temp_text, return_tensors='pt', truncation = True).to(device))
        spans.append(temp_text)
        file_names.append(name)

    else:
        num_iters = int(len(text)/maxlen)+1
        for i in range(num_iters):
            if i == 0:
                temp_text = " ".join(text[i*maxlen:(i+1)*maxlen+stride])
            else:
                temp_text = " ".join(text[(i-1)*maxlen:(i)*maxlen][-stride:] + text[i*maxlen:(i+1)*maxlen])

            encoded_input.append(tokenizer(temp_text, return_tensors='pt', truncation = True).to(device))
            spans.append(temp_text)
            file_names.append(name)

    with torch.no_grad():
        for encoded in tqdm(encoded_input): 
            model_output = model(**encoded, return_dict=True)
            embeddings.append(cls_pooling(model_output))
    
    embeddings = np.float32(torch.stack(embeddings).transpose(0, 1).cpu())
    
    np.save("emb_{}.npy".format(name),dict(zip(list(range(len(embeddings))),embeddings))) 
    np.save("spans_{}.npy".format(name),dict(zip(list(range(len(spans))),spans)))
    np.save("file_{}.npy".format(name),dict(zip(list(range(len(file_names))),file_names)))
    
    return embeddings, spans, file_names
   
def predict(query,data):
    name_to_save = data.name.split("\\")[-1].split(".")[0][:-8]
    st = str([query,name_to_save])
    hist = st + " " + str(hashlib.sha256(st.encode()).hexdigest())
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    try:
        df = pd.read_csv("{}.csv".format(hash(st)))
        return df
    except Exception as e:
        print(e)
        print(st)

    if name_to_save+".txt" in os.listdir():
        doc_emb = np.load('emb_{}.npy'.format(name_to_save),allow_pickle='TRUE').item()
        doc_text = np.load('spans_{}.npy'.format(name_to_save),allow_pickle='TRUE').item()
        file_names_dicto = np.load('file_{}.npy'.format(name_to_save),allow_pickle='TRUE').item()
        
        doc_emb = np.array(list(doc_emb.values())).reshape(-1,768)
        doc_text = list(doc_text.values())
        file_names = list(file_names_dicto.values())
    
    else:
        text = textract.process("{}".format(data.name)).decode('utf8')
        text = text.replace("\r", " ")
        text = text.replace("\n", " ")
        text = text.replace(" . "," ")
        
        doc_emb, doc_text, file_names = encode_docs((name_to_save,text),maxlen = 64, stride = 32)
        
        doc_emb = doc_emb.reshape(-1, 768)
        with open("{}.txt".format(name_to_save),"w",encoding="utf-8") as f:
            f.write(text)
    start = time.time()
    query_emb = encode_query(query)
    
    scores = np.matmul(query_emb, doc_emb.transpose(1,0))[0].tolist()
    doc_score_pairs = list(zip(doc_text, scores, file_names))
    doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
    k = 5
    probs_sum = 0
    probs = softmax(sorted(scores,reverse = True)[:k])
    table = {"Passage":[],"Answer":[],"Probabilities":[],"Source":[]}
    
    for i, (passage, _, names) in enumerate(doc_score_pairs[:k]):
        passage = passage.replace("\n","")
        passage = passage.replace(" . "," ")
        
        if probs[i] > 0.1 or (i < 3 and probs[i] > 0.05): #generate answers for more likely passages but no less than 2
            QA = {'question':query,'context':passage}
            ans = pipe(QA)
            probabilities = "P(a|p): {}, P(a|p,q): {}, P(p|q): {}".format(round(ans["score"],5), 
                                                                          round(ans["score"]*probs[i],5), 
                                                                          round(probs[i],5))
            passage = passage.replace(str(ans["answer"]),str(ans["answer"]).upper()) 
            table["Passage"].append(passage)
            table["Passage"].append("---")
            table["Answer"].append(str(ans["answer"]).upper())
            table["Answer"].append("---")
            table["Probabilities"].append(probabilities)
            table["Probabilities"].append("---")
            table["Source"].append(names)
            table["Source"].append("---")
        else:
            table["Passage"].append(passage)
            table["Passage"].append("---")
            table["Answer"].append("no_answer_calculated")
            table["Answer"].append("---")
            table["Probabilities"].append("P(p|q): {}".format(round(probs[i],5)))
            table["Probabilities"].append("---")
            table["Source"].append(names)
            table["Source"].append("---")
    df = pd.DataFrame(table)
    print("time: "+ str(time.time()-start))
    
    with open("HISTORY.txt","a", encoding = "utf-8") as f:
        f.write(hist)
        f.write(" " + str(current_time))
        f.write("\n")
        f.close()
    df.to_csv("{}.csv".format(hash(st)), index=False)
    
    return df

iface = gr.Interface(examples = [
        ["How high is the highest mountain?","China.pdf"], 
        ["Where does UK prime minister live?","London.pdf"]
    ],
    
    fn =predict,
    inputs = [gr.inputs.Textbox(default="What is Open-domain question answering?"),
              gr.inputs.File(),
    ],
    outputs = [
        gr.outputs.Dataframe(),
            ],
    
allow_flagging ="manual",flagging_options = ["correct","wrong"],
                     allow_screenshot=False)

iface.launch(share = True,enable_queue=True, show_error =True)