Spaces:
Runtime error
Runtime error
Commit
Β·
1f1e9bd
1
Parent(s):
5cc7b84
don't use concat
Browse files
app.py
CHANGED
@@ -145,7 +145,7 @@ def init_models():
|
|
145 |
"question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
|
146 |
device=device
|
147 |
)
|
148 |
-
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-
|
149 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
150 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
151 |
return question_answerer, reranker, stop, device
|
@@ -211,6 +211,9 @@ st.markdown("""
|
|
211 |
""", unsafe_allow_html=True)
|
212 |
|
213 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
|
|
|
|
|
|
214 |
support_all = st.radio(
|
215 |
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
216 |
('yes', 'no'))
|
@@ -224,8 +227,8 @@ with st.expander("Settings (strictness, context limit, top hits)"):
|
|
224 |
use_reranking = st.radio(
|
225 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
226 |
('yes', 'no'))
|
227 |
-
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300,
|
228 |
-
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300,
|
229 |
|
230 |
# def paraphrase(text, max_length=128):
|
231 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
@@ -313,14 +316,24 @@ def run_query(query):
|
|
313 |
scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
|
314 |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
315 |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
316 |
-
|
317 |
else:
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
results = []
|
321 |
-
|
322 |
-
for result in model_results:
|
323 |
-
|
|
|
|
|
|
|
324 |
support = find_source(result['answer'], orig_docs, matched)
|
325 |
if not support:
|
326 |
continue
|
|
|
145 |
"question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
|
146 |
device=device
|
147 |
)
|
148 |
+
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
|
149 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
150 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
151 |
return question_answerer, reranker, stop, device
|
|
|
211 |
""", unsafe_allow_html=True)
|
212 |
|
213 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
214 |
+
concat_passages = st.radio(
|
215 |
+
"Concatenate passages as one long context?",
|
216 |
+
('no', 'yes'))
|
217 |
support_all = st.radio(
|
218 |
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
219 |
('yes', 'no'))
|
|
|
227 |
use_reranking = st.radio(
|
228 |
"Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
|
229 |
('yes', 'no'))
|
230 |
+
top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 10)
|
231 |
+
context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 5)
|
232 |
|
233 |
# def paraphrase(text, max_length=128):
|
234 |
# input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
|
|
|
316 |
scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
|
317 |
hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
|
318 |
sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
|
319 |
+
contexts = sorted_contexts[:context_limit]
|
320 |
else:
|
321 |
+
contexts = contexts[:context_limit]
|
322 |
+
|
323 |
+
if concat_passages == 'yes':
|
324 |
+
context = '\n---'.join(contexts)
|
325 |
+
model_results = qa_model(question=query, context=context, top_k=10)
|
326 |
+
else:
|
327 |
+
context = ['\n---\n'+ctx for ctx in contexts]
|
328 |
+
model_results = qa_model(question=[query]*len(contexts), context=context)
|
329 |
|
330 |
results = []
|
331 |
+
|
332 |
+
for i, result in enumerate(model_results):
|
333 |
+
if concat_passages == 'yes':
|
334 |
+
matched = matched_context(result['start'], result['end'], context)
|
335 |
+
else:
|
336 |
+
matched = matched_context(result['start'], result['end'], context[i])
|
337 |
support = find_source(result['answer'], orig_docs, matched)
|
338 |
if not support:
|
339 |
continue
|