Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import base64 | |
import time | |
from pathlib import Path | |
import pandas as pd | |
import streamlit as st | |
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore | |
from src.document_store.get_index import get_index | |
from src.rag.pipeline import RAGPipeline | |
from src.utils.data import load_css, load_json | |
from src.utils.writer import typewriter | |
DATA_BASE_PATH = Path(__file__).parent.parent.parent.parent / "data" | |
# Function to load and encode the image | |
def get_base64_image(image_path): | |
with open(image_path, "rb") as img_file: | |
return base64.b64encode(img_file.read()).decode() | |
def load_css_style() -> None: | |
load_css(Path(__file__).parent.parent.parent.parent / "style" / "style.css") | |
def load_template() -> str: | |
path = ( | |
Path(__file__).parent.parent.parent | |
/ "rag" | |
/ "prompt_templates" | |
/ "inc_template.txt" | |
) | |
with open(path, "r") as file: | |
template = file.read() | |
return template | |
def load_inc_pipeline(template: str) -> tuple[QdrantDocumentStore, RAGPipeline]: | |
inc_index = get_index(index="inc_data") | |
inc_rag = RAGPipeline(document_store=inc_index, top_k=7, template=template) | |
return inc_index, inc_rag | |
def get_authors_taxonomy() -> list[str]: | |
taxonomy = load_json(DATA_BASE_PATH / "taxonomies" / "authors_taxonomy.json") | |
countries = [] | |
members = taxonomy["Members"] | |
for key, value in members.items(): | |
if key == "Countries" or key == "International and Regional State Associations": | |
countries.extend(value) | |
return countries | |
def get_draft_cat_taxonomy() -> dict[str, list[str]]: | |
taxonomy = load_json( | |
DATA_BASE_PATH / "taxonomies" / "draftcat_taxonomy_filter.json" | |
) | |
draft_labels = [] | |
for _, subpart in taxonomy.items(): | |
for label in subpart: | |
draft_labels.append(label) | |
return draft_labels | |
def get_negotiations_rounds() -> list[int]: | |
return [1, 2, 3, 4] | |
def get_example_prompts() -> list[str]: | |
return [ | |
example["question"] | |
for example in load_json( | |
DATA_BASE_PATH / "example_prompts" / "example_prompts_inc.json" | |
) | |
] | |
def set_trigger_state_values() -> tuple[bool, bool]: | |
trigger_filter_inc = st.session_state.setdefault("trigger_inc", False) | |
trigger_ask_inc = st.session_state.setdefault("trigger_inc", False) | |
return trigger_filter_inc, trigger_ask_inc | |
def load_app_init() -> None: | |
description_inc_col_1, _ = st.columns([0.66, 1]) | |
with description_inc_col_1: | |
with st.expander("About", icon=":material/info:"): | |
st.markdown( | |
""" | |
<p class="description"> The Interactive Treaty Assistant will support you on your research and analysis of documents submitted by INC members in the previous rounds to quickly pinpoint crucial information. Together with treaty-specific queries make use of the filters to get more precise responses. Along with the answer, the Chatbot also provides you with direct links to relevant documents enabling a deeper examination. <br> | |
The tool excels at providing targeted information on countries and their positions in negotiations. Filter options by author and sections of the negotiation draft enhance accuracy, while direct links to filtered documents ensure quick access to detailed information. While the generated answers take into account up to eight documents at a time due to technical limitations, users can still access the full set of filtered documents via direct links for comprehensive exploration. </p> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.write("\n") | |
st.write("\n") | |
def about_inc() -> None: | |
st.markdown("""<p class="header"> Help us Improve! </p>""", unsafe_allow_html=True) | |
st.markdown( | |
"""<p class="description"> We would appreciate your feedback and support to improve the app. You can fill out a quick feedback form (maximal 5 minutes) or use the in-depth survey (maximal 15 minutes). </p>""", | |
unsafe_allow_html=True, | |
) | |
review, in_depth_review, _ = st.columns(spec=[0.7, 1.0, 4], gap="large") | |
with review: | |
st.link_button( | |
label="Feedback", | |
url="https://forms.gle/PPT5g558utGDUAGh6", | |
icon=":material/reviews:", | |
) | |
with in_depth_review: | |
st.link_button( | |
label="Survey", | |
url="https://docs.google.com/forms/d/1-WNS0ZdAuystajf2i6iSR5HpRfvV1LYq_TcQfaIMvkA", | |
icon=":material/rate_review:", | |
) | |
logo = get_base64_image("static/images/logo.png") | |
st.write("\n") | |
st.write("\n") | |
st.write("\n") | |
st.markdown( | |
f"""<div class="footer"> | |
<h3>About</h3> | |
<div class="content"> | |
The Deutsche Gesellschaft für Internationale Zusammenarbeit (GIZ) GmbH <br> | |
is a globally active service provider dedicated to international cooperation <br> | |
for sustainable development and it’s active in over 120 countries. <br> <br> | |
The GIZ Data Lab specializes in harnessing data for development, <br> | |
driving innovative solutions in international cooperation | |
to address <br> real-world challenges. <br> <br> | |
Our work on NegotiateAI started in 2023. You can find more information <br> | |
about the NegotiateAI project on our <a href="https://www.blog-datalab.com/home/negotiateai/">website</a>. | |
</div> | |
<img src="data:image/png;base64,{logo}" class="logo" /> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
def init_inc_page(): | |
load_css_style() | |
load_app_init() | |
# Load Cache Data and Resources | |
authors = get_authors_taxonomy() | |
draft_labs = get_draft_cat_taxonomy() | |
negotiation_rounds = get_negotiations_rounds() | |
example_prompts = get_example_prompts() | |
template = load_template() | |
trigger_filter_inc, trigger_ask_inc = set_trigger_state_values() | |
inc_index, inc_rag = load_inc_pipeline(template=template) | |
# Application Column | |
application_col_inc = st.columns(1) | |
with application_col_inc[0]: | |
st.markdown( | |
""" | |
<p class="header" style="display: flex; align-items: center;"> | |
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" width="32" height="32" style="margin-right: 10px;"> | |
<circle cx="16" cy="16" r="15" fill="none" stroke="#077493" stroke-width="2"/> | |
<text x="16" y="21" text-anchor="middle" font-size="16" font-family="Arial" font-weight="bold" fill="#077493">1</text> | |
</svg> Select Filters</span> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
text_design_col_1, textext_design_col_2 = st.columns([1, 1]) | |
with text_design_col_1: | |
st.markdown( | |
"""<p class="description"> Selecting at least one filter is mandatory, because otherwise the model would have to analyze all available documents which results in inaccurate answers and long processing times. Please select at least one filter. We especially recommend to select countries you are interested in. """, | |
unsafe_allow_html=True, | |
) | |
st.write("\n") | |
col_1, col_2, col_3 = st.columns([1, 1, 1]) | |
with col_1: | |
selected_authors_inc = st.multiselect( | |
label="Countries or Associations", | |
options=authors, | |
label_visibility="visible", | |
placeholder="Select", | |
key="selected_authors_inc", | |
help="Please select the countries of interest. Your selection will refine the database to include documents submitted by these countries or recognized groupings such as Small Developing States, the African States Group, etc.", | |
) | |
with col_2: | |
selected_rounds_inc = st.multiselect( | |
label="Session", | |
options=negotiation_rounds, | |
label_visibility="visible", | |
placeholder="Select", | |
key="selected_rounds_inc", | |
help="Please select the countries of interest. Your selection will refine the database to include documents submitted by these countries or recognized groupings such as Small Developing States, the African States Group, etc. </p>", | |
) | |
with col_3: | |
selected_draft_cats_inc = st.multiselect( | |
label="Draft Categories", | |
options=draft_labs, | |
label_visibility="visible", | |
placeholder="Select", | |
key="selected_draft_cats", | |
help=" Please select the parts of the negotiation draft of interest. The negotiation draft can be accessed (https://www.unep.org/inc-plastic-pollution/session-4/documents)", | |
) | |
st.write("\n") | |
st.write("\n") | |
st.markdown( | |
""" | |
<p class="header" style="display: flex; align-items: center;"> | |
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" width="32" height="32" style="margin-right: 10px;"> | |
<circle cx="16" cy="16" r="15" fill="none" stroke="#077493" stroke-width="2"/> | |
<text x="16" y="21" text-anchor="middle" font-size="16" font-family="Arial" font-weight="bold" fill="#077493">2</text> | |
</svg> Ask a question or show documents based on selected filters | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
asking_inc, filtering_inc = st.tabs(["Ask a question", "Filter documents"]) | |
with asking_inc: | |
application_col_ask_inc, output_col_ask_inc = st.columns([1, 1.5]) | |
with application_col_ask_inc: | |
st.markdown( | |
""" | |
<p class="description"> Ask a question, noting that the database has been restricted by filters and that your question should pertain to the selected data. \n | |
""", | |
unsafe_allow_html=True, | |
) | |
if "prompt" not in st.session_state: | |
prompt_inc = st.text_area("") | |
if ( | |
"prompt" in st.session_state | |
and st.session_state.prompt in example_prompts # noqa: E501 | |
): # noqa: E501 | |
prompt_inc = st.text_area( | |
"Enter a question", value=st.session_state.prompt | |
) # noqa: E501 | |
if ( | |
"prompt" in st.session_state | |
and st.session_state.prompt not in example_prompts # noqa: E501 | |
): # noqa: E501 | |
del st.session_state["prompt"] | |
prompt_inc = st.text_area("Enter a question") | |
trigger_ask_inc = st.session_state.setdefault("trigger_inc", False) | |
if st.button("Ask", icon=":material/send:", type="primary"): | |
if prompt_inc == "": | |
st.error( | |
"Please enter a question. Reloading the app in few seconds" | |
) | |
time.sleep(3) | |
st.rerun() | |
with st.spinner("Filtering data...") as status: | |
if ( | |
not selected_authors_inc | |
and not selected_draft_cats_inc | |
and not selected_rounds_inc | |
): | |
st.error( | |
"Selecting a filter is mandatory. We especially recommend to select countries you are interested in. Selecting at least one filter is mandatory, because otherwise the model would have to analyze all available documents which results in inaccurate answers and long processing times. Please select at least one filter." | |
) | |
st.stop() | |
with st.spinner("Analyzing Filters") as status: | |
filter_selection = { | |
"author": selected_authors_inc, | |
"draft_labs": selected_draft_cats_inc, | |
"round": selected_rounds_inc, | |
} | |
filters = inc_rag.build_filter( | |
filter_selections=filter_selection | |
) | |
docs = inc_index.filter_documents(filters) | |
if not docs: | |
st.error( | |
"The combination of filters you've chosen does not match any documents. Please try another combination of filters. If a filter combination does not return any documents, it means that there are no documents that match the selected filters and therefore no answer can be given." | |
) | |
trigger_ask_inc = False | |
st.stop() | |
else: | |
st.success("Filtering completed.") | |
with st.spinner("Answering question..."): | |
result = inc_rag.run( | |
query=prompt_inc, filter_selections=filter_selection | |
) | |
trigger_ask_inc = True | |
st.success("Answering question completed.") | |
st.markdown( | |
"***≡ Examples***", | |
help="These are example prompts that can be used to ask questions to the model. Click on a prompt to use it as a question. You can also type your own question in the text area above. In general we highly recommend to use the filter functions to narrow down the data.", | |
) | |
st.caption("Double click to select the prompt") | |
for i, prompt_inc in enumerate(example_prompts): | |
# with col[i % 4]: | |
if st.button(prompt_inc): | |
if "key" not in st.session_state: | |
st.session_state["prompt"] = prompt_inc | |
# Define the button | |
with filtering_inc: | |
application_col_filter, output_col_filter = st.columns([1, 1.5]) | |
# make the buttons text smaller | |
with application_col_filter: | |
st.markdown( | |
""" | |
<p class="description"> | |
This filter function allows you to see all documents that match the selected filters. The documents can be accessed via a link. \n | |
""", | |
unsafe_allow_html=True, | |
) | |
if st.button("Filter", icon=":material/filter_alt:", type="primary"): | |
if ( | |
not selected_authors_inc | |
and not selected_draft_cats_inc | |
and not selected_rounds_inc | |
): | |
st.info( | |
"No filters selected. All documents will be shown. Longer processing time expected." | |
) | |
with st.spinner("Filtering documents..."): | |
filter = RAGPipeline.build_filter( | |
filter_selections={ | |
"author": selected_authors_inc, | |
"draft_labs": selected_draft_cats_inc, | |
"round": selected_rounds_inc, | |
} | |
) | |
result = inc_index.filter_documents(filter) | |
retriever_ids = set() | |
result_meta = [] | |
for doc in result: | |
retriever_id = doc.meta["retriever_id"] | |
if retriever_id not in retriever_ids: | |
result_meta.append( | |
{ | |
"author": doc.meta["author"], | |
"doc_type": doc.meta["doc_type"], | |
"session": doc.meta["round"], | |
"href": doc.meta["href"], | |
"draft_labs": doc.meta["draft_labs"], | |
} | |
) | |
retriever_ids.add(retriever_id) | |
else: | |
continue | |
result_df = pd.DataFrame(result_meta) | |
if result_df.empty: | |
st.info( | |
"No documents found for the combination of filters you've chosen. All countries are represented at least once in the data. Remove the draft categories to see all documents for the countries selected or try other draft categories and/or sessions." | |
) | |
trigger_filter_inc = False | |
else: | |
trigger_filter_inc = True | |
if trigger_filter_inc: | |
with output_col_filter: | |
st.markdown("### Overview of all filtered documents") | |
st.dataframe( | |
result_df, | |
hide_index=True, | |
column_config={ | |
"author": st.column_config.ListColumn("Authors"), | |
"href": st.column_config.LinkColumn("Link to Document"), | |
"draft_labs": st.column_config.ListColumn("Draft Categories"), | |
"session": st.column_config.NumberColumn("Session"), | |
"doc_type": st.column_config.TextColumn("Document Type"), | |
}, | |
) | |
if trigger_ask_inc: | |
with output_col_ask_inc: | |
if result is None: | |
st.error( | |
"Open AI rate limit exceeded. Please try again in a few seconds." | |
) | |
st.stop() | |
reference_data = [ | |
(doc.meta["retriever_id"], doc.meta["href"]) | |
for doc in result["retriever"]["documents"] | |
] | |
references = ["\n"] | |
for retriever_id, href in reference_data: | |
references.append(f"-[{retriever_id}]: {href} \n") | |
references = list(set(references)) | |
st.markdown( | |
"""<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#077493"><path d="m640-480 80 80v80H520v240l-40 40-40-40v-240H240v-80l80-80v-280h-40v-80h400v80h-40v280Zm-286 80h252l-46-46v-314H400v314l-46 46Zm126 0Z"/></svg> <b>Answer</b>""", | |
unsafe_allow_html=True, | |
) | |
typewriter( | |
text=result["llm"]["replies"][0], | |
references=references, | |
speed=100, | |
) | |
with st.expander("Show more information to the documents"): | |
sorted_docs = sorted( | |
result["retriever"]["documents"], | |
key=lambda x: x.meta["retriever_id"], | |
) | |
current_doc = None | |
markdown_text = "" | |
for doc in sorted_docs: | |
print(current_doc) | |
if doc.meta["retriever_id"] != current_doc: | |
markdown_text += f"- Document: {doc.meta['retriever_id']}\n" | |
markdown_text += " - Text passages\n" | |
markdown_text += f" - {doc.content}\n" | |
else: | |
markdown_text += f" - {doc.content}\n" | |
current_doc = doc.meta["retriever_id"] | |
st.write(markdown_text) | |
trigger_ask_inc = False | |
st.markdown( | |
"""<hr style="height:2px;border:none;color:#077493;background-color:#077493;" /> """, | |
unsafe_allow_html=True, | |
) | |
about_inc() | |