intoxication tanaysoni commited on
Commit
852c61d
·
0 Parent(s):

Duplicate from deepset/retrieval-augmentation-svb

Browse files

Co-authored-by: Tanay Soni <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.streamlit/config.toml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ base = "light"
3
+ font="monospace"
4
+ [global]
5
+
6
+ # By default, Streamlit checks if the Python watchdog module is available and, if not, prints a warning asking for you to install it. The watchdog module is not required, but highly recommended. It improves Streamlit's ability to detect changes to files in your filesystem.
7
+ # If you'd like to turn off this warning, set this to True.
8
+ # Default: false
9
+ disableWatchdogWarning = true
10
+
11
+ # If True, will show a warning when you run a Streamlit-enabled script via "python my_script.py".
12
+ # Default: true
13
+ showWarningOnDirectExecution = false
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Retrieval Augmented Generative QA
3
+ emoji: 👁
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: streamlit
7
+ sdk_version: 1.19.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: deepset/retrieval-augmentation-svb
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils.backend import (get_plain_pipeline, get_retrieval_augmented_pipeline,
3
+ get_web_retrieval_augmented_pipeline)
4
+ from utils.ui import left_sidebar, right_sidebar, main_column
5
+ from utils.constants import BUTTON_LOCAL_RET_AUG
6
+
7
+ st.set_page_config(
8
+ page_title="Retrieval Augmentation with Haystack",
9
+ layout="wide"
10
+ )
11
+ left_sidebar()
12
+
13
+ st.markdown("<center> <h2> Reduce Hallucinations 😵‍💫 with Retrieval Augmentation </h2> </center>", unsafe_allow_html=True)
14
+
15
+ st.markdown("<center>Ask a question about the collapse of the Silicon Valley Bank (SVB).</center>", unsafe_allow_html=True)
16
+
17
+ col_1, col_2 = st.columns([4, 2], gap="small")
18
+ with col_1:
19
+ run_pressed, placeholder_plain_gpt, placeholder_retrieval_augmented = main_column()
20
+
21
+ with col_2:
22
+ right_sidebar()
23
+
24
+ if st.session_state.get('query') and run_pressed:
25
+ ip = st.session_state['query']
26
+ with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
27
+ p1 = get_plain_pipeline()
28
+ with st.spinner('Fetching answers from plain GPT... '
29
+ '\n This may take a few mins and might also fail if OpenAI API server is down.'):
30
+ answers = p1.run(ip)
31
+ placeholder_plain_gpt.markdown(answers['results'][0])
32
+
33
+ if st.session_state.get("query_type", BUTTON_LOCAL_RET_AUG) == BUTTON_LOCAL_RET_AUG:
34
+ with st.spinner(
35
+ 'Loading Retrieval Augmented pipeline that can fetch relevant documents from local data store... '
36
+ '\n This may take a few mins and might also fail if OpenAI API server is down.'):
37
+ p2 = get_retrieval_augmented_pipeline()
38
+ with st.spinner('Getting relevant documents from documented stores and calculating answers... '
39
+ '\n This may take a few mins and might also fail if OpenAI API server is down.'):
40
+ answers_2 = p2.run(ip)
41
+ else:
42
+ with st.spinner(
43
+ 'Loading Retrieval Augmented pipeline that can fetch relevant documents from the web... \
44
+ n This may take a few mins and might also fail if OpenAI API server is down.'):
45
+ p3 = get_web_retrieval_augmented_pipeline()
46
+ with st.spinner('Getting relevant documents from the Web and calculating answers... '
47
+ '\n This may take a few mins and might also fail if OpenAI API server is down.'):
48
+ answers_2 = p3.run(ip)
49
+ placeholder_retrieval_augmented.markdown(answers_2['results'][0])
50
+ with st.expander("See source:"):
51
+ src = answers_2['invocation_context']['documents'][0].content.replace("$", "\$")
52
+ split_marker = "\n\n" if "\n\n" in src else "\n"
53
+ src = " ".join(src.split(split_marker))[0:2000] + "..."
54
+ if answers_2['invocation_context']['documents'][0].meta.get('link'):
55
+ title = answers_2['invocation_context']['documents'][0].meta.get('link')
56
+ src = '"' + title + '": ' + src
57
+ st.write(src)
data/my_faiss_index.faiss ADDED
Binary file (154 kB). View file
 
data/my_faiss_index.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"faiss_index_factory_str": "Flat"}
faiss_document_store.db ADDED
Binary file (274 kB). View file
 
logo/haystack-logo-colored.png ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ farm-haystack==1.17.1
2
+ faiss-cpu==1.7.2
3
+ sqlalchemy>=1.4.2,<2
4
+ sqlalchemy_utils
5
+ psycopg2-binary
6
+ streamlit==1.19.0
7
+ altair<5
utils/__init__.py ADDED
File without changes
utils/backend.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from haystack import Pipeline
3
+ from haystack.document_stores import FAISSDocumentStore
4
+ from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel, EmbeddingRetriever
5
+ from haystack.nodes.retriever.web import WebRetriever
6
+
7
+
8
+ @st.cache_resource(show_spinner=False)
9
+ def get_plain_pipeline():
10
+ prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"])
11
+ # Now let make one PromptNode use the default model and the other one the OpenAI model:
12
+ plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: {query}")
13
+ node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300)
14
+ pipeline = Pipeline()
15
+ pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"])
16
+ return pipeline
17
+
18
+
19
+ @st.cache_resource(show_spinner=False)
20
+ def get_retrieval_augmented_pipeline():
21
+ ds = FAISSDocumentStore(faiss_index_path="data/my_faiss_index.faiss",
22
+ faiss_config_path="data/my_faiss_index.json")
23
+
24
+ retriever = EmbeddingRetriever(
25
+ document_store=ds,
26
+ embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
27
+ model_format="sentence_transformers",
28
+ top_k=2
29
+ )
30
+
31
+ default_template = PromptTemplate(
32
+ name="question-answering",
33
+ prompt_text="Given the context please answer the question. Context: {join(documents)}; Question: "
34
+ "{query}; Answer:",
35
+ )
36
+
37
+ # Let's initiate the PromptNode
38
+ node = PromptNode("text-davinci-003", default_prompt_template=default_template,
39
+ api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
40
+
41
+ # Let's create a simple retrieval augmented pipeline with the retriever + PromptNode
42
+ pipeline = Pipeline()
43
+ pipeline.add_node(component=retriever, name='retriever', inputs=['Query'])
44
+ pipeline.add_node(component=node, name="prompt_node", inputs=["retriever"])
45
+ return pipeline
46
+
47
+
48
+ @st.cache_resource(show_spinner=False)
49
+ def get_web_retrieval_augmented_pipeline():
50
+ search_key = st.secrets["WEBRET_API_KEY"]
51
+ web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev")
52
+ default_template = PromptTemplate(
53
+ name="question-answering",
54
+ prompt_text="Given the context please answer the question. Context: {join(documents)}; Question: "
55
+ "{query}; Answer:",
56
+ )
57
+ # Let's initiate the PromptNode
58
+ node = PromptNode("text-davinci-003", default_prompt_template=default_template,
59
+ api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
60
+ # Let's create a pipeline with the webretriever + PromptNode
61
+ pipeline = Pipeline()
62
+ pipeline.add_node(component=web_retriever, name='retriever', inputs=['Query'])
63
+ pipeline.add_node(component=node, name="prompt_node", inputs=["retriever"])
64
+ return pipeline
utils/constants.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ QUERIES = [
2
+ "Did SVB collapse?",
3
+ "Why did SVB collapse?",
4
+ "What does SVB failure mean for our economy?",
5
+ "Who is responsible for SVB collapse?",
6
+ "When did SVB collapse?"
7
+ ]
8
+ PLAIN_GPT_ANS = "Answer with plain GPT"
9
+ GPT_LOCAL_RET_AUG_ANS = "Answer with Retrieval augmented GPT (static news dataset)"
10
+ GPT_WEB_RET_AUG_ANS = "Answer with Retrieval augmented GPT (web search)"
11
+
12
+
13
+ BUTTON_LOCAL_RET_AUG = "Retrieval augmented (static news dataset)"
14
+ BUTTON_WEB_RET_AUG = "Retrieval augmented with web search"
utils/ui.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+
4
+ from .constants import (QUERIES, PLAIN_GPT_ANS, GPT_WEB_RET_AUG_ANS, GPT_LOCAL_RET_AUG_ANS,
5
+ BUTTON_LOCAL_RET_AUG, BUTTON_WEB_RET_AUG)
6
+
7
+
8
+ def set_question():
9
+ st.session_state['query'] = st.session_state['q_drop_down']
10
+
11
+
12
+ def set_q1():
13
+ st.session_state['query'] = QUERIES[0]
14
+
15
+
16
+ def set_q2():
17
+ st.session_state['query'] = QUERIES[1]
18
+
19
+
20
+ def set_q3():
21
+ st.session_state['query'] = QUERIES[2]
22
+
23
+
24
+ def set_q4():
25
+ st.session_state['query'] = QUERIES[3]
26
+
27
+
28
+ def set_q5():
29
+ st.session_state['query'] = QUERIES[4]
30
+
31
+
32
+ def main_column():
33
+ placeholder = st.empty()
34
+ with placeholder:
35
+ search_bar, button = st.columns([3, 1])
36
+ with search_bar:
37
+ _ = st.text_area(f" ", max_chars=200, key='query')
38
+
39
+ with button:
40
+ st.write(" ")
41
+ st.write(" ")
42
+ run_pressed = st.button("Run", key="run")
43
+
44
+ st.write(" ")
45
+ st.radio("Answer Type:", (BUTTON_LOCAL_RET_AUG, BUTTON_WEB_RET_AUG), key="query_type")
46
+
47
+ st.markdown(f"<h5>{PLAIN_GPT_ANS}</h5>", unsafe_allow_html=True)
48
+ placeholder_plain_gpt = st.empty()
49
+ placeholder_plain_gpt.text_area(f" ", placeholder="The answer will appear here.", disabled=True,
50
+ key=PLAIN_GPT_ANS, height=1, label_visibility='collapsed')
51
+ if st.session_state.get("query_type", BUTTON_LOCAL_RET_AUG) == BUTTON_LOCAL_RET_AUG:
52
+ st.markdown(f"<h5>{GPT_LOCAL_RET_AUG_ANS}</h5>", unsafe_allow_html=True)
53
+ else:
54
+ st.markdown(f"<h5>{GPT_WEB_RET_AUG_ANS}</h5>", unsafe_allow_html=True)
55
+ placeholder_retrieval_augmented = st.empty()
56
+ placeholder_retrieval_augmented.text_area(f" ", placeholder="The answer will appear here.", disabled=True,
57
+ key=GPT_LOCAL_RET_AUG_ANS, height=1, label_visibility='collapsed')
58
+
59
+ return run_pressed, placeholder_plain_gpt, placeholder_retrieval_augmented
60
+
61
+
62
+ def right_sidebar():
63
+ st.write("")
64
+ st.write("")
65
+ st.markdown("<h5> Example questions </h5>", unsafe_allow_html=True)
66
+ st.button(QUERIES[0], on_click=set_q1, use_container_width=True)
67
+ st.button(QUERIES[1], on_click=set_q2, use_container_width=True)
68
+ st.button(QUERIES[2], on_click=set_q3, use_container_width=True)
69
+ st.button(QUERIES[3], on_click=set_q4, use_container_width=True)
70
+ st.button(QUERIES[4], on_click=set_q5, use_container_width=True)
71
+
72
+
73
+ def left_sidebar():
74
+ with st.sidebar:
75
+ image = Image.open('logo/haystack-logo-colored.png')
76
+ st.markdown("Thanks for coming to this :hugging_face: space. \n\n"
77
+ "This is an effort towards showcasing how you can use Haystack for Retrieval Augmented QA, "
78
+ "with local [FAISSDocumentStore](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstore)"
79
+ " or a [WebRetriever](https://docs.haystack.deepset.ai/docs/retriever#retrieval-from-the-web). \n\n"
80
+ "More information on how this was built and instructions along "
81
+ "with a repository will be published soon and updated here.")
82
+
83
+ # st.markdown(
84
+ # "## How to use\n"
85
+ # "1. Enter your [OpenAI API key](https://platform.openai.com/account/api-keys) below\n"
86
+ # "2. Enter a Serper Dev API key\n"
87
+ # "3. Enjoy 🤗\n"
88
+ # )
89
+
90
+ # api_key_input = st.text_input(
91
+ # "OpenAI API Key",
92
+ # type="password",
93
+ # placeholder="Paste your OpenAI API key here (sk-...)",
94
+ # help="You can get your API key from https://platform.openai.com/account/api-keys.",
95
+ # value=st.session_state.get("OPENAI_API_KEY", ""),
96
+ # )
97
+
98
+ # if api_key_input:
99
+ # set_openai_api_key(api_key_input)
100
+
101
+ st.markdown("---")
102
+ st.markdown(
103
+ "## How this works\n"
104
+ "This app was built with [Haystack](https://haystack.deepset.ai) using the"
105
+ " [PromptNode](https://docs.haystack.deepset.ai/docs/prompt_node), "
106
+ "[Retriever](https://docs.haystack.deepset.ai/docs/retriever#embedding-retrieval-recommended),"
107
+ "and [FAISSDocumentStore](https://docs.haystack.deepset.ai/reference/document-store-api#faissdocumentstore).\n\n"
108
+ " You can find the source code in **Files and versions** tab."
109
+ )
110
+
111
+ st.markdown("---")
112
+ st.image(image, width=250)