File size: 3,398 Bytes
07ab211
 
 
 
 
 
 
 
 
 
 
 
 
d97f2ca
07ab211
 
 
 
d97f2ca
 
 
 
07ab211
 
 
 
 
 
 
 
 
 
 
 
 
 
bc819a0
07ab211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d97f2ca
 
 
07ab211
 
d97f2ca
07ab211
 
 
 
 
 
 
 
 
 
d97f2ca
2b65cc5
 
 
07ab211
 
 
 
 
2b65cc5
 
 
07ab211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import pandas as pd
import numpy as np
import pickle
import glob
import json  
from pandas.io.json import json_normalize  
from nltk.tokenize import sent_tokenize
import nltk
import scipy.spatial
from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering
from sentence_transformers import models, SentenceTransformer
import torch
import spacy
import subprocess
import streamlit as st
from utils import *


@st.cache(allow_output_mutation=True)
def load_spacy_model():
  subprocess.call(['python', '-m','spacy', 'download', 'en_core_web_sm'])
  
@st.cache(allow_output_mutation=True)
def load_prep_data():
  with open('listfile_3.data', 'rb') as filehandle:
    articles = pickle.load(filehandle)

  for article in range(len(articles)):
    if articles[article][1] != []:
      articles[article][1] = sent_tokenize(articles[article][1])
  
  return articles

@st.cache(allow_output_mutation=True)
def build_sent_trans_model():

  word_embedding_model = models.BERT('./')

  # Add the pooling strategy of Mean
  pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                   pooling_mode_mean_tokens=True,
                                   pooling_mode_cls_token=False,
                                   pooling_mode_max_tokens=False)
  
  model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
  return model

@st.cache(allow_output_mutation=True)
def load_embedded_articles():
  with open('list_of_articles.pkl', 'rb') as f:
    list_of_articles = pickle.load(f)

  return list_of_articles

@st.cache(allow_output_mutation=True)
def load_comprehension_model():
  # device is set to -1 to use the available gpu
  comprehension_model = pipeline("question-answering",
                                 model=AutoModelForQuestionAnswering.\
                                 from_pretrained("graviraja/covidbert_squad"),
                                 tokenizer=AutoTokenizer.\
                                 from_pretrained("graviraja/covidbert_squad"),
                                 device=-1)

  return comprehension_model



def main():

  nltk.download('punkt')
  
  load_spacy_model()
  
  spacy_nlp = spacy.load('en_core_web_sm')

  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  
  embeddings = load_prep_data()

  model = build_sent_trans_model()
  model.to(device)

  list_of_articles = load_embedded_articles()

  comprehension_model = load_comprehension_model()

  query = st.text_input("Enter Query",'example query ', key="query")
  
  st.write(query)
  
  query_embedding, results1 = fetch_stage1(query, model, list_of_articles)

  results2 = fetch_stage2(results1, model, embeddings, query_embedding)

  results3 = fetch_stage3(results2, query, embeddings, comprehension_model, spacy_nlp)
  
  st.write('Hello')
  
  if results3:
    count = 1

    for res in results3:
      st.write('{}> {}'.format(count, res[2]))
      st.write('Score: %.4f' % (res[1]))
      st.write("From the article with title: {}".format(embeddings[res[0]][0]))
      st.write("\n")
      # print(count,". ", res[2], "(Score: %.4f)" % (res[1]))
      # print("From the article with title: ", embeddings[res[0]][0])
      # print("\n")
      if count > 3:
        break
      count += 1
  else:
    st.info("There isn't any answer")


if __name__ == '__main__':
	main()