Spaces:
Runtime error
Runtime error
Commit
·
852c61d
0
Parent(s):
Duplicate from deepset/retrieval-augmentation-svb
Browse filesCo-authored-by: Tanay Soni <[email protected]>
- .gitattributes +34 -0
- .streamlit/config.toml +13 -0
- README.md +13 -0
- app.py +57 -0
- data/my_faiss_index.faiss +0 -0
- data/my_faiss_index.json +1 -0
- faiss_document_store.db +0 -0
- logo/haystack-logo-colored.png +0 -0
- requirements.txt +7 -0
- utils/__init__.py +0 -0
- utils/backend.py +64 -0
- utils/constants.py +14 -0
- utils/ui.py +112 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
base = "light"
|
3 |
+
font="monospace"
|
4 |
+
[global]
|
5 |
+
|
6 |
+
# By default, Streamlit checks if the Python watchdog module is available and, if not, prints a warning asking for you to install it. The watchdog module is not required, but highly recommended. It improves Streamlit's ability to detect changes to files in your filesystem.
|
7 |
+
# If you'd like to turn off this warning, set this to True.
|
8 |
+
# Default: false
|
9 |
+
disableWatchdogWarning = true
|
10 |
+
|
11 |
+
# If True, will show a warning when you run a Streamlit-enabled script via "python my_script.py".
|
12 |
+
# Default: true
|
13 |
+
showWarningOnDirectExecution = false
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Retrieval Augmented Generative QA
|
3 |
+
emoji: 👁
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: pink
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.19.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: deepset/retrieval-augmentation-svb
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from utils.backend import (get_plain_pipeline, get_retrieval_augmented_pipeline,
|
3 |
+
get_web_retrieval_augmented_pipeline)
|
4 |
+
from utils.ui import left_sidebar, right_sidebar, main_column
|
5 |
+
from utils.constants import BUTTON_LOCAL_RET_AUG
|
6 |
+
|
7 |
+
st.set_page_config(
|
8 |
+
page_title="Retrieval Augmentation with Haystack",
|
9 |
+
layout="wide"
|
10 |
+
)
|
11 |
+
left_sidebar()
|
12 |
+
|
13 |
+
st.markdown("<center> <h2> Reduce Hallucinations 😵💫 with Retrieval Augmentation </h2> </center>", unsafe_allow_html=True)
|
14 |
+
|
15 |
+
st.markdown("<center>Ask a question about the collapse of the Silicon Valley Bank (SVB).</center>", unsafe_allow_html=True)
|
16 |
+
|
17 |
+
col_1, col_2 = st.columns([4, 2], gap="small")
|
18 |
+
with col_1:
|
19 |
+
run_pressed, placeholder_plain_gpt, placeholder_retrieval_augmented = main_column()
|
20 |
+
|
21 |
+
with col_2:
|
22 |
+
right_sidebar()
|
23 |
+
|
24 |
+
if st.session_state.get('query') and run_pressed:
|
25 |
+
ip = st.session_state['query']
|
26 |
+
with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
|
27 |
+
p1 = get_plain_pipeline()
|
28 |
+
with st.spinner('Fetching answers from plain GPT... '
|
29 |
+
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
30 |
+
answers = p1.run(ip)
|
31 |
+
placeholder_plain_gpt.markdown(answers['results'][0])
|
32 |
+
|
33 |
+
if st.session_state.get("query_type", BUTTON_LOCAL_RET_AUG) == BUTTON_LOCAL_RET_AUG:
|
34 |
+
with st.spinner(
|
35 |
+
'Loading Retrieval Augmented pipeline that can fetch relevant documents from local data store... '
|
36 |
+
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
37 |
+
p2 = get_retrieval_augmented_pipeline()
|
38 |
+
with st.spinner('Getting relevant documents from documented stores and calculating answers... '
|
39 |
+
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
40 |
+
answers_2 = p2.run(ip)
|
41 |
+
else:
|
42 |
+
with st.spinner(
|
43 |
+
'Loading Retrieval Augmented pipeline that can fetch relevant documents from the web... \
|
44 |
+
n This may take a few mins and might also fail if OpenAI API server is down.'):
|
45 |
+
p3 = get_web_retrieval_augmented_pipeline()
|
46 |
+
with st.spinner('Getting relevant documents from the Web and calculating answers... '
|
47 |
+
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
|
48 |
+
answers_2 = p3.run(ip)
|
49 |
+
placeholder_retrieval_augmented.markdown(answers_2['results'][0])
|
50 |
+
with st.expander("See source:"):
|
51 |
+
src = answers_2['invocation_context']['documents'][0].content.replace("$", "\$")
|
52 |
+
split_marker = "\n\n" if "\n\n" in src else "\n"
|
53 |
+
src = " ".join(src.split(split_marker))[0:2000] + "..."
|
54 |
+
if answers_2['invocation_context']['documents'][0].meta.get('link'):
|
55 |
+
title = answers_2['invocation_context']['documents'][0].meta.get('link')
|
56 |
+
src = '"' + title + '": ' + src
|
57 |
+
st.write(src)
|
data/my_faiss_index.faiss
ADDED
Binary file (154 kB). View file
|
|
data/my_faiss_index.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"faiss_index_factory_str": "Flat"}
|
faiss_document_store.db
ADDED
Binary file (274 kB). View file
|
|
logo/haystack-logo-colored.png
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
farm-haystack==1.17.1
|
2 |
+
faiss-cpu==1.7.2
|
3 |
+
sqlalchemy>=1.4.2,<2
|
4 |
+
sqlalchemy_utils
|
5 |
+
psycopg2-binary
|
6 |
+
streamlit==1.19.0
|
7 |
+
altair<5
|
utils/__init__.py
ADDED
File without changes
|
utils/backend.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from haystack import Pipeline
|
3 |
+
from haystack.document_stores import FAISSDocumentStore
|
4 |
+
from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel, EmbeddingRetriever
|
5 |
+
from haystack.nodes.retriever.web import WebRetriever
|
6 |
+
|
7 |
+
|
8 |
+
@st.cache_resource(show_spinner=False)
|
9 |
+
def get_plain_pipeline():
|
10 |
+
prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"])
|
11 |
+
# Now let make one PromptNode use the default model and the other one the OpenAI model:
|
12 |
+
plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: {query}")
|
13 |
+
node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300)
|
14 |
+
pipeline = Pipeline()
|
15 |
+
pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"])
|
16 |
+
return pipeline
|
17 |
+
|
18 |
+
|
19 |
+
@st.cache_resource(show_spinner=False)
|
20 |
+
def get_retrieval_augmented_pipeline():
|
21 |
+
ds = FAISSDocumentStore(faiss_index_path="data/my_faiss_index.faiss",
|
22 |
+
faiss_config_path="data/my_faiss_index.json")
|
23 |
+
|
24 |
+
retriever = EmbeddingRetriever(
|
25 |
+
document_store=ds,
|
26 |
+
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
27 |
+
model_format="sentence_transformers",
|
28 |
+
top_k=2
|
29 |
+
)
|
30 |
+
|
31 |
+
default_template = PromptTemplate(
|
32 |
+
name="question-answering",
|
33 |
+
prompt_text="Given the context please answer the question. Context: {join(documents)}; Question: "
|
34 |
+
"{query}; Answer:",
|
35 |
+
)
|
36 |
+
|
37 |
+
# Let's initiate the PromptNode
|
38 |
+
node = PromptNode("text-davinci-003", default_prompt_template=default_template,
|
39 |
+
api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
|
40 |
+
|
41 |
+
# Let's create a simple retrieval augmented pipeline with the retriever + PromptNode
|
42 |
+
pipeline = Pipeline()
|
43 |
+
pipeline.add_node(component=retriever, name='retriever', inputs=['Query'])
|
44 |
+
pipeline.add_node(component=node, name="prompt_node", inputs=["retriever"])
|
45 |
+
return pipeline
|
46 |
+
|
47 |
+
|
48 |
+
@st.cache_resource(show_spinner=False)
|
49 |
+
def get_web_retrieval_augmented_pipeline():
|
50 |
+
search_key = st.secrets["WEBRET_API_KEY"]
|
51 |
+
web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev")
|
52 |
+
default_template = PromptTemplate(
|
53 |
+
name="question-answering",
|
54 |
+
prompt_text="Given the context please answer the question. Context: {join(documents)}; Question: "
|
55 |
+
"{query}; Answer:",
|
56 |
+
)
|
57 |
+
# Let's initiate the PromptNode
|
58 |
+
node = PromptNode("text-davinci-003", default_prompt_template=default_template,
|
59 |
+
api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
|
60 |
+
# Let's create a pipeline with the webretriever + PromptNode
|
61 |
+
pipeline = Pipeline()
|
62 |
+
pipeline.add_node(component=web_retriever, name='retriever', inputs=['Query'])
|
63 |
+
pipeline.add_node(component=node, name="prompt_node", inputs=["retriever"])
|
64 |
+
return pipeline
|
utils/constants.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
QUERIES = [
|
2 |
+
"Did SVB collapse?",
|
3 |
+
"Why did SVB collapse?",
|
4 |
+
"What does SVB failure mean for our economy?",
|
5 |
+
"Who is responsible for SVB collapse?",
|
6 |
+
"When did SVB collapse?"
|
7 |
+
]
|
8 |
+
PLAIN_GPT_ANS = "Answer with plain GPT"
|
9 |
+
GPT_LOCAL_RET_AUG_ANS = "Answer with Retrieval augmented GPT (static news dataset)"
|
10 |
+
GPT_WEB_RET_AUG_ANS = "Answer with Retrieval augmented GPT (web search)"
|
11 |
+
|
12 |
+
|
13 |
+
BUTTON_LOCAL_RET_AUG = "Retrieval augmented (static news dataset)"
|
14 |
+
BUTTON_WEB_RET_AUG = "Retrieval augmented with web search"
|
utils/ui.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
from .constants import (QUERIES, PLAIN_GPT_ANS, GPT_WEB_RET_AUG_ANS, GPT_LOCAL_RET_AUG_ANS,
|
5 |
+
BUTTON_LOCAL_RET_AUG, BUTTON_WEB_RET_AUG)
|
6 |
+
|
7 |
+
|
8 |
+
def set_question():
|
9 |
+
st.session_state['query'] = st.session_state['q_drop_down']
|
10 |
+
|
11 |
+
|
12 |
+
def set_q1():
|
13 |
+
st.session_state['query'] = QUERIES[0]
|
14 |
+
|
15 |
+
|
16 |
+
def set_q2():
|
17 |
+
st.session_state['query'] = QUERIES[1]
|
18 |
+
|
19 |
+
|
20 |
+
def set_q3():
|
21 |
+
st.session_state['query'] = QUERIES[2]
|
22 |
+
|
23 |
+
|
24 |
+
def set_q4():
|
25 |
+
st.session_state['query'] = QUERIES[3]
|
26 |
+
|
27 |
+
|
28 |
+
def set_q5():
|
29 |
+
st.session_state['query'] = QUERIES[4]
|
30 |
+
|
31 |
+
|
32 |
+
def main_column():
|
33 |
+
placeholder = st.empty()
|
34 |
+
with placeholder:
|
35 |
+
search_bar, button = st.columns([3, 1])
|
36 |
+
with search_bar:
|
37 |
+
_ = st.text_area(f" ", max_chars=200, key='query')
|
38 |
+
|
39 |
+
with button:
|
40 |
+
st.write(" ")
|
41 |
+
st.write(" ")
|
42 |
+
run_pressed = st.button("Run", key="run")
|
43 |
+
|
44 |
+
st.write(" ")
|
45 |
+
st.radio("Answer Type:", (BUTTON_LOCAL_RET_AUG, BUTTON_WEB_RET_AUG), key="query_type")
|
46 |
+
|
47 |
+
st.markdown(f"<h5>{PLAIN_GPT_ANS}</h5>", unsafe_allow_html=True)
|
48 |
+
placeholder_plain_gpt = st.empty()
|
49 |
+
placeholder_plain_gpt.text_area(f" ", placeholder="The answer will appear here.", disabled=True,
|
50 |
+
key=PLAIN_GPT_ANS, height=1, label_visibility='collapsed')
|
51 |
+
if st.session_state.get("query_type", BUTTON_LOCAL_RET_AUG) == BUTTON_LOCAL_RET_AUG:
|
52 |
+
st.markdown(f"<h5>{GPT_LOCAL_RET_AUG_ANS}</h5>", unsafe_allow_html=True)
|
53 |
+
else:
|
54 |
+
st.markdown(f"<h5>{GPT_WEB_RET_AUG_ANS}</h5>", unsafe_allow_html=True)
|
55 |
+
placeholder_retrieval_augmented = st.empty()
|
56 |
+
placeholder_retrieval_augmented.text_area(f" ", placeholder="The answer will appear here.", disabled=True,
|
57 |
+
key=GPT_LOCAL_RET_AUG_ANS, height=1, label_visibility='collapsed')
|
58 |
+
|
59 |
+
return run_pressed, placeholder_plain_gpt, placeholder_retrieval_augmented
|
60 |
+
|
61 |
+
|
62 |
+
def right_sidebar():
|
63 |
+
st.write("")
|
64 |
+
st.write("")
|
65 |
+
st.markdown("<h5> Example questions </h5>", unsafe_allow_html=True)
|
66 |
+
st.button(QUERIES[0], on_click=set_q1, use_container_width=True)
|
67 |
+
st.button(QUERIES[1], on_click=set_q2, use_container_width=True)
|
68 |
+
st.button(QUERIES[2], on_click=set_q3, use_container_width=True)
|
69 |
+
st.button(QUERIES[3], on_click=set_q4, use_container_width=True)
|
70 |
+
st.button(QUERIES[4], on_click=set_q5, use_container_width=True)
|
71 |
+
|
72 |
+
|
73 |
+
def left_sidebar():
|
74 |
+
with st.sidebar:
|
75 |
+
image = Image.open('logo/haystack-logo-colored.png')
|
76 |
+
st.markdown("Thanks for coming to this :hugging_face: space. \n\n"
|
77 |
+
"This is an effort towards showcasing how you can use Haystack for Retrieval Augmented QA, "
|
78 |
+
"with local [FAISSDocumentStore](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstore)"
|
79 |
+
" or a [WebRetriever](https://docs.haystack.deepset.ai/docs/retriever#retrieval-from-the-web). \n\n"
|
80 |
+
"More information on how this was built and instructions along "
|
81 |
+
"with a repository will be published soon and updated here.")
|
82 |
+
|
83 |
+
# st.markdown(
|
84 |
+
# "## How to use\n"
|
85 |
+
# "1. Enter your [OpenAI API key](https://platform.openai.com/account/api-keys) below\n"
|
86 |
+
# "2. Enter a Serper Dev API key\n"
|
87 |
+
# "3. Enjoy 🤗\n"
|
88 |
+
# )
|
89 |
+
|
90 |
+
# api_key_input = st.text_input(
|
91 |
+
# "OpenAI API Key",
|
92 |
+
# type="password",
|
93 |
+
# placeholder="Paste your OpenAI API key here (sk-...)",
|
94 |
+
# help="You can get your API key from https://platform.openai.com/account/api-keys.",
|
95 |
+
# value=st.session_state.get("OPENAI_API_KEY", ""),
|
96 |
+
# )
|
97 |
+
|
98 |
+
# if api_key_input:
|
99 |
+
# set_openai_api_key(api_key_input)
|
100 |
+
|
101 |
+
st.markdown("---")
|
102 |
+
st.markdown(
|
103 |
+
"## How this works\n"
|
104 |
+
"This app was built with [Haystack](https://haystack.deepset.ai) using the"
|
105 |
+
" [PromptNode](https://docs.haystack.deepset.ai/docs/prompt_node), "
|
106 |
+
"[Retriever](https://docs.haystack.deepset.ai/docs/retriever#embedding-retrieval-recommended),"
|
107 |
+
"and [FAISSDocumentStore](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstore).\n\n"
|
108 |
+
" You can find the source code in **Files and versions** tab."
|
109 |
+
)
|
110 |
+
|
111 |
+
st.markdown("---")
|
112 |
+
st.image(image, width=250)
|