Spaces:
Runtime error
Runtime error
Commit
Β·
5cc7b84
1
Parent(s):
f1fd3e1
remove summarization
Browse files
app.py
CHANGED
@@ -78,7 +78,6 @@ def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=Tru
|
|
78 |
except:
|
79 |
pass
|
80 |
|
81 |
-
|
82 |
return (
|
83 |
contexts,
|
84 |
docs
|
@@ -149,11 +148,9 @@ def init_models():
|
|
149 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
|
150 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
151 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
152 |
-
|
153 |
-
summ_mdl = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384-ms2')
|
154 |
-
return question_answerer, reranker, stop, device, summ_mdl, summ_tok
|
155 |
|
156 |
-
qa_model, reranker, stop, device
|
157 |
|
158 |
|
159 |
def clean_query(query, strict=True, clean=True):
|
@@ -214,9 +211,6 @@ st.markdown("""
|
|
214 |
""", unsafe_allow_html=True)
|
215 |
|
216 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
217 |
-
use_mds = st.radio(
|
218 |
-
"Use multi-document summarization to summarize answer?",
|
219 |
-
('yes', 'no'))
|
220 |
support_all = st.radio(
|
221 |
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
222 |
('yes', 'no'))
|
@@ -271,77 +265,6 @@ def matched_context(start_i, end_i, contexts_string, seperator='---'):
|
|
271 |
return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
|
272 |
return None
|
273 |
|
274 |
-
def process_document(documents, tokenizer, docsep_token_id, pad_token_id, device=device):
|
275 |
-
input_ids_all=[]
|
276 |
-
for data in documents:
|
277 |
-
all_docs = data.split("|||||")
|
278 |
-
for i, doc in enumerate(all_docs):
|
279 |
-
doc = doc.replace("\n", " ")
|
280 |
-
doc = " ".join(doc.split())
|
281 |
-
all_docs[i] = doc
|
282 |
-
|
283 |
-
#### concat with global attention on doc-sep
|
284 |
-
input_ids = []
|
285 |
-
for doc in all_docs:
|
286 |
-
input_ids.extend(
|
287 |
-
tokenizer.encode(
|
288 |
-
doc,
|
289 |
-
truncation=True,
|
290 |
-
max_length=4096 // len(all_docs),
|
291 |
-
)[1:-1]
|
292 |
-
)
|
293 |
-
input_ids.append(docsep_token_id)
|
294 |
-
input_ids = (
|
295 |
-
[tokenizer.bos_token_id]
|
296 |
-
+ input_ids
|
297 |
-
+ [tokenizer.eos_token_id]
|
298 |
-
)
|
299 |
-
input_ids_all.append(torch.tensor(input_ids))
|
300 |
-
input_ids = torch.nn.utils.rnn.pad_sequence(
|
301 |
-
input_ids_all, batch_first=True, padding_value=pad_token_id
|
302 |
-
)
|
303 |
-
return input_ids
|
304 |
-
|
305 |
-
|
306 |
-
def batch_process(batch, model, tokenizer, docsep_token_id, pad_token_id, device=device):
|
307 |
-
input_ids=process_document(batch['document'], tokenizer, docsep_token_id, pad_token_id)
|
308 |
-
# get the input ids and attention masks together
|
309 |
-
global_attention_mask = torch.zeros_like(input_ids).to(device)
|
310 |
-
input_ids = input_ids.to(device)
|
311 |
-
# put global attention on <s> token
|
312 |
-
|
313 |
-
global_attention_mask[:, 0] = 1
|
314 |
-
global_attention_mask[input_ids == docsep_token_id] = 1
|
315 |
-
generated_ids = model.generate(
|
316 |
-
input_ids=input_ids,
|
317 |
-
global_attention_mask=global_attention_mask,
|
318 |
-
use_cache=True,
|
319 |
-
max_length=1024,
|
320 |
-
num_beams=5,
|
321 |
-
)
|
322 |
-
generated_str = tokenizer.batch_decode(
|
323 |
-
generated_ids.tolist(), skip_special_tokens=True
|
324 |
-
)
|
325 |
-
result={}
|
326 |
-
result['generated_summaries'] = generated_str
|
327 |
-
return result
|
328 |
-
|
329 |
-
|
330 |
-
def gen_summary(query, sorted_result):
|
331 |
-
pad_token_id = summ_tok.pad_token_id
|
332 |
-
docsep_token_id = summ_tok.convert_tokens_to_ids("</s>")
|
333 |
-
out = batch_process({ 'document': [f'||||'.join([f'{query} '.join(r['texts']) + r['context'] for r in sorted_result])]}, summ_mdl, summ_tok, docsep_token_id, pad_token_id)
|
334 |
-
st.markdown(f"""
|
335 |
-
<div class="container-fluid">
|
336 |
-
<div class="row align-items-start">
|
337 |
-
<div class="col-md-12 col-sm-12">
|
338 |
-
<strong>Answer:</strong> {out['generated_summaries'][0]}
|
339 |
-
</div>
|
340 |
-
</div>
|
341 |
-
</div>
|
342 |
-
""", unsafe_allow_html=True)
|
343 |
-
st.markdown("<br /><br /><h5>Sources:</h5>", unsafe_allow_html=True)
|
344 |
-
|
345 |
|
346 |
def run_query(query):
|
347 |
# if use_query_exp == 'yes':
|
@@ -395,7 +318,7 @@ def run_query(query):
|
|
395 |
context = '\n---'.join(contexts[:context_limit])
|
396 |
|
397 |
results = []
|
398 |
-
model_results = qa_model(question=query, context=context, top_k=10)
|
399 |
for result in model_results:
|
400 |
matched = matched_context(result['start'], result['end'], context)
|
401 |
support = find_source(result['answer'], orig_docs, matched)
|
@@ -423,9 +346,6 @@ def run_query(query):
|
|
423 |
sorted_result
|
424 |
))
|
425 |
|
426 |
-
if use_mds == 'yes':
|
427 |
-
gen_summary(query, sorted_result)
|
428 |
-
|
429 |
for r in sorted_result:
|
430 |
ctx = remove_html(r["context"])
|
431 |
for answer in r['texts']:
|
|
|
78 |
except:
|
79 |
pass
|
80 |
|
|
|
81 |
return (
|
82 |
contexts,
|
83 |
docs
|
|
|
148 |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
|
149 |
# queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
150 |
# queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
|
151 |
+
return question_answerer, reranker, stop, device
|
|
|
|
|
152 |
|
153 |
+
qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
|
154 |
|
155 |
|
156 |
def clean_query(query, strict=True, clean=True):
|
|
|
211 |
""", unsafe_allow_html=True)
|
212 |
|
213 |
with st.expander("Settings (strictness, context limit, top hits)"):
|
|
|
|
|
|
|
214 |
support_all = st.radio(
|
215 |
"Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
|
216 |
('yes', 'no'))
|
|
|
265 |
return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
|
266 |
return None
|
267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
def run_query(query):
|
270 |
# if use_query_exp == 'yes':
|
|
|
318 |
context = '\n---'.join(contexts[:context_limit])
|
319 |
|
320 |
results = []
|
321 |
+
model_results = qa_model(question=query, context=query+'---'+context, top_k=10)
|
322 |
for result in model_results:
|
323 |
matched = matched_context(result['start'], result['end'], context)
|
324 |
support = find_source(result['answer'], orig_docs, matched)
|
|
|
346 |
sorted_result
|
347 |
))
|
348 |
|
|
|
|
|
|
|
349 |
for r in sorted_result:
|
350 |
ctx = remove_html(r["context"])
|
351 |
for answer in r['texts']:
|