Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 12,498 Bytes
5d4054c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')))
import pandas as pd
from src.rag.pipeline import RAGPipeline
import streamlit as st
from src.utils.data import (
build_filter,
get_filter_values,
get_meta,
load_json,
load_css,
)
from src.utils.writer import typewriter
st.set_page_config(layout="wide")
EMBEDDING_MODEL = "sentence-transformers/distiluse-base-multilingual-cased-v1"
PROMPT_TEMPLATE = os.path.join("src", "rag", "prompt_template.yaml")
@st.cache_data
def load_css_style(path: str) -> None:
load_css(path)
@st.cache_data
def get_meta_data() -> pd.DataFrame:
return pd.read_csv(
os.path.join("database", "meta_data.csv"), dtype=({"retriever_id": str})
)
@st.cache_data
def get_authors_taxonomy() -> dict[str, list[str]]:
return load_json(os.path.join("data", "authors_filter.json"))
@st.cache_data
def get_draft_cat_taxonomy() -> dict[str, list[str]]:
return load_json(os.path.join("data", "draftcat_taxonomy_filter.json"))
@st.cache_data
def get_example_prompts() -> list[str]:
return [
example["question"]
for example in load_json(os.path.join("data", "example_prompts.json"))
]
@st.cache_resource
def load_pipeline() -> RAGPipeline:
return RAGPipeline(
embedding_model=EMBEDDING_MODEL,
prompt_template=PROMPT_TEMPLATE,
)
@st.cache_data
def load_app_init() -> None:
# Define the title of the app
st.title("INC Plastic Treaty - Q&A")
# add warning emoji and style
st.markdown(
"""
<p class="remark"> ⚠️ Remark:
The app is a beta version that serves as a basis for further development. We are aware that the performance is not yet sufficient and that the data basis is not yet complete. We are grateful for any feedback that contributes to the further development and improvement of the app!
""",
unsafe_allow_html=True,
)
# add explanation to the app
st.markdown(
"""
<p class="description">
The app aims to facilitate the search for information and documents related to the UN Plastics Treaty Negotiations. The database includes all relevant documents that are available <a href=https://www.unep.org/inc-plastic-pollution target="_blank">here</a>. Users can query the data through a chatbot. Please note that, due to technical constraints, only a maximum of 10 documents can be used to generate the answer. A comprehensive response can therefore not be guaranteed. However, all relevant documents can be accessed via a link using the filter functions.
Filter functions are available to narrow down the data by country/author, zero draft categories and negotiation rounds. Pre-selecting relevant data enhances the accuracy of generated answers. Additionally, all documents selected via the filter function can be accessed via a link.
""",
unsafe_allow_html=True,
)
load_css_style("style/style.css")
# Load the data
metadata = get_meta_data()
authors_taxonomy = get_authors_taxonomy()
draft_cat_taxonomy = get_draft_cat_taxonomy()
example_prompts = get_example_prompts()
# Load pipeline
pipeline = load_pipeline()
# Load app init
load_app_init()
filter_col = st.columns(1)
# Filter column
with filter_col[0]:
st.markdown("## Select Filters")
author_col, round_col, draft_cat_col = st.columns([1, 1, 1])
with author_col:
st.markdown("### Authors")
selected_author_parent = st.multiselect(
"Entity Parent", list(authors_taxonomy.keys())
)
available_child_items = []
for category in selected_author_parent:
available_child_items.extend(authors_taxonomy[category])
selected_authors = st.multiselect("Entity", available_child_items)
with round_col:
st.markdown("### Round")
negotiation_rounds = get_filter_values(metadata, "round")
selected_rounds = st.multiselect("Round", negotiation_rounds)
with draft_cat_col:
st.markdown("### Draft Categories")
selected_draft_cats_parent = st.multiselect(
"Draft Categories Parent", list(draft_cat_taxonomy.keys())
)
available_draft_cats_child_items = []
for category in selected_draft_cats_parent:
available_draft_cats_child_items.extend(draft_cat_taxonomy[category])
selected_draft_cats = st.multiselect(
"Draft Categories", available_draft_cats_child_items
)
prompt_col, output_col = st.columns([1, 1.5])
# make the buttons text smaller
# GPT column
with prompt_col:
st.markdown("## Filter documents")
st.markdown(
"""
* The filter function allows you to see all documents that match the selected filters.
* Additionally, all documents selected via the filter function can be accessed via a link.
* Alternatively, you can ask a question to the model. The model will then provide you with an answer based on the filtered documents.
"""
)
trigger_filter = st.session_state.setdefault("trigger", False)
if st.button("Filter documents"):
filter_selection_transformed = build_filter(
meta_data=metadata,
authors_filter=selected_authors,
draft_cats_filter=selected_draft_cats,
round_filter=selected_rounds,
)
documents = pipeline.document_store.get_all_documents(
filters=filter_selection_transformed
)
trigger_filter = True
st.markdown("## Ask a question")
if "prompt" not in st.session_state:
prompt = st.text_area("")
if (
"prompt" in st.session_state
and st.session_state.prompt in example_prompts # noqa: E501
): # noqa: E501
prompt = st.text_area(
"Enter a question", value=st.session_state.prompt
) # noqa: E501
if (
"prompt" in st.session_state
and st.session_state.prompt not in example_prompts # noqa: E501
): # noqa: E501
del st.session_state["prompt"]
prompt = st.text_area("Enter a question")
trigger_ask = st.session_state.setdefault("trigger", False)
if st.button("Ask"):
with st.status("Filtering documents...", expanded=False) as status:
if filter_selection_transformed == {}:
st.warning(
"No filters selected. We highly recommend to use filters otherwise the answer might not be accurate. In addition you might experience performance issues since the model has to analyze all available documents."
)
filter_selection_transformed = build_filter(
meta_data=metadata,
authors_filter=selected_authors,
draft_cats_filter=selected_draft_cats,
round_filter=selected_rounds,
)
documents = pipeline.document_store.get_all_documents(
filters=filter_selection_transformed
)
status.update(
label="Filtering documents completed!", state="complete", expanded=False
)
with st.status("Answering question...", expanded=True) as status:
result = pipeline(prompt=prompt, filters=filter_selection_transformed)
trigger_ask = True
status.update(
label="Answering question completed!", state="complete", expanded=False
)
st.markdown("### Examples")
st.markdown(
"""
* These are example prompts that can be used to ask questions to the model
* Click on a prompt to use it as a question. You can also type your own question in the text area above.
* For questions like "How do country a, b and c [...]" please make sure to select the countries in the filter section. Otherwise the answer will not be accurate. In general we highly recommend to use the filter functions to narrow down the data.
"""
)
for i, prompt in enumerate(example_prompts):
# with col[i % 4]:
if st.button(prompt):
if "key" not in st.session_state:
st.session_state["prompt"] = prompt
# Define the button
if trigger_ask:
with output_col:
meta_data = get_meta(result=result)
answer = result["answers"][0].answer
meta_data_cleaned = []
seen_retriever_ids = set()
for data in meta_data:
retriever_id = data["retriever_id"]
content = data["content"]
if retriever_id not in seen_retriever_ids:
meta_data_cleaned.append(
{
"retriever_id": retriever_id,
"href": data["href"],
"content": [content],
}
)
seen_retriever_ids.add(retriever_id)
else:
for i, item in enumerate(meta_data_cleaned):
if item["retriever_id"] == retriever_id:
meta_data_cleaned[i]["content"].append(content)
references = ["\n"]
for data in meta_data_cleaned:
retriever_id = data["retriever_id"]
href = data["href"]
references.append(f"-[{retriever_id}]: {href} \n")
st.write("#### 📌 Answer")
typewriter(
text=answer,
references=references,
speed=100,
)
with st.expander("Show more information to the documents"):
for data in meta_data_cleaned:
markdown_text = f"- Document: {data['retriever_id']}\n"
markdown_text += " - Text passages\n"
for content in data["content"]:
content = content.replace("[", "").replace("]", "").replace("'", "")
content = " ".join(content.split())
markdown_text += f" - {content}\n"
st.write(markdown_text)
col4 = st.columns(1)
with col4[0]:
references = []
for document in documents:
authors = document.meta["author"]
authors = authors.replace("'", "").replace("[", "").replace("]", "")
href = document.meta["href"]
source = f"- {authors}: {href}"
references.append(source)
references = list(set(references))
references = sorted(references)
st.markdown("### Overview of all filtered documents")
st.markdown(
f"<p class='description'> The answer above results from the most similar text passages (top 7) from the documents that you can find under 'References' in the answer block. Below you will find an overview of all documents that match the filters you have selected. Please note that the above answer is based specifically on the highlighted references above and does not include the findings from all the filtered documents shown below. \n For your current filtering, {len(references)} documents were found. </p>",
unsafe_allow_html=True,
)
for reference in references:
st.write(reference)
trigger = 0
if trigger_filter:
with output_col:
references = []
for document in documents:
authors = document.meta["author"]
authors = authors.replace("'", "").replace("[", "").replace("]", "")
href = document.meta["href"]
round_ = document.meta["round"]
draft_labs = document.meta["draft_labs"]
references.append(
{
"author": authors,
"href": href,
"draft_labs": draft_labs,
"round": round_,
}
)
references = pd.DataFrame(references)
references = references.drop_duplicates()
st.markdown("### Overview of all filtered documents")
# show
# make columns author and draft_labs bigger and make href width smaller and round width smaller
st.dataframe(
references,
hide_index=True,
column_config={
"author": st.column_config.ListColumn("Authors"),
"href": st.column_config.LinkColumn("Link to Document"),
"draft_labs": st.column_config.ListColumn("Draft Categories"),
"round": st.column_config.NumberColumn("Round"),
},
)
|