File size: 3,605 Bytes
07ab211 2c34ee1 07ab211 d97f2ca 07ab211 d97f2ca 07ab211 bc819a0 07ab211 d97f2ca 07ab211 d97f2ca 07ab211 124530a b224783 07ab211 debd322 a25958d 2b65cc5 124530a a25958d 2c34ee1 a25958d 2c34ee1 2352b9f 2c34ee1 a25958d 2b65cc5 a25958d 2c34ee1 a25958d 07ab211 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import pandas as pd
import numpy as np
import datetime, time
import pickle
import glob
import json
from pandas.io.json import json_normalize
from nltk.tokenize import sent_tokenize
import nltk
import scipy.spatial
from transformers import AutoTokenizer, AutoModel, pipeline, AutoModelForQuestionAnswering
from sentence_transformers import models, SentenceTransformer
import torch
import spacy
import subprocess
import streamlit as st
from utils import *
@st.cache(allow_output_mutation=True)
def load_spacy_model():
subprocess.call(['python', '-m','spacy', 'download', 'en_core_web_sm'])
@st.cache(allow_output_mutation=True)
def load_prep_data():
with open('listfile_3.data', 'rb') as filehandle:
articles = pickle.load(filehandle)
for article in range(len(articles)):
if articles[article][1] != []:
articles[article][1] = sent_tokenize(articles[article][1])
return articles
@st.cache(allow_output_mutation=True)
def build_sent_trans_model():
word_embedding_model = models.BERT('./')
# Add the pooling strategy of Mean
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True,
pooling_mode_cls_token=False,
pooling_mode_max_tokens=False)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
return model
@st.cache(allow_output_mutation=True)
def load_embedded_articles():
with open('list_of_articles.pkl', 'rb') as f:
list_of_articles = pickle.load(f)
return list_of_articles
@st.cache(allow_output_mutation=True)
def load_comprehension_model():
# device is set to -1 to use the available gpu
comprehension_model = pipeline("question-answering",
model=AutoModelForQuestionAnswering.\
from_pretrained("graviraja/covidbert_squad"),
tokenizer=AutoTokenizer.\
from_pretrained("graviraja/covidbert_squad"),
device=-1)
return comprehension_model
def main():
nltk.download('punkt')
load_spacy_model()
spacy_nlp = spacy.load('en_core_web_sm')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
embeddings = load_prep_data()
model = build_sent_trans_model()
model.to(device)
list_of_articles = load_embedded_articles()
comprehension_model = load_comprehension_model()
st.title('Co-Search')
query = st.text_input("Enter Query",'What are the corona viruses?', key="query")
st.write('Using device type: {}'.format(device))
with st.spinner('Please wait...'):
dt1 = datetime.datetime.now()
query_embedding, results1 = fetch_stage1(query, model, list_of_articles)
results2 = fetch_stage2(results1, model, embeddings, query_embedding)
results3 = fetch_stage3(results2, query, embeddings, comprehension_model, spacy_nlp)
dt2 = datetime.datetime.now()
tdelta = dt2-dt1
st.write('Time taken in minutes: %.2f' % (tdelta.seconds/60))
if results3:
count = 1
for res in results3:
st.write('{}> {}'.format(count, res[2]))
st.write('Score: %.4f' % (res[1]))
st.write("From the article with title:\n{}".format(embeddings[res[0]][0]))
st.write("\n")
if count > 3:
break
count += 1
else:
st.info("There isn't any answer")
st.success('Done!')
if __name__ == '__main__':
main()
|