import pandas as pd
import numpy as np
import pickle
import glob
import json  
from pandas.io.json import json_normalize  
from nltk.tokenize import sent_tokenize
import nltk
import scipy.spatial
from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering
from sentence_transformers import models, SentenceTransformer


def get_full_sentence(spacy_nlp, para_text, start_index, end_index):
  """
  Returns the relative sentence of original text, 
  given a specific paragraph (body text).
  """
  sent_start = 0
  sent_end = len(para_text)
  for sent in spacy_nlp(para_text).sents:
      if (sent.start_char <= start_index) and (sent.end_char >= start_index):
          sent_start = sent.start_char
      if (sent.start_char <= end_index) and (sent.end_char >= end_index):
          sent_end = sent.end_char
  sentence = para_text[sent_start:sent_end + 1]
  return sentence


def fetch_stage1(query, model, list_of_articles):
  """
  Compare all the articles' abstract content with each query
  """

  # Encode queries
  query_embedding = model.encode([query])[0]

  
  all_abs_distances = []

  for idx_of_article,article in enumerate(list_of_articles):
    if article:
      distances = []
      cdists = scipy.spatial.distance.cdist([query_embedding], np.vstack(article), "cosine").reshape(-1,1)
      for idx,sentence in enumerate(article):
        distances.append((idx, 1 - cdists[idx][0]))

      results = sorted(distances, key=lambda x: x[1], reverse=True)
      if results:
        all_abs_distances.append((idx_of_article, results[0][0], results[0][1]))

  results = sorted(all_abs_distances, key=lambda x: x[2], reverse=True)

  return query_embedding, results


def fetch_stage2(results, model, embeddings, query_embedding):
  """
  Take the 20 most similar articles, based on the relevant abstracts and
  compare all the body texts content to the query
  """
  
  all_text_distances = []
  for top in results[0:20]:
    article_idx = top[0]

    body_texts = [text[0] for text in embeddings[article_idx][2]]
    body_text_embeddings = model.encode(body_texts, show_progress_bar=False)

    # body_text_distances = []
    # for text_idx,text in enumerate(embeddings[article_idx][2]):

    qbody = scipy.spatial.distance.cdist([query_embedding],
                                          np.vstack(body_text_embeddings),
                                          "cosine").reshape(-1,1)
    
    body_text_distances = [(idx, 1-dist[0]) for idx,dist in enumerate(qbody)]
    
    # for text_idx,text in enumerate(body_texts):
    #   # Encode only the body texts of 20 best articles 
    #   # body_text_embedding = model.encode(text, show_progress_bar=False)

    #   body_text_distances.append(((text_idx, 
    #                                (1 - ([0]))
    #                                )))

    results = sorted(body_text_distances, key=lambda x: x[1], reverse=True)

    if results:
      all_text_distances.append((article_idx, results[0][0], results[0][1]))

  results = sorted(all_text_distances, key=lambda x: x[2], reverse=True)

  return results


def fetch_stage3(results, query, embeddings, comprehension_model, spacy_nlp):
  """
  For the top 20 retrieved paragraphs in the document,
  answer will be comprehended on each paragraph using the model.
  """
 
  answers = []

  # contxt = [embeddings[top_text[0]][2][top_text[1]][0] for top_text in results[0:20]]

  for top_text in results[0:20]:
    article_idx = top_text[0]
    body_text_idx = top_text[1]

    query_ = {"context": embeddings[article_idx][2][body_text_idx][0], "question": query}
    pred = comprehension_model(query_, topk=1, show_progress_bar=False)

    # If there is any answer
    if pred["answer"] and round(pred["score"], 4) > 0:
        # Take the suitable sentence from the paragraph
        sent = get_full_sentence(spacy_nlp, query_['context'], pred["start"], pred["end"])
        answers.append((article_idx, round(pred["score"], 4), sent))

  results = sorted(answers, key=lambda x: x[1], reverse=True)

  return results