File size: 3,199 Bytes
07ab211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

import pandas as pd
import numpy as np
import pickle
import glob
import json  
from pandas.io.json import json_normalize  
from nltk.tokenize import sent_tokenize
import nltk
import scipy.spatial
from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering
from sentence_transformers import models, SentenceTransformer
import torch
import spacy
import streamlit as st
from utils import *


@st.cache(allow_output_mutation=True)
def load_prep_data():
  with open('listfile_3.data', 'rb') as filehandle:
    articles = pickle.load(filehandle)

  for article in range(len(articles)):
    if articles[article][1] != []:
      articles[article][1] = sent_tokenize(articles[article][1])
  
  return articles

@st.cache(allow_output_mutation=True)
def build_sent_trans_model():

  word_embedding_model = models.BERT('covidbert_nli')

  # Add the pooling strategy of Mean
  pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                   pooling_mode_mean_tokens=True,
                                   pooling_mode_cls_token=False,
                                   pooling_mode_max_tokens=False)
  
  model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
  return model

@st.cache(allow_output_mutation=True)
def load_embedded_articles():
  with open('list_of_articles.pkl', 'rb') as f:
    list_of_articles = pickle.load(f)

  return list_of_articles

@st.cache(allow_output_mutation=True)
def load_comprehension_model():
  # device is set to -1 to use the available gpu
  comprehension_model = pipeline("question-answering",
                                 model=AutoModelForQuestionAnswering.\
                                 from_pretrained("graviraja/covidbert_squad"),
                                 tokenizer=AutoTokenizer.\
                                 from_pretrained("graviraja/covidbert_squad"),
                                 device=-1)

  return comprehension_model



def main():

  nltk.download('punkt')
  spacy_nlp = spacy.load('en_core_web_sm')

  device = torch.device('cuda:0' if torch.cuda.is_available()
                      else 'cpu')
  
  embeddings = load_prep_data()

  model = build_sent_trans_model()
  model.to(device)

  list_of_articles = load_embedded_articles()

  comprehension_model = load_comprehension_model()

  query = st.text_input("Enter Query",'example query ',key="query")

  query_embedding, results1 = fetch_stage1(query, model, list_of_articles)

  results2 = fetch_stage2(results1, model, embeddings, query_embedding)

  results3 = fetch_stage3(results2, query, embeddings, comprehension_model, spacy_nlp)

  if results3:
    count = 1

    for res in results3:
      st.write('{}> {}'.format(count, res[2]))
      st.write('Score: %.4f' % (res[1]))
      st.write("From the article with title: {}".format(embeddings[res[0]][0]))
      st.write("\n")
      # print(count,". ", res[2], "(Score: %.4f)" % (res[1]))
      # print("From the article with title: ", embeddings[res[0]][0])
      # print("\n")
      if count > 3:
        break
      count += 1
  else:
    st.info("There isn't any answer")


if __name__ == '__main__':
	main()