TeresaK's picture
upload version 2 (#2)
d064c89 verified
raw
history blame
20.6 kB
import base64
import time
from pathlib import Path
import pandas as pd
import streamlit as st
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
from src.document_store.get_index import get_index
from src.rag.pipeline import RAGPipeline
from src.utils.data import load_css, load_json
from src.utils.writer import typewriter
DATA_BASE_PATH = Path(__file__).parent.parent.parent.parent / "data"
# Function to load and encode the image
def get_base64_image(image_path):
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode()
@st.cache_data
def load_css_style() -> None:
load_css(Path(__file__).parent.parent.parent.parent / "style" / "style.css")
@st.cache_data
def load_template() -> str:
path = (
Path(__file__).parent.parent.parent
/ "rag"
/ "prompt_templates"
/ "inc_template.txt"
)
with open(path, "r") as file:
template = file.read()
return template
@st.cache_resource
def load_inc_pipeline(template: str) -> tuple[QdrantDocumentStore, RAGPipeline]:
inc_index = get_index(index="inc_data")
inc_rag = RAGPipeline(document_store=inc_index, top_k=7, template=template)
return inc_index, inc_rag
@st.cache_data
def get_authors_taxonomy() -> list[str]:
taxonomy = load_json(DATA_BASE_PATH / "taxonomies" / "authors_taxonomy.json")
countries = []
members = taxonomy["Members"]
for key, value in members.items():
if key == "Countries" or key == "International and Regional State Associations":
countries.extend(value)
return countries
@st.cache_data
def get_draft_cat_taxonomy() -> dict[str, list[str]]:
taxonomy = load_json(
DATA_BASE_PATH / "taxonomies" / "draftcat_taxonomy_filter.json"
)
draft_labels = []
for _, subpart in taxonomy.items():
for label in subpart:
draft_labels.append(label)
return draft_labels
@st.cache_data
def get_negotiations_rounds() -> list[int]:
return [1, 2, 3, 4]
@st.cache_data
def get_example_prompts() -> list[str]:
return [
example["question"]
for example in load_json(
DATA_BASE_PATH / "example_prompts" / "example_prompts_inc.json"
)
]
@st.cache_data
def set_trigger_state_values() -> tuple[bool, bool]:
trigger_filter_inc = st.session_state.setdefault("trigger_inc", False)
trigger_ask_inc = st.session_state.setdefault("trigger_inc", False)
return trigger_filter_inc, trigger_ask_inc
@st.cache_data
def load_app_init() -> None:
description_inc_col_1, _ = st.columns([0.66, 1])
with description_inc_col_1:
with st.expander("About", icon=":material/info:"):
st.markdown(
"""
<p class="description"> The Interactive Treaty Assistant will support you on your research and analysis of documents submitted by INC members in the previous rounds to quickly pinpoint crucial information. Together with treaty-specific queries make use of the filters to get more precise responses. Along with the answer, the Chatbot also provides you with direct links to relevant documents enabling a deeper examination. <br>
The tool excels at providing targeted information on countries and their positions in negotiations. Filter options by author and sections of the negotiation draft enhance accuracy, while direct links to filtered documents ensure quick access to detailed information. While the generated answers take into account up to eight documents at a time due to technical limitations, users can still access the full set of filtered documents via direct links for comprehensive exploration. </p>
""",
unsafe_allow_html=True,
)
st.write("\n")
st.write("\n")
@st.cache_data
def about_inc() -> None:
st.markdown("""<p class="header"> Help us Improve! </p>""", unsafe_allow_html=True)
st.markdown(
"""<p class="description"> We would appreciate your feedback and support to improve the app. You can fill out a quick feedback form (maximal 5 minutes) or use the in-depth survey (maximal 15 minutes). </p>""",
unsafe_allow_html=True,
)
review, in_depth_review, _ = st.columns(spec=[0.7, 1.0, 4], gap="large")
with review:
st.link_button(
label="Feedback",
url="https://forms.gle/PPT5g558utGDUAGh6",
icon=":material/reviews:",
)
with in_depth_review:
st.link_button(
label="Survey",
url="https://docs.google.com/forms/d/1-WNS0ZdAuystajf2i6iSR5HpRfvV1LYq_TcQfaIMvkA",
icon=":material/rate_review:",
)
logo = get_base64_image("static/images/logo.png")
st.write("\n")
st.write("\n")
st.write("\n")
st.markdown(
f"""<div class="footer">
<h3>About</h3>
<div class="content">
The Deutsche Gesellschaft für Internationale Zusammenarbeit (GIZ) GmbH <br>
is a globally active service provider dedicated to international cooperation <br>
for sustainable development and it’s active in over 120 countries. <br> <br>
The GIZ Data Lab specializes in harnessing data for development, <br>
driving innovative solutions in international cooperation
to address <br> real-world challenges. <br> <br>
Our work on NegotiateAI started in 2023. You can find more information <br>
about the NegotiateAI project on our <a href="https://www.blog-datalab.com/home/negotiateai/">website</a>.
</div>
<img src="data:image/png;base64,{logo}" class="logo" />
</div>
""",
unsafe_allow_html=True,
)
def init_inc_page():
load_css_style()
load_app_init()
# Load Cache Data and Resources
authors = get_authors_taxonomy()
draft_labs = get_draft_cat_taxonomy()
negotiation_rounds = get_negotiations_rounds()
example_prompts = get_example_prompts()
template = load_template()
trigger_filter_inc, trigger_ask_inc = set_trigger_state_values()
inc_index, inc_rag = load_inc_pipeline(template=template)
# Application Column
application_col_inc = st.columns(1)
with application_col_inc[0]:
st.markdown(
"""
<p class="header" style="display: flex; align-items: center;">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" width="32" height="32" style="margin-right: 10px;">
<circle cx="16" cy="16" r="15" fill="none" stroke="#077493" stroke-width="2"/>
<text x="16" y="21" text-anchor="middle" font-size="16" font-family="Arial" font-weight="bold" fill="#077493">1</text>
</svg> Select Filters</span>
</p>
""",
unsafe_allow_html=True,
)
text_design_col_1, textext_design_col_2 = st.columns([1, 1])
with text_design_col_1:
st.markdown(
"""<p class="description"> Selecting at least one filter is mandatory, because otherwise the model would have to analyze all available documents which results in inaccurate answers and long processing times. Please select at least one filter. We especially recommend to select countries you are interested in. """,
unsafe_allow_html=True,
)
st.write("\n")
col_1, col_2, col_3 = st.columns([1, 1, 1])
with col_1:
selected_authors_inc = st.multiselect(
label="Countries or Associations",
options=authors,
label_visibility="visible",
placeholder="Select",
key="selected_authors_inc",
help="Please select the countries of interest. Your selection will refine the database to include documents submitted by these countries or recognized groupings such as Small Developing States, the African States Group, etc.",
)
with col_2:
selected_rounds_inc = st.multiselect(
label="Session",
options=negotiation_rounds,
label_visibility="visible",
placeholder="Select",
key="selected_rounds_inc",
help="Please select the countries of interest. Your selection will refine the database to include documents submitted by these countries or recognized groupings such as Small Developing States, the African States Group, etc. </p>",
)
with col_3:
selected_draft_cats_inc = st.multiselect(
label="Draft Categories",
options=draft_labs,
label_visibility="visible",
placeholder="Select",
key="selected_draft_cats",
help=" Please select the parts of the negotiation draft of interest. The negotiation draft can be accessed (https://www.unep.org/inc-plastic-pollution/session-4/documents)",
)
st.write("\n")
st.write("\n")
st.markdown(
"""
<p class="header" style="display: flex; align-items: center;">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 32 32" width="32" height="32" style="margin-right: 10px;">
<circle cx="16" cy="16" r="15" fill="none" stroke="#077493" stroke-width="2"/>
<text x="16" y="21" text-anchor="middle" font-size="16" font-family="Arial" font-weight="bold" fill="#077493">2</text>
</svg> Ask a question or show documents based on selected filters
</p>
""",
unsafe_allow_html=True,
)
asking_inc, filtering_inc = st.tabs(["Ask a question", "Filter documents"])
with asking_inc:
application_col_ask_inc, output_col_ask_inc = st.columns([1, 1.5])
with application_col_ask_inc:
st.markdown(
"""
<p class="description"> Ask a question, noting that the database has been restricted by filters and that your question should pertain to the selected data. \n
""",
unsafe_allow_html=True,
)
if "prompt" not in st.session_state:
prompt_inc = st.text_area("")
if (
"prompt" in st.session_state
and st.session_state.prompt in example_prompts # noqa: E501
): # noqa: E501
prompt_inc = st.text_area(
"Enter a question", value=st.session_state.prompt
) # noqa: E501
if (
"prompt" in st.session_state
and st.session_state.prompt not in example_prompts # noqa: E501
): # noqa: E501
del st.session_state["prompt"]
prompt_inc = st.text_area("Enter a question")
trigger_ask_inc = st.session_state.setdefault("trigger_inc", False)
if st.button("Ask", icon=":material/send:", type="primary"):
if prompt_inc == "":
st.error(
"Please enter a question. Reloading the app in few seconds"
)
time.sleep(3)
st.rerun()
with st.spinner("Filtering data...") as status:
if (
not selected_authors_inc
and not selected_draft_cats_inc
and not selected_rounds_inc
):
st.error(
"Selecting a filter is mandatory. We especially recommend to select countries you are interested in. Selecting at least one filter is mandatory, because otherwise the model would have to analyze all available documents which results in inaccurate answers and long processing times. Please select at least one filter."
)
st.stop()
with st.spinner("Analyzing Filters") as status:
filter_selection = {
"author": selected_authors_inc,
"draft_labs": selected_draft_cats_inc,
"round": selected_rounds_inc,
}
filters = inc_rag.build_filter(
filter_selections=filter_selection
)
docs = inc_index.filter_documents(filters)
if not docs:
st.error(
"The combination of filters you've chosen does not match any documents. Please try another combination of filters. If a filter combination does not return any documents, it means that there are no documents that match the selected filters and therefore no answer can be given."
)
trigger_ask_inc = False
st.stop()
else:
st.success("Filtering completed.")
with st.spinner("Answering question..."):
result = inc_rag.run(
query=prompt_inc, filter_selections=filter_selection
)
trigger_ask_inc = True
st.success("Answering question completed.")
st.markdown(
"***≡ Examples***",
help="These are example prompts that can be used to ask questions to the model. Click on a prompt to use it as a question. You can also type your own question in the text area above. In general we highly recommend to use the filter functions to narrow down the data.",
)
st.caption("Double click to select the prompt")
for i, prompt_inc in enumerate(example_prompts):
# with col[i % 4]:
if st.button(prompt_inc):
if "key" not in st.session_state:
st.session_state["prompt"] = prompt_inc
# Define the button
with filtering_inc:
application_col_filter, output_col_filter = st.columns([1, 1.5])
# make the buttons text smaller
with application_col_filter:
st.markdown(
"""
<p class="description">
This filter function allows you to see all documents that match the selected filters. The documents can be accessed via a link. \n
""",
unsafe_allow_html=True,
)
if st.button("Filter", icon=":material/filter_alt:", type="primary"):
if (
not selected_authors_inc
and not selected_draft_cats_inc
and not selected_rounds_inc
):
st.info(
"No filters selected. All documents will be shown. Longer processing time expected."
)
with st.spinner("Filtering documents..."):
filter = RAGPipeline.build_filter(
filter_selections={
"author": selected_authors_inc,
"draft_labs": selected_draft_cats_inc,
"round": selected_rounds_inc,
}
)
result = inc_index.filter_documents(filter)
retriever_ids = set()
result_meta = []
for doc in result:
retriever_id = doc.meta["retriever_id"]
if retriever_id not in retriever_ids:
result_meta.append(
{
"author": doc.meta["author"],
"doc_type": doc.meta["doc_type"],
"session": doc.meta["round"],
"href": doc.meta["href"],
"draft_labs": doc.meta["draft_labs"],
}
)
retriever_ids.add(retriever_id)
else:
continue
result_df = pd.DataFrame(result_meta)
if result_df.empty:
st.info(
"No documents found for the combination of filters you've chosen. All countries are represented at least once in the data. Remove the draft categories to see all documents for the countries selected or try other draft categories and/or sessions."
)
trigger_filter_inc = False
else:
trigger_filter_inc = True
if trigger_filter_inc:
with output_col_filter:
st.markdown("### Overview of all filtered documents")
st.dataframe(
result_df,
hide_index=True,
column_config={
"author": st.column_config.ListColumn("Authors"),
"href": st.column_config.LinkColumn("Link to Document"),
"draft_labs": st.column_config.ListColumn("Draft Categories"),
"session": st.column_config.NumberColumn("Session"),
"doc_type": st.column_config.TextColumn("Document Type"),
},
)
if trigger_ask_inc:
with output_col_ask_inc:
if result is None:
st.error(
"Open AI rate limit exceeded. Please try again in a few seconds."
)
st.stop()
reference_data = [
(doc.meta["retriever_id"], doc.meta["href"])
for doc in result["retriever"]["documents"]
]
references = ["\n"]
for retriever_id, href in reference_data:
references.append(f"-[{retriever_id}]: {href} \n")
references = list(set(references))
st.markdown(
"""<svg xmlns="http://www.w3.org/2000/svg" height="24px" viewBox="0 -960 960 960" width="24px" fill="#077493"><path d="m640-480 80 80v80H520v240l-40 40-40-40v-240H240v-80l80-80v-280h-40v-80h400v80h-40v280Zm-286 80h252l-46-46v-314H400v314l-46 46Zm126 0Z"/></svg> <b>Answer</b>""",
unsafe_allow_html=True,
)
typewriter(
text=result["llm"]["replies"][0],
references=references,
speed=100,
)
with st.expander("Show more information to the documents"):
sorted_docs = sorted(
result["retriever"]["documents"],
key=lambda x: x.meta["retriever_id"],
)
current_doc = None
markdown_text = ""
for doc in sorted_docs:
print(current_doc)
if doc.meta["retriever_id"] != current_doc:
markdown_text += f"- Document: {doc.meta['retriever_id']}\n"
markdown_text += " - Text passages\n"
markdown_text += f" - {doc.content}\n"
else:
markdown_text += f" - {doc.content}\n"
current_doc = doc.meta["retriever_id"]
st.write(markdown_text)
trigger_ask_inc = False
st.markdown(
"""<hr style="height:2px;border:none;color:#077493;background-color:#077493;" /> """,
unsafe_allow_html=True,
)
about_inc()