File size: 3,605 Bytes
07ab211
 
2c34ee1
07ab211
 
 
 
 
 
 
 
 
 
 
d97f2ca
07ab211
 
 
 
d97f2ca
 
 
 
07ab211
 
 
 
 
 
 
 
 
 
 
 
 
 
bc819a0
07ab211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d97f2ca
 
 
07ab211
 
d97f2ca
07ab211
124530a
b224783
07ab211
 
 
 
 
 
 
 
debd322
 
 
a25958d
2b65cc5
124530a
 
a25958d
2c34ee1
 
a25958d
 
 
2c34ee1
 
 
 
2352b9f
2c34ee1
a25958d
 
2b65cc5
a25958d
 
 
2c34ee1
a25958d
 
 
 
 
 
 
 
 
07ab211
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import pandas as pd
import numpy as np
import datetime, time
import pickle
import glob
import json  
from pandas.io.json import json_normalize  
from nltk.tokenize import sent_tokenize
import nltk
import scipy.spatial
from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering
from sentence_transformers import models, SentenceTransformer
import torch
import spacy
import subprocess
import streamlit as st
from utils import *


@st.cache(allow_output_mutation=True)
def load_spacy_model():
  subprocess.call(['python', '-m','spacy', 'download', 'en_core_web_sm'])
  
@st.cache(allow_output_mutation=True)
def load_prep_data():
  with open('listfile_3.data', 'rb') as filehandle:
    articles = pickle.load(filehandle)

  for article in range(len(articles)):
    if articles[article][1] != []:
      articles[article][1] = sent_tokenize(articles[article][1])
  
  return articles

@st.cache(allow_output_mutation=True)
def build_sent_trans_model():

  word_embedding_model = models.BERT('./')

  # Add the pooling strategy of Mean
  pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                   pooling_mode_mean_tokens=True,
                                   pooling_mode_cls_token=False,
                                   pooling_mode_max_tokens=False)
  
  model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
  return model

@st.cache(allow_output_mutation=True)
def load_embedded_articles():
  with open('list_of_articles.pkl', 'rb') as f:
    list_of_articles = pickle.load(f)

  return list_of_articles

@st.cache(allow_output_mutation=True)
def load_comprehension_model():
  # device is set to -1 to use the available gpu
  comprehension_model = pipeline("question-answering",
                                 model=AutoModelForQuestionAnswering.\
                                 from_pretrained("graviraja/covidbert_squad"),
                                 tokenizer=AutoTokenizer.\
                                 from_pretrained("graviraja/covidbert_squad"),
                                 device=-1)

  return comprehension_model



def main():

  nltk.download('punkt')
  
  load_spacy_model()
  
  spacy_nlp = spacy.load('en_core_web_sm')

  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  
  
  
  embeddings = load_prep_data()

  model = build_sent_trans_model()
  model.to(device)

  list_of_articles = load_embedded_articles()

  comprehension_model = load_comprehension_model()
  
  st.title('Co-Search')
  
  query = st.text_input("Enter Query",'What are the corona viruses?', key="query")
  
  st.write('Using device type: {}'.format(device))
  
  with st.spinner('Please wait...'):
    dt1 = datetime.datetime.now()
    
    query_embedding, results1 = fetch_stage1(query, model, list_of_articles)
    results2 = fetch_stage2(results1, model, embeddings, query_embedding)
    results3 = fetch_stage3(results2, query, embeddings, comprehension_model, spacy_nlp)
    
    dt2 = datetime.datetime.now()
    tdelta = dt2-dt1
    
    st.write('Time taken in minutes: %.2f' % (tdelta.seconds/60))
    
    if results3:
      count = 1
  
      for res in results3:
        st.write('{}> {}'.format(count, res[2]))
        st.write('Score: %.4f' % (res[1]))
        st.write("From the article with title:\n{}".format(embeddings[res[0]][0]))
        st.write("\n")
        
        if count > 3:
          break
        count += 1
    else:
      st.info("There isn't any answer")
      
  st.success('Done!')


if __name__ == '__main__':
	main()