Spaces:
Runtime error
Runtime error
File size: 13,461 Bytes
7fbcea5 e15c8b9 7fbcea5 4c36cd4 e996282 4c36cd4 7fbcea5 4c36cd4 e996282 03de2e8 e996282 03de2e8 5ed186b 4c36cd4 03de2e8 5ed186b 03de2e8 7fbcea5 5ed186b 7fbcea5 a776895 b7e15be 5ed186b a776895 5ed186b b7e15be 5ed186b a14da38 5ed186b a776895 4c36cd4 7fbcea5 4c36cd4 e15c8b9 4c36cd4 e15c8b9 4c36cd4 e15c8b9 7fbcea5 69d7ac6 4c36cd4 7fbcea5 4c36cd4 7fbcea5 4c36cd4 7fbcea5 7cfb21e e996282 7fbcea5 577cb80 5ed186b 03de2e8 5ed186b 82fe24c e996282 00e4b2e b7e15be 69d7ac6 e15c8b9 00e4b2e b7e15be 7fbcea5 e15c8b9 82fe24c b7e15be 82fe24c 00e4b2e 5ed186b 82fe24c 5ed186b 82fe24c a91b925 4c36cd4 7fbcea5 8890bde 7fbcea5 8890bde 00e4b2e 8457196 00e4b2e 8457196 a91b925 7fbcea5 4c36cd4 7fbcea5 4c36cd4 7fbcea5 b7e15be 7fbcea5 e996282 7fbcea5 b7e15be 7fbcea5 4c36cd4 7fbcea5 4c36cd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 |
import streamlit as st
from transformers import pipeline
import requests
from bs4 import BeautifulSoup
import nltk
import string
from streamlit.components.v1 import html
from sentence_transformers.cross_encoder import CrossEncoder as CE
import numpy as np
from typing import List, Tuple
import torch
SCITE_API_KEY = st.secrets["SCITE_API_KEY"]
class CrossEncoder:
def __init__(self, model_path: str, **kwargs):
self.model = CE(model_path, **kwargs)
def predict(self, sentences: List[Tuple[str,str]], batch_size: int = 32, show_progress_bar: bool = True) -> List[float]:
return self.model.predict(
sentences=sentences,
batch_size=batch_size,
show_progress_bar=show_progress_bar)
def remove_html(x):
soup = BeautifulSoup(x, 'html.parser')
text = soup.get_text()
return text
# 4 searches: strict y/n, supported y/n
# deduplicate
# search per query
# options are abstract search
# all search
def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=True, abstract_only=False):
term = clean_query(term, clean=clean, strict=strict)
# heuristic, 2 searches strict and not? and then merge?
# https://api.scite.ai/search?mode=all&term=unit%20testing%20software&limit=10&date_from=2000&date_to=2022&offset=0&supporting_from=1&contrasting_from=0&contrasting_to=0&user_slug=domenic-rosati-keW5&compute_aggregations=true
contexts, docs = [], []
if not abstract_only:
mode = 'all'
if not all_mode:
mode = 'citations'
search = f"https://api.scite.ai/search?mode={mode}&term={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
req = requests.get(
search,
headers={
'Authorization': f'Bearer {SCITE_API_KEY}'
}
)
try:
req.json()
except:
pass
contexts += [remove_html('\n'.join([cite['snippet'] for cite in doc['citations']])) for doc in req.json()['hits']]
docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
for doc in req.json()['hits']]
if abstracts or abstract_only:
search = f"https://api.scite.ai/search?mode=papers&abstract={term}&limit={limit}&offset=0&user_slug=domenic-rosati-keW5&compute_aggregations=false"
req = requests.get(
search,
headers={
'Authorization': f'Bearer {SCITE_API_KEY}'
}
)
try:
req.json()
contexts += [remove_html(doc['abstract'] or '') for doc in req.json()['hits']]
docs += [(doc['doi'], doc['citations'], doc['title'], doc['abstract'] or '')
for doc in req.json()['hits']]
except:
pass
return (
contexts,
docs
)
def find_source(text, docs):
for doc in docs:
for snippet in doc[1]:
if text in remove_html(snippet.get('snippet', '')):
new_text = text
for sent in nltk.sent_tokenize(remove_html(snippet.get('snippet', ''))):
if text in sent:
new_text = sent
return {
'citation_statement': snippet['snippet'].replace('<strong class="highlight">', '').replace('</strong>', ''),
'text': new_text,
'from': snippet['source'],
'supporting': snippet['target'],
'source_title': remove_html(doc[2]),
'source_link': f"https://scite.ai/reports/{doc[0]}"
}
if text in remove_html(doc[3]):
new_text = text
for sent in nltk.sent_tokenize(remove_html(doc[3])):
if text in sent:
new_text = sent
return {
'citation_statement': "ABSTRACT: " + remove_html(doc[3]).replace('<strong class="highlight">', '').replace('</strong>', ''),
'text': new_text,
'from': doc[0],
'supporting': doc[0],
'source_title': "ABSTRACT of " + remove_html(doc[2]),
'source_link': f"https://scite.ai/reports/{doc[0]}"
}
return None
@st.experimental_singleton
def init_models():
nltk.download('stopwords')
from nltk.corpus import stopwords
stop = set(stopwords.words('english') + list(string.punctuation))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
question_answerer = pipeline(
"question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
device=device
)
reranker = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2', device=device)
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
return question_answerer, reranker, stop, device # uqeryexp_model, queryexp_tokenizer
qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
def clean_query(query, strict=True, clean=True):
operator = ' '
if strict:
operator = ' AND '
query = operator.join(
[i for i in query.lower().split(' ') if clean and i not in stop])
if clean:
query = query.translate(str.maketrans('', '', string.punctuation))
return query
def card(title, context, score, link, supporting):
st.markdown(f"""
<div class="container-fluid">
<div class="row align-items-start">
<div class="col-md-12 col-sm-12">
<br>
<span>
{context}
[<b>Score: </b>{score}]
</span>
<br>
<b>From <a href="{link}">{title}</a></b>
</div>
</div>
</div>
""", unsafe_allow_html=True)
html(f"""
<div
class="scite-badge"
data-doi="{supporting}"
data-layout="horizontal"
data-show-zero="false"
data-show-labels="false"
data-tally-show="true"
/>
<script
async
type="application/javascript"
src="https://cdn.scite.ai/badge/scite-badge-latest.min.js">
</script>
""", width=None, height=42, scrolling=False)
st.title("Scientific Question Answering with Citations")
st.write("""
Ask a scientific question and get an answer drawn from [scite.ai](https://scite.ai) corpus of over 1.1bn citation statements.
Answers are linked to source documents containing citations where users can explore further evidence from scientific literature for the answer.
For example try: Do tanning beds cause cancer?
""")
st.markdown("""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-Gn5384xqQ1aoWXA+058RXPxPg6fy4IWvTNh0E263XmFcJlSAwiGgFAW/dAiS6JXm" crossorigin="anonymous">
""", unsafe_allow_html=True)
with st.expander("Settings (strictness, context limit, top hits)"):
support_all = st.radio(
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
('yes', 'no'))
support_abstracts = st.radio(
"Use abstracts as a source document?",
('yes', 'no', 'abstract only'))
strict_lenient_mix = st.radio(
"Type of strict+lenient combination: Fallback or Mix? If fallback, strict is run first then if the results are less than context_lim we also search lenient. Mix will search them both and let reranking sort em out",
('fallback', 'mix'))
confidence_threshold = st.slider('Confidence threshold for answering questions? This number represents how confident the model should be in the answers it gives. The number is out of 100%', 0, 100, 1)
use_reranking = st.radio(
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
('yes', 'no'))
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 10)
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 5)
# def paraphrase(text, max_length=128):
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
# generated_ids = queryexp_model.generate(input_ids=input_ids, num_return_sequences=suggested_queries or 5, num_beams=suggested_queries or 5, max_length=max_length)
# queries = set([queryexp_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids])
# preds = '\n * '.join(queries)
# return preds
def group_results_by_context(results):
result_groups = {}
for result in results:
if result['context'] not in result_groups:
result_groups[result['context']] = result
result_groups[result['context']]['texts'] = []
result_groups[result['context']]['texts'].append(
result['answer']
)
if result['score'] > result_groups[result['context']]['score']:
result_groups[result['context']]['score'] = result['score']
return list(result_groups.values())
def run_query(query):
# if use_query_exp == 'yes':
# query_exp = paraphrase(f"question2question: {query}")
# st.markdown(f"""
# If you are not getting good results try one of:
# * {query_exp}
# """)
# address period in highlitht avoidability. Risk factors
# address poor tokenization Deletions involving chromosome region 4p16.3 cause WolfHirschhorn syndrome (WHS, OMIM 194190) [Battaglia et al, 2001].
# address highlight html
# could also try fallback if there are no good answers by score...
limit = top_hits_limit or 100
context_limit = context_lim or 10
contexts_strict, orig_docs_strict = search(query, limit=limit, strict=True, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only=support_abstracts == 'abstract only')
if strict_lenient_mix == 'fallback' and len(contexts_strict) < context_limit:
contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False, all_mode=support_all == 'yes', abstracts= support_abstracts == 'yes', abstract_only= support_abstracts == 'abstract only')
contexts = list(
set(contexts_strict + contexts_lenient)
)
orig_docs = orig_docs_strict + orig_docs_lenient
elif strict_lenient_mix == 'mix':
contexts_lenient, orig_docs_lenient = search(query, limit=limit, strict=False)
contexts = list(
set(contexts_strict + contexts_lenient)
)
orig_docs = orig_docs_strict + orig_docs_lenient
else:
contexts = list(
set(contexts_strict)
)
orig_docs = orig_docs_strict
if len(contexts) == 0 or not ''.join(contexts).strip():
return st.markdown("""
<div class="container-fluid">
<div class="row align-items-start">
<div class="col-md-12 col-sm-12">
Sorry... no results for that question! Try another...
</div>
</div>
</div>
""", unsafe_allow_html=True)
if use_reranking == 'yes':
sentence_pairs = [[query, context] for context in contexts]
scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
context = '\n'.join(sorted_contexts[:context_limit])
else:
context = '\n'.join(contexts[:context_limit])
results = []
model_results = qa_model(question=query, context=context, top_k=10)
for result in model_results:
support = find_source(result['answer'], orig_docs)
if not support:
continue
results.append({
"answer": support['text'],
"title": support['source_title'],
"link": support['source_link'],
"context": support['citation_statement'],
"score": result['score'],
"doi": support["supporting"]
})
grouped_results = group_results_by_context(results)
sorted_result = sorted(grouped_results, key=lambda x: x['score'], reverse=True)
if confidence_threshold == 0:
threshold = 0
else:
threshold = (confidence_threshold or 10) / 100
sorted_result = filter(
lambda x: x['score'] > threshold,
sorted_result
)
for r in sorted_result:
answer = r["answer"]
ctx = remove_html(r["context"])
for answer in r['texts']:
ctx = ctx.replace(answer, f"<mark>{answer}</mark>")
# .replace( '<cite', '<a').replace('</cite', '</a').replace('data-doi="', 'href="https://scite.ai/reports/')
title = r.get("title", '')
score = round(r["score"], 4)
card(title, ctx, score, r['link'], r['doi'])
query = st.text_input("Ask scientific literature a question", "")
if query != "":
with st.spinner('Loading...'):
run_query(query)
|