Spaces:
Sleeping
Sleeping
feat: huggingface space pipeline with resrer model
Browse files- app.py +97 -2
- model.py +86 -0
- requirements.txt +3 -0
app.py
CHANGED
|
@@ -1,2 +1,97 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from pymilvus import MilvusClient
|
| 5 |
+
|
| 6 |
+
from model import encode_dpr_question, get_dpr_encoder
|
| 7 |
+
from model import summarize_text, get_summarizer
|
| 8 |
+
from model import ask_reader, get_reader
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
TITLE = 'ReSRer: Retriever-Summarizer-Reader'
|
| 12 |
+
INITIAL = "What is the population of NYC"
|
| 13 |
+
|
| 14 |
+
st.set_page_config(page_title=TITLE)
|
| 15 |
+
st.header(TITLE)
|
| 16 |
+
st.markdown('''
|
| 17 |
+
### Ask short-answer question that can be find in Wikipedia data.
|
| 18 |
+
''', unsafe_allow_html=True)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@st.cache_resource
|
| 22 |
+
def load_models():
|
| 23 |
+
models = {}
|
| 24 |
+
models['encoder'] = get_dpr_encoder()
|
| 25 |
+
models['summarizer'] = get_summarizer()
|
| 26 |
+
models['reader'] = get_reader()
|
| 27 |
+
return models
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@st.cache_resource
|
| 31 |
+
def load_client():
|
| 32 |
+
client = MilvusClient(user='resrer', password=os.env['MILVUS_PW'],
|
| 33 |
+
uri=f"http://{os.env['MILVUS_HOST']}:19530", db_name='psgs_w100')
|
| 34 |
+
return client
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
client = load_client()
|
| 38 |
+
models = load_models()
|
| 39 |
+
|
| 40 |
+
styl = """
|
| 41 |
+
<style>
|
| 42 |
+
.StatusWidget-enter-done{
|
| 43 |
+
position: fixed;
|
| 44 |
+
left: 50%;
|
| 45 |
+
top: 50%;
|
| 46 |
+
transform: translate(-50%, -50%);
|
| 47 |
+
}
|
| 48 |
+
.StatusWidget-enter-done button{
|
| 49 |
+
display: none;
|
| 50 |
+
}
|
| 51 |
+
</style>
|
| 52 |
+
"""
|
| 53 |
+
st.markdown(styl, unsafe_allow_html=True)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
question = st.text_area("Text to summarize", INITIAL, height=400)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def main(question: str):
|
| 60 |
+
if question in st.session_state:
|
| 61 |
+
print("Cache hit!")
|
| 62 |
+
ctx, summary, answer = st.session_state[question]
|
| 63 |
+
else:
|
| 64 |
+
print(f"Input: {question}")
|
| 65 |
+
# Embedding
|
| 66 |
+
question_vectors = encode_dpr_question(
|
| 67 |
+
models['encoder'][0], models['encoder'][1], [question])
|
| 68 |
+
query_vector = question_vectors.detach().cpu().numpy().tolist()[0]
|
| 69 |
+
|
| 70 |
+
# Retriever
|
| 71 |
+
results = client.search(collection_name='dpr_nq', data=[
|
| 72 |
+
query_vector], limit=10, output_fields=['title', 'text'])
|
| 73 |
+
texts = [result['entity']['text'] for result in results[0]]
|
| 74 |
+
ctx = '\n'.join(texts)
|
| 75 |
+
|
| 76 |
+
# Reader
|
| 77 |
+
summary = summarize_text(models['summarizer'][0],
|
| 78 |
+
models['summarizer'][1], [summary])
|
| 79 |
+
answers = ask_reader(models['reader'][0],
|
| 80 |
+
models['reader'][1], [question], [ctx])
|
| 81 |
+
answer = answers[0]['answer']
|
| 82 |
+
print(f"\nAnswer: {answer}")
|
| 83 |
+
|
| 84 |
+
st.session_state[question] = (ctx, summary, answer)
|
| 85 |
+
|
| 86 |
+
# Summary
|
| 87 |
+
st.markdown(answer)
|
| 88 |
+
st.write("## Summary")
|
| 89 |
+
st.markdown(
|
| 90 |
+
f"<h6 style='padding: 0'>{summary}</h6><hr style='margin: 1em 0px'>", unsafe_allow_html=True)
|
| 91 |
+
st.markdown(ctx)
|
| 92 |
+
|
| 93 |
+
st.write(f"{question}", unsafe_allow_html=True)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
if question:
|
| 97 |
+
main(question)
|
model.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, TypedDict
|
| 2 |
+
from re import sub
|
| 3 |
+
|
| 4 |
+
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, logging
|
| 5 |
+
from transformers import AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader
|
| 6 |
+
from transformers import QuestionAnsweringPipeline
|
| 7 |
+
from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
max_answer_len = 8
|
| 11 |
+
logging.set_verbosity_error()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration,
|
| 15 |
+
input_texts: List[str]):
|
| 16 |
+
inputs = tokenizer(input_texts, padding=True,
|
| 17 |
+
return_tensors='pt', truncation=True).to(1)
|
| 18 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
| 19 |
+
summary_ids = model.generate(inputs["input_ids"])
|
| 20 |
+
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
|
| 21 |
+
clean_up_tokenization_spaces=False, batch_size=len(input_texts))
|
| 22 |
+
return summaries
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_summarizer(model_id="seonglae/resrer") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]:
|
| 26 |
+
tokenizer = PegasusTokenizerFast.from_pretrained(model_id)
|
| 27 |
+
model = PegasusXForConditionalGeneration.from_pretrained(model_id).to(1)
|
| 28 |
+
model = torch.compile(model)
|
| 29 |
+
return tokenizer, model
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# OpenAI reader
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class AnswerInfo(TypedDict):
|
| 36 |
+
score: float
|
| 37 |
+
start: int
|
| 38 |
+
end: int
|
| 39 |
+
answer: str
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@torch.inference_mode()
|
| 43 |
+
def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
|
| 44 |
+
questions: List[str], ctxs: List[str]) -> List[AnswerInfo]:
|
| 45 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
| 46 |
+
pipeline = QuestionAnsweringPipeline(
|
| 47 |
+
model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len)
|
| 48 |
+
answer_infos: List[AnswerInfo] = pipeline(
|
| 49 |
+
question=questions, context=ctxs)
|
| 50 |
+
for answer_info in answer_infos:
|
| 51 |
+
answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
|
| 52 |
+
return answer_infos
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_reader(model_id="mrm8488/longformer-base-4096-finetuned-squadv2"):
|
| 56 |
+
tokenizer = DPRReaderTokenizer.from_pretrained(model_id)
|
| 57 |
+
model = DPRReader.from_pretrained(model_id).to(0)
|
| 58 |
+
return tokenizer, model
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor:
|
| 62 |
+
"""Encode a question using DPR question encoder.
|
| 63 |
+
https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
question (str): question string to encode
|
| 67 |
+
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
|
| 68 |
+
"""
|
| 69 |
+
batch_dict = tokenizer(questions, return_tensors="pt",
|
| 70 |
+
padding=True, truncation=True,).to(0)
|
| 71 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
| 72 |
+
embeddings: torch.FloatTensor = model(**batch_dict).pooler_output
|
| 73 |
+
return embeddings
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_dpr_encoder(model_id="facebook/dpr-question_encoder-single-nq-base") -> Tuple[DPRQuestionEncoder, DPRQuestionEncoderTokenizer]:
|
| 77 |
+
"""Encode a question using DPR question encoder.
|
| 78 |
+
https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
question (str): question string to encode
|
| 82 |
+
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
|
| 83 |
+
"""
|
| 84 |
+
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id)
|
| 85 |
+
model = DPRQuestionEncoder.from_pretrained(model_id).to(0)
|
| 86 |
+
return tokenizer, model
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers
|
| 2 |
+
torch
|
| 3 |
+
pymilvus
|