TeresaK commited on
Commit
93b63ff
·
verified ·
1 Parent(s): 8dc32c9

Delete src

Browse files
src/__init__py DELETED
File without changes
src/analysis/__init__.py DELETED
File without changes
src/app/v1/app.py DELETED
@@ -1,329 +0,0 @@
1
- import os
2
- import sys
3
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')))
4
- import pandas as pd
5
-
6
-
7
- from src.rag.pipeline import RAGPipeline
8
- import streamlit as st
9
- from src.utils.data import (
10
- build_filter,
11
- get_filter_values,
12
- get_meta,
13
- load_json,
14
- load_css,
15
- )
16
- from src.utils.writer import typewriter
17
-
18
- st.set_page_config(layout="wide")
19
-
20
-
21
-
22
- EMBEDDING_MODEL = "sentence-transformers/distiluse-base-multilingual-cased-v1"
23
- PROMPT_TEMPLATE = os.path.join("src", "rag", "prompt_template.yaml")
24
-
25
-
26
- @st.cache_data
27
- def load_css_style(path: str) -> None:
28
- load_css(path)
29
-
30
-
31
- @st.cache_data
32
- def get_meta_data() -> pd.DataFrame:
33
- return pd.read_csv(
34
- os.path.join("database", "meta_data.csv"), dtype=({"retriever_id": str})
35
- )
36
-
37
-
38
- @st.cache_data
39
- def get_authors_taxonomy() -> dict[str, list[str]]:
40
- return load_json(os.path.join("data", "authors_filter.json"))
41
-
42
-
43
- @st.cache_data
44
- def get_draft_cat_taxonomy() -> dict[str, list[str]]:
45
- return load_json(os.path.join("data", "draftcat_taxonomy_filter.json"))
46
-
47
-
48
- @st.cache_data
49
- def get_example_prompts() -> list[str]:
50
- return [
51
- example["question"]
52
- for example in load_json(os.path.join("data", "example_prompts.json"))
53
- ]
54
-
55
-
56
- @st.cache_resource
57
- def load_pipeline() -> RAGPipeline:
58
- return RAGPipeline(
59
- embedding_model=EMBEDDING_MODEL,
60
- prompt_template=PROMPT_TEMPLATE,
61
- )
62
-
63
-
64
- @st.cache_data
65
- def load_app_init() -> None:
66
- # Define the title of the app
67
- st.title("INC Plastic Treaty - Q&A")
68
-
69
- # add warning emoji and style
70
- st.markdown(
71
- """
72
- <p class="remark"> ⚠️ Remark:
73
- The app is a beta version that serves as a basis for further development. We are aware that the performance is not yet sufficient and that the data basis is not yet complete. We are grateful for any feedback that contributes to the further development and improvement of the app!
74
- """,
75
- unsafe_allow_html=True,
76
- )
77
-
78
- # add explanation to the app
79
- st.markdown(
80
- """
81
- <p class="description">
82
- The app aims to facilitate the search for information and documents related to the UN Plastics Treaty Negotiations. The database includes all relevant documents that are available <a href=https://www.unep.org/inc-plastic-pollution target="_blank">here</a>. Users can query the data through a chatbot. Please note that, due to technical constraints, only a maximum of 10 documents can be used to generate the answer. A comprehensive response can therefore not be guaranteed. However, all relevant documents can be accessed via a link using the filter functions.
83
- Filter functions are available to narrow down the data by country/author, zero draft categories and negotiation rounds. Pre-selecting relevant data enhances the accuracy of generated answers. Additionally, all documents selected via the filter function can be accessed via a link.
84
- """,
85
- unsafe_allow_html=True,
86
- )
87
-
88
-
89
- load_css_style("style/style.css")
90
-
91
-
92
- # Load the data
93
- metadata = get_meta_data()
94
- authors_taxonomy = get_authors_taxonomy()
95
- draft_cat_taxonomy = get_draft_cat_taxonomy()
96
- example_prompts = get_example_prompts()
97
-
98
- # Load pipeline
99
- pipeline = load_pipeline()
100
-
101
- # Load app init
102
- load_app_init()
103
-
104
-
105
- filter_col = st.columns(1)
106
- # Filter column
107
- with filter_col[0]:
108
- st.markdown("## Select Filters")
109
- author_col, round_col, draft_cat_col = st.columns([1, 1, 1])
110
-
111
- with author_col:
112
- st.markdown("### Authors")
113
- selected_author_parent = st.multiselect(
114
- "Entity Parent", list(authors_taxonomy.keys())
115
- )
116
-
117
- available_child_items = []
118
- for category in selected_author_parent:
119
- available_child_items.extend(authors_taxonomy[category])
120
-
121
- selected_authors = st.multiselect("Entity", available_child_items)
122
-
123
- with round_col:
124
- st.markdown("### Round")
125
- negotiation_rounds = get_filter_values(metadata, "round")
126
- selected_rounds = st.multiselect("Round", negotiation_rounds)
127
-
128
- with draft_cat_col:
129
- st.markdown("### Draft Categories")
130
- selected_draft_cats_parent = st.multiselect(
131
- "Draft Categories Parent", list(draft_cat_taxonomy.keys())
132
- )
133
- available_draft_cats_child_items = []
134
- for category in selected_draft_cats_parent:
135
- available_draft_cats_child_items.extend(draft_cat_taxonomy[category])
136
-
137
- selected_draft_cats = st.multiselect(
138
- "Draft Categories", available_draft_cats_child_items
139
- )
140
-
141
-
142
- prompt_col, output_col = st.columns([1, 1.5])
143
- # make the buttons text smaller
144
-
145
-
146
- # GPT column
147
- with prompt_col:
148
- st.markdown("## Filter documents")
149
- st.markdown(
150
- """
151
- * The filter function allows you to see all documents that match the selected filters.
152
- * Additionally, all documents selected via the filter function can be accessed via a link.
153
- * Alternatively, you can ask a question to the model. The model will then provide you with an answer based on the filtered documents.
154
- """
155
- )
156
- trigger_filter = st.session_state.setdefault("trigger", False)
157
- if st.button("Filter documents"):
158
- filter_selection_transformed = build_filter(
159
- meta_data=metadata,
160
- authors_filter=selected_authors,
161
- draft_cats_filter=selected_draft_cats,
162
- round_filter=selected_rounds,
163
- )
164
- documents = pipeline.document_store.get_all_documents(
165
- filters=filter_selection_transformed
166
- )
167
- trigger_filter = True
168
-
169
- st.markdown("## Ask a question")
170
- if "prompt" not in st.session_state:
171
- prompt = st.text_area("")
172
- if (
173
- "prompt" in st.session_state
174
- and st.session_state.prompt in example_prompts # noqa: E501
175
- ): # noqa: E501
176
- prompt = st.text_area(
177
- "Enter a question", value=st.session_state.prompt
178
- ) # noqa: E501
179
- if (
180
- "prompt" in st.session_state
181
- and st.session_state.prompt not in example_prompts # noqa: E501
182
- ): # noqa: E501
183
- del st.session_state["prompt"]
184
- prompt = st.text_area("Enter a question")
185
-
186
- trigger_ask = st.session_state.setdefault("trigger", False)
187
- if st.button("Ask"):
188
- with st.status("Filtering documents...", expanded=False) as status:
189
- if filter_selection_transformed == {}:
190
- st.warning(
191
- "No filters selected. We highly recommend to use filters otherwise the answer might not be accurate. In addition you might experience performance issues since the model has to analyze all available documents."
192
- )
193
- filter_selection_transformed = build_filter(
194
- meta_data=metadata,
195
- authors_filter=selected_authors,
196
- draft_cats_filter=selected_draft_cats,
197
- round_filter=selected_rounds,
198
- )
199
-
200
- documents = pipeline.document_store.get_all_documents(
201
- filters=filter_selection_transformed
202
- )
203
- status.update(
204
- label="Filtering documents completed!", state="complete", expanded=False
205
- )
206
- with st.status("Answering question...", expanded=True) as status:
207
- result = pipeline(prompt=prompt, filters=filter_selection_transformed)
208
- trigger_ask = True
209
- status.update(
210
- label="Answering question completed!", state="complete", expanded=False
211
- )
212
-
213
- st.markdown("### Examples")
214
- st.markdown(
215
- """
216
- * These are example prompts that can be used to ask questions to the model
217
- * Click on a prompt to use it as a question. You can also type your own question in the text area above.
218
- * For questions like "How do country a, b and c [...]" please make sure to select the countries in the filter section. Otherwise the answer will not be accurate. In general we highly recommend to use the filter functions to narrow down the data.
219
- """
220
- )
221
-
222
- for i, prompt in enumerate(example_prompts):
223
- # with col[i % 4]:
224
- if st.button(prompt):
225
- if "key" not in st.session_state:
226
- st.session_state["prompt"] = prompt
227
- # Define the button
228
-
229
-
230
- if trigger_ask:
231
- with output_col:
232
- meta_data = get_meta(result=result)
233
- answer = result["answers"][0].answer
234
-
235
- meta_data_cleaned = []
236
- seen_retriever_ids = set()
237
-
238
- for data in meta_data:
239
- retriever_id = data["retriever_id"]
240
- content = data["content"]
241
- if retriever_id not in seen_retriever_ids:
242
- meta_data_cleaned.append(
243
- {
244
- "retriever_id": retriever_id,
245
- "href": data["href"],
246
- "content": [content],
247
- }
248
- )
249
- seen_retriever_ids.add(retriever_id)
250
- else:
251
- for i, item in enumerate(meta_data_cleaned):
252
- if item["retriever_id"] == retriever_id:
253
- meta_data_cleaned[i]["content"].append(content)
254
-
255
- references = ["\n"]
256
- for data in meta_data_cleaned:
257
- retriever_id = data["retriever_id"]
258
- href = data["href"]
259
- references.append(f"-[{retriever_id}]: {href} \n")
260
- st.write("#### 📌 Answer")
261
- typewriter(
262
- text=answer,
263
- references=references,
264
- speed=100,
265
- )
266
-
267
- with st.expander("Show more information to the documents"):
268
- for data in meta_data_cleaned:
269
- markdown_text = f"- Document: {data['retriever_id']}\n"
270
- markdown_text += " - Text passages\n"
271
- for content in data["content"]:
272
- content = content.replace("[", "").replace("]", "").replace("'", "")
273
- content = " ".join(content.split())
274
- markdown_text += f" - {content}\n"
275
- st.write(markdown_text)
276
-
277
- col4 = st.columns(1)
278
- with col4[0]:
279
- references = []
280
- for document in documents:
281
- authors = document.meta["author"]
282
- authors = authors.replace("'", "").replace("[", "").replace("]", "")
283
- href = document.meta["href"]
284
- source = f"- {authors}: {href}"
285
- references.append(source)
286
- references = list(set(references))
287
- references = sorted(references)
288
- st.markdown("### Overview of all filtered documents")
289
- st.markdown(
290
- f"<p class='description'> The answer above results from the most similar text passages (top 7) from the documents that you can find under 'References' in the answer block. Below you will find an overview of all documents that match the filters you have selected. Please note that the above answer is based specifically on the highlighted references above and does not include the findings from all the filtered documents shown below. \n For your current filtering, {len(references)} documents were found. </p>",
291
- unsafe_allow_html=True,
292
- )
293
- for reference in references:
294
- st.write(reference)
295
- trigger = 0
296
-
297
-
298
- if trigger_filter:
299
- with output_col:
300
- references = []
301
- for document in documents:
302
- authors = document.meta["author"]
303
- authors = authors.replace("'", "").replace("[", "").replace("]", "")
304
- href = document.meta["href"]
305
- round_ = document.meta["round"]
306
- draft_labs = document.meta["draft_labs"]
307
- references.append(
308
- {
309
- "author": authors,
310
- "href": href,
311
- "draft_labs": draft_labs,
312
- "round": round_,
313
- }
314
- )
315
- references = pd.DataFrame(references)
316
- references = references.drop_duplicates()
317
- st.markdown("### Overview of all filtered documents")
318
- # show
319
- # make columns author and draft_labs bigger and make href width smaller and round width smaller
320
- st.dataframe(
321
- references,
322
- hide_index=True,
323
- column_config={
324
- "author": st.column_config.ListColumn("Authors"),
325
- "href": st.column_config.LinkColumn("Link to Document"),
326
- "draft_labs": st.column_config.ListColumn("Draft Categories"),
327
- "round": st.column_config.NumberColumn("Round"),
328
- },
329
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/app/v2/app.py DELETED
@@ -1,385 +0,0 @@
1
- import os
2
- import sys
3
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')))
4
- import pandas as pd
5
- import streamlit as st
6
- import time
7
-
8
-
9
- from src.rag.pipeline import RAGPipeline
10
- from src.utils.data_v2 import (
11
- build_filter,
12
- get_meta,
13
- load_json,
14
- load_css,
15
- )
16
- from src.utils.writer import typewriter
17
-
18
-
19
- st.set_page_config(layout="wide")
20
-
21
- EMBEDDING_MODEL = "sentence-transformers/distiluse-base-multilingual-cased-v1"
22
- PROMPT_TEMPLATE = os.path.join("src", "rag", "prompt_template.yaml")
23
-
24
-
25
- @st.cache_data
26
- def load_css_style(path: str) -> None:
27
- load_css(path)
28
-
29
-
30
- @st.cache_data
31
- def get_meta_data() -> pd.DataFrame:
32
- return pd.read_csv(
33
- os.path.join("database", "meta_data.csv"), dtype=({"retriever_id": str})
34
- )
35
-
36
-
37
- @st.cache_data
38
- def get_df() -> pd.DataFrame:
39
- return pd.read_csv(
40
- os.path.join("data", "inc_df.csv"), dtype=({"retriever_id": str})
41
- )[["retriever_id", "draft_labs", "author", "href", "round"]]
42
-
43
-
44
- @st.cache_data
45
- def get_authors_taxonomy() -> list[str]:
46
- taxonomy = load_json(os.path.join("data", "authors_taxonomy.json"))
47
- countries = []
48
- members = taxonomy["Members"]
49
- for key, value in members.items():
50
- if key == "Countries" or key == "International and Regional State Associations":
51
- countries.extend(value)
52
- return countries
53
-
54
-
55
- @st.cache_data
56
- def get_draft_cat_taxonomy() -> dict[str, list[str]]:
57
- taxonomy = load_json(os.path.join("data", "draftcat_taxonomy_filter.json"))
58
- draft_labels = []
59
- for _, subpart in taxonomy.items():
60
- for label in subpart:
61
- draft_labels.append(label)
62
- return draft_labels
63
-
64
-
65
- @st.cache_data
66
- def get_example_prompts() -> list[str]:
67
- return [
68
- example["question"]
69
- for example in load_json(os.path.join("data", "example_prompts.json"))
70
- ]
71
-
72
-
73
- @st.cache_data
74
- def set_trigger_state_values() -> tuple[bool, bool]:
75
- trigger_filter = st.session_state.setdefault("trigger", False)
76
- trigger_ask = st.session_state.setdefault("trigger", False)
77
- return trigger_filter, trigger_ask
78
-
79
-
80
- @st.cache_resource
81
- def load_pipeline() -> RAGPipeline:
82
- return RAGPipeline(
83
- embedding_model=EMBEDDING_MODEL,
84
- prompt_template=PROMPT_TEMPLATE,
85
- )
86
-
87
-
88
- @st.cache_data
89
- def load_app_init() -> None:
90
- # Define the title of the app
91
- st.title("INC Plastic Pollution Country Profile Analysis")
92
-
93
- # add warning emoji and style
94
-
95
- st.markdown(
96
- """
97
- <div class="remark">
98
- <div class="remark-content">
99
- <p class="remark-text" style="font-size: 20px;"> ⚠️ The app is a beta version that serves as a basis for further development. We are aware that the performance is not yet sufficient and that the data basis is not yet complete. We are grateful for any feedback that contributes to the further development and improvement of the app!</p>
100
- </div>
101
- </div>
102
- """,
103
- unsafe_allow_html=True,
104
- )
105
-
106
- st.markdown(
107
- """
108
- <a href="mailto:[email protected]" class="feedback-link">Send feedback</a>
109
- """,
110
- unsafe_allow_html=True,
111
- )
112
-
113
- # add explanation to the app
114
- st.markdown(
115
- """
116
- <p class="description">
117
- The app is tailored to enhance the efficiency of finding and accessing information on the UN Plastics Treaty Negotiations. It hosts a comprehensive database of relevant documents submitted by the members available <a href="https://www.unep.org/inc-plastic-pollution"> here</a>, which users can explore through an intuitive chatbot interface as well as simple filtering options.
118
- The app excels in querying specific information about countries and their positions in the negotiations, providing targeted and precise answers. However, it can process only up to 8 relevant documents at a time, which may limit responses to more complex inquiries. Filter options by authors and sections of the negotiation draft ensure the accuracy of the answers. Each document found via these filters is also directly accessible via a link, ensuring complete and easy access to the desired information.
119
- </p>
120
- """,
121
- unsafe_allow_html=True,
122
- )
123
-
124
-
125
- load_css_style("style/style.css")
126
-
127
-
128
- # Load the data
129
- df = get_df()
130
- df_transformed = get_meta_data()
131
- countries = get_authors_taxonomy()
132
- draft_labels = get_draft_cat_taxonomy()
133
- example_prompts = get_example_prompts()
134
- trigger_filter, trigger_ask = set_trigger_state_values()
135
-
136
- # Load pipeline
137
- pipeline = load_pipeline()
138
-
139
- # Load app init
140
- load_app_init()
141
-
142
-
143
- application_col = st.columns(1)
144
-
145
-
146
- with application_col[0]:
147
- st.markdown("""<p class="header"> 1️⃣ Select countries""", unsafe_allow_html=True)
148
- st.markdown(
149
- """
150
- <p class="description">
151
- Please select the countries of interest. Your selection will refine the database to include documents submitted by these countries or recognized groupings such as Small Developing States, the African States Group, etc. </p>
152
- """,
153
- unsafe_allow_html=True,
154
- )
155
- selected_authors = st.multiselect(
156
- label="country",
157
- options=countries,
158
- label_visibility="collapsed",
159
- placeholder="Select country/countries",
160
- )
161
-
162
- st.write("\n")
163
- st.write("\n")
164
-
165
- st.markdown(
166
- """<p class="header"> 2️⃣ Select parts of the negotiation draft""",
167
- unsafe_allow_html=True,
168
- )
169
- st.markdown(
170
- """
171
- <p class="description">
172
- Please select the parts of the negotiation draft of interest. The negotiation draft can be accessed <a href="https://www.unep.org/inc-plastic-pollution/session-4/documents"> here</a>. </p>
173
- """,
174
- unsafe_allow_html=True,
175
- )
176
- selected_draft_cats = st.multiselect(
177
- label="Subpart",
178
- options=draft_labels,
179
- label_visibility="collapsed",
180
- placeholder="Select draft category/draft categories",
181
- )
182
-
183
- st.write("\n")
184
- st.write("\n")
185
-
186
- st.markdown(
187
- """<p class="header"> 3️⃣ Ask a question or show documents based on selected filters""",
188
- unsafe_allow_html=True,
189
- )
190
-
191
- asking, filtering = st.tabs(["Ask a question", "Filter documents"])
192
-
193
- with filtering:
194
- application_col_filter, output_col_filter = st.columns([1, 1.5])
195
- # make the buttons text smaller
196
- with application_col_filter:
197
- st.markdown(
198
- """
199
- <p class="description">
200
- This filter function allows you to see all documents that match the selected filters. The documents can be accessed via a link. \n
201
- """,
202
- unsafe_allow_html=True,
203
- )
204
- if st.button("Filter documents"):
205
- filters, status = build_filter(
206
- meta_data=df_transformed,
207
- authors_filter=selected_authors,
208
- draft_cats_filter=selected_draft_cats,
209
- )
210
- if status == "no filters selected":
211
- st.info("No filters selected. All documents will be shown.")
212
- df_filtered = df[
213
- ["author", "href", "draft_labs", "round"]
214
- ].sort_values(by="author")
215
- trigger_filter = True
216
- if status == "no results found":
217
- st.info(
218
- "No documents found for the combination of filters you've chosen. All countries are represented at least once in the data. Remove the draft categories to see all documents for the countries selected or try other draft categories."
219
- )
220
- if status == "success":
221
- df_filtered = df[df["retriever_id"].isin(filters["retriever_id"])][
222
- ["author", "href", "draft_labs", "round"]
223
- ].sort_values(by="author")
224
- trigger_filter = True
225
-
226
- with asking:
227
- application_col_ask, output_col_ask = st.columns([1, 1.5])
228
- with application_col_ask:
229
- st.markdown(
230
- """
231
- <p class="description"> Ask a question, noting that the database has been restricted by filters and that your question should pertain to the selected data. \n
232
- """,
233
- unsafe_allow_html=True,
234
- )
235
- if "prompt" not in st.session_state:
236
- prompt = st.text_area("Enter a question")
237
- if (
238
- "prompt" in st.session_state
239
- and st.session_state.prompt in example_prompts # noqa: E501
240
- ): # noqa: E501
241
- prompt = st.text_area(
242
- "Enter a question", value=st.session_state.prompt
243
- ) # noqa: E501
244
- if (
245
- "prompt" in st.session_state
246
- and st.session_state.prompt not in example_prompts # noqa: E501
247
- ): # noqa: E501
248
- del st.session_state["prompt"]
249
- prompt = st.text_area("Enter a question")
250
-
251
- trigger_ask = st.session_state.setdefault("trigger", False)
252
-
253
- if st.button("Ask"):
254
- if prompt == "":
255
- st.error(
256
- "Please enter a question. Reloading the app in few seconds"
257
- )
258
- time.sleep(3)
259
- st.rerun()
260
- with st.spinner("Filtering data...") as status:
261
- filter_selection_transformed, status = build_filter(
262
- meta_data=df_transformed,
263
- authors_filter=selected_authors,
264
- draft_cats_filter=selected_draft_cats,
265
- )
266
-
267
- if status == "no filters selected":
268
- st.info(
269
- "No filters selcted.This will increase the prcessing time significantly. Please select at least one filter."
270
- )
271
- # st.error(
272
- # "Selecting a filter is mandatory. We especially recommend to select countries you are interested in. Selecting at least one filter is mandatory, because otherwise the model would have to analyze all available documents which results in inaccurate answers and long processing times. Please select at least one filter."
273
- # )
274
- # st.stop()
275
-
276
- documents = pipeline.document_store.get_all_documents(
277
- filters=filter_selection_transformed
278
- )
279
-
280
- st.success("Filtering data completed.")
281
- with st.spinner("Answering question...") as status:
282
- if filter_selection_transformed == {}:
283
- st.warning(
284
- "The combination of filters you've chosen does not match any documents. Giving answer based on all documents. Please note that the answer might not be accurate. We highly recommend to use a combination of filters that match the data. All countries are represented at least once in the data. Thus, for example, you could remove the draft categories to match the documents. Or you could check with the Filter documents function which documents are available for the selected countries by removing the draft categories and filter the documents."
285
- )
286
-
287
- result = pipeline.run(
288
- prompt=prompt, filters=filter_selection_transformed
289
- )
290
- trigger_ask = True
291
- st.success("Answering question completed.")
292
-
293
- st.markdown("### Examples")
294
- for i, prompt in enumerate(example_prompts):
295
- # with col[i % 4]:
296
- if st.button(prompt):
297
- if "key" not in st.session_state:
298
- st.session_state["prompt"] = prompt
299
- st.markdown(
300
- """
301
- <ul class="description" style="font-size: 20px;">
302
- <li style="font-size: 17px;">These are example prompts that can be used to ask questions to the model</li>
303
- <li style="font-size: 17px;">Click on a prompt to use it as a question. You can also type your own question in the text area above.</li>
304
- <li style="font-size: 17px;">For questions like "How do country a, b and c [...]" please make sure to select the countries in the filter section. Otherwise the answer will not be accurate. In general we highly recommend to use the filter functions to narrow down the data.</li>
305
- </ul>
306
- """,
307
- unsafe_allow_html=True,
308
- )
309
-
310
- # for i, prompt in enumerate(example_prompts):
311
- # # with col[i % 4]:
312
- # if st.button(prompt):
313
- # if "key" not in st.session_state:
314
- # st.session_state["prompt"] = prompt
315
- # Define the button
316
-
317
- if trigger_ask:
318
- with output_col_ask:
319
- if result is None:
320
- st.error(
321
- "Open AI rate limit exceeded. Please try again in a few seconds."
322
- )
323
- st.stop()
324
- meta_data = get_meta(result=result)
325
- answer = result["answers"][0].answer
326
-
327
- meta_data_cleaned = []
328
- seen_retriever_ids = set()
329
-
330
- for data in meta_data:
331
- retriever_id = data["retriever_id"]
332
- content = data["content"]
333
- if retriever_id not in seen_retriever_ids:
334
- meta_data_cleaned.append(
335
- {
336
- "retriever_id": retriever_id,
337
- "href": data["href"],
338
- "content": [content],
339
- }
340
- )
341
- seen_retriever_ids.add(retriever_id)
342
- else:
343
- for i, item in enumerate(meta_data_cleaned):
344
- if item["retriever_id"] == retriever_id:
345
- meta_data_cleaned[i]["content"].append(content)
346
-
347
- references = ["\n"]
348
- for data in meta_data_cleaned:
349
- retriever_id = data["retriever_id"]
350
- href = data["href"]
351
- references.append(f"-[{retriever_id}]: {href} \n")
352
- st.write("#### 📌 Answer")
353
- typewriter(
354
- text=answer,
355
- references=references,
356
- speed=100,
357
- )
358
-
359
- with st.expander("Show more information to the documents"):
360
- for data in meta_data_cleaned:
361
- markdown_text = f"- Document: {data['retriever_id']}\n"
362
- markdown_text += " - Text passages\n"
363
- for content in data["content"]:
364
- content = (
365
- content.replace("[", "").replace("]", "").replace("'", "")
366
- )
367
- content = " ".join(content.split())
368
- markdown_text += f" - {content}\n"
369
- st.write(markdown_text)
370
-
371
- trigger = 0
372
-
373
- if trigger_filter:
374
- with output_col_filter:
375
- st.markdown("### Overview of all filtered documents")
376
- st.dataframe(
377
- df_filtered,
378
- hide_index=True,
379
- column_config={
380
- "author": st.column_config.ListColumn("Authors"),
381
- "href": st.column_config.LinkColumn("Link to Document"),
382
- "draft_labs": st.column_config.ListColumn("Draft Categories"),
383
- "round": st.column_config.NumberColumn("Round"),
384
- },
385
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data_processing/__init__.py DELETED
File without changes
src/data_processing/document_store_data.py DELETED
@@ -1,93 +0,0 @@
1
- import pandas as pd
2
- import ast
3
- import json
4
-
5
- DATASET = "data/inc_df_v6_small_4.csv"
6
- DATASET_PROCESSED = "data/inc_df.csv"
7
- MEMBERS = "data/authors_filter.json"
8
-
9
-
10
- def main():
11
- print(f"Length of dataset: {len(pd.read_csv(DATASET))}")
12
- df = pd.read_csv(DATASET)
13
- df["retriever_id"] = df.index
14
- columns = [
15
- "retriever_id",
16
- "description",
17
- "href",
18
- "draft_labs_list",
19
- "authors_list",
20
- "draft_allcats",
21
- "doc_subtype",
22
- "doc_type",
23
- "text",
24
- "round",
25
- ]
26
-
27
- df = df[columns]
28
-
29
- df.rename(
30
- mapper={
31
- "draft_labs_list": "draft_labs",
32
- "draft_allcats": "draft_cats",
33
- "authors_list": "author",
34
- },
35
- axis=1,
36
- inplace=True,
37
- )
38
-
39
- ###Subselect for countries and country groups
40
- with open(MEMBERS, "r") as f:
41
- authors = json.load(f)
42
- special_character_words_mapper = {
43
- "Côte D'Ivoire": "Cote DIvoire",
44
- "Ligue Camerounaise Des Droits De L'Homme": "Ligue Camerounaise Des Droits De LHomme",
45
- "Association Pour L'Integration Et La Developpement Durable Au Burundi": "Association Pour LIntegration Et La Developpement Durable Au Burundi",
46
- }
47
- members = [
48
- authors[key]
49
- for key in [
50
- "Members - Countries",
51
- "Members - International and Regional State Associations",
52
- ]
53
- ]
54
- members = [item for sublist in members for item in sublist]
55
- members = [special_character_words_mapper.get(member, member) for member in members]
56
-
57
- nonmembers = [
58
- authors[key]
59
- for key in [
60
- "Intergovernmental Negotiation Committee",
61
- "Observers and Other Participants",
62
- ]
63
- ]
64
- nonmembers = [item for sublist in nonmembers for item in sublist]
65
-
66
- df["author"][df["author"] == "['Côte D'Ivoire']"] = "['Cote DIvoire']"
67
- df["author"][
68
- df["author"] == "['Ligue Camerounaise Des Droits De L'Homme']"
69
- ] = "['Ligue Camerounaise Des Droits De LHomme']"
70
- df["author"][
71
- df["author"]
72
- == "['Association Pour L'Integration Et La Developpement Durable Au Burundi']"
73
- ] = "['Association Pour LIntegration Et La Developpement Durable Au Burundi']"
74
-
75
- df["author"] = df["author"].apply(ast.literal_eval)
76
- df = df[df["author"].apply(lambda x: any(item in members for item in x))]
77
- df["author"] = df["author"].apply(
78
- lambda x: [item for item in x if item not in nonmembers]
79
- )
80
- df["author"] = df["author"].apply(
81
- lambda x: [item.replace("Côte DIvoire", "Cote D'Ivoire") for item in x]
82
- )
83
- df["draft_labs"] = df["draft_labs"].fillna("[]")
84
- df["author"][
85
- df["author"] == "['The Alliance Of Small Island States (AOSIS)']"
86
- ] = "['Alliance Of Small Island States (AOSIS)']"
87
-
88
- print(f"Filtered dataset to {len(df)} entries")
89
- df.to_csv(DATASET_PROCESSED, index=False)
90
-
91
-
92
- if __name__ == "__main__":
93
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data_processing/document_store_data_all.py DELETED
@@ -1,93 +0,0 @@
1
- import pandas as pd
2
- import ast
3
- import json
4
-
5
- DATASET = "data/inc_df_v6_small.csv"
6
- DATASET_PROCESSED = "data/inc_df.csv"
7
- MEMBERS = "data/authors_filter.json"
8
-
9
-
10
- def main():
11
- print(f"Length of dataset: {len(pd.read_csv(DATASET))}")
12
- df = pd.read_csv(DATASET)
13
- df["retriever_id"] = df.index
14
- columns = [
15
- "retriever_id",
16
- "description",
17
- "href",
18
- "draft_labs_list",
19
- "authors_list",
20
- "draft_allcats",
21
- "doc_subtype",
22
- "doc_type",
23
- "text",
24
- "round",
25
- ]
26
-
27
- df = df[columns]
28
-
29
- df.rename(
30
- mapper={
31
- "draft_labs_list": "draft_labs",
32
- "draft_allcats": "draft_cats",
33
- "authors_list": "author",
34
- },
35
- axis=1,
36
- inplace=True,
37
- )
38
-
39
- ###Subselect for countries and country groups
40
- with open(MEMBERS, "r") as f:
41
- authors = json.load(f)
42
- special_character_words_mapper = {
43
- "Côte D'Ivoire": "Côte DIvoire",
44
- "Ligue Camerounaise Des Droits De L'Homme": "Ligue Camerounaise Des Droits De LHomme",
45
- "Association Pour L'Integration Et La Developpement Durable Au Burundi": "Association Pour LIntegration Et La Developpement Durable Au Burundi",
46
- }
47
- members = [
48
- authors[key]
49
- for key in [
50
- "Members - Countries",
51
- "Members - International and Regional State Associations",
52
- ]
53
- ]
54
- members = [item for sublist in members for item in sublist]
55
- members = [special_character_words_mapper.get(member, member) for member in members]
56
-
57
- nonmembers = [
58
- authors[key]
59
- for key in [
60
- "Intergovernmental Negotiation Committee",
61
- "Observers and Other Participants",
62
- ]
63
- ]
64
- nonmembers = [item for sublist in nonmembers for item in sublist]
65
-
66
- df["author"][df["author"] == "['Côte D'Ivoire']"] = "['Côte DIvoire']"
67
- df["author"][
68
- df["author"] == "['Ligue Camerounaise Des Droits De L'Homme']"
69
- ] = "['Ligue Camerounaise Des Droits De LHomme']"
70
- df["author"][
71
- df["author"]
72
- == "['Association Pour L'Integration Et La Developpement Durable Au Burundi']"
73
- ] = "['Association Pour LIntegration Et La Developpement Durable Au Burundi']"
74
-
75
- df["author"] = df["author"].apply(ast.literal_eval)
76
- df = df[df["author"].apply(lambda x: any(item in members for item in x))]
77
- df["author"] = df["author"].apply(
78
- lambda x: [item for item in x if item not in nonmembers]
79
- )
80
- df["author"] = df["author"].apply(
81
- lambda x: [item.replace("Côte DIvoire", "Côte D 'Ivoire") for item in x]
82
- )
83
- df["draft_labs"] = df["draft_labs"].fillna("[]")
84
- df["author"][
85
- df["author"] == "['The Alliance Of Small Island States (AOSIS)']"
86
- ] = "['Alliance Of Small Island States (AOSIS)']"
87
-
88
- print(f"Filtered dataset to {len(df)} entries")
89
- df.to_csv(DATASET_PROCESSED, index=False)
90
-
91
-
92
- if __name__ == "__main__":
93
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data_processing/get_meta_data_filter.py DELETED
@@ -1,21 +0,0 @@
1
- import pandas as pd
2
- import sys
3
- import os
4
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
5
- from src.rag.pipeline import RAGPipeline
6
-
7
- DATASET = os.path.join("data", "inc_df.csv")
8
- META_DATA = os.path.join("database", "meta_data.csv")
9
-
10
- rag_pipeline = RAGPipeline(
11
- embedding_model="sentence-transformers/distiluse-base-multilingual-cased-v1",
12
- prompt_template="src/rag/prompt_template.yaml",
13
- )
14
-
15
- meta_data = pd.DataFrame(
16
- [document.meta for document in rag_pipeline.document_store.get_all_documents()]
17
- )
18
-
19
- meta_data = meta_data.drop_duplicates(subset=["retriever_id"], keep="first")
20
-
21
- meta_data.to_csv(META_DATA, index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data_processing/taxonomy_processing.py DELETED
@@ -1,47 +0,0 @@
1
- import os
2
-
3
- from src.utils.data import load_json, save_json
4
-
5
- AUTHORS_TAXONOMY = os.path.join("data", "authors_taxonomy.json")
6
- AUTHORS_FILTER = os.path.join("data", "authors_filter.json")
7
-
8
- DRAFT_CATEGORIES_TAXONOMY = os.path.join("data", "draftcat_taxonomy.json")
9
- DRAFT_CATEGORIES_FILTER = os.path.join("data", "draftcat_taxonomy_filter.json")
10
-
11
-
12
- def get_authors(taxonomy: dict) -> dict:
13
- countries = taxonomy["Members"]["Countries"]
14
- associations = taxonomy["Members"][
15
- "International and Regional State Associations"
16
- ] # noqa: E501
17
- intergovernmental_negotiations = taxonomy[
18
- "Intergovernmental Negotiation Committee"
19
- ] # noqa: E501
20
- observers = taxonomy["Observers and Other Participants"] # noqa: E501
21
- return {
22
- "Members - Countries": countries,
23
- "Members - International and Regional State Associations": associations, # noqa: E501
24
- "Intergovernmental Negotiation Committee": intergovernmental_negotiations, # noqa: E501
25
- "Observers and Other Participants": observers,
26
- }
27
-
28
-
29
- def get_draftcategories(taxonomy: dict) -> dict:
30
- taxonomy_filter = {}
31
- for draft_part, part_values in taxonomy.items():
32
- part = draft_part
33
- temp_values = []
34
- for part_name, part_value in part_values.items():
35
- temp_values.append(part_value)
36
- taxonomy_filter[part] = temp_values
37
- return taxonomy_filter
38
-
39
-
40
- if __name__ == "__main__":
41
- authors_taxonomy = load_json(AUTHORS_TAXONOMY)
42
- authors_filter = get_authors(authors_taxonomy)
43
- save_json(file_path=AUTHORS_FILTER, data=authors_filter)
44
-
45
- draft_categories_taxonomy = load_json(DRAFT_CATEGORIES_TAXONOMY)
46
- draft_categories_filter = get_draftcategories(draft_categories_taxonomy)
47
- save_json(file_path=DRAFT_CATEGORIES_FILTER, data=draft_categories_filter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/document_store/document_store.py DELETED
@@ -1,180 +0,0 @@
1
- from haystack.document_stores import InMemoryDocumentStore
2
- import pandas as pd
3
- import os
4
- import pathlib
5
- import ast
6
- from sklearn.preprocessing import MultiLabelBinarizer
7
- from langchain_community.document_loaders import DataFrameLoader
8
- from langchain.text_splitter import (
9
- RecursiveCharacterTextSplitter,
10
- )
11
- from typing import Any
12
-
13
-
14
- INC_TEST_DATASET_PATH = os.path.join("data", "inc_df.csv")
15
- EMBEDDING_DIMENSION = 512
16
-
17
- special_character_words_mapper = {
18
- "Côte D'Ivoire": "Côte DIvoire",
19
- "Ligue Camerounaise Des Droits De L'Homme": "Ligue Camerounaise Des Droits De LHomme",
20
- "Association Pour L'Integration Et La Developpement Durable Au Burundi": "Association Pour LIntegration Et La Developpement Durable Au Burundi",
21
- }
22
- special_character_words_reverse_mapper = {}
23
- for key, value in special_character_words_mapper.items():
24
- special_character_words_reverse_mapper[value] = key
25
-
26
-
27
- def transform_to_list(row):
28
- special_characters = False
29
- if str(row) == "[]" or str(row) == "nan":
30
- return []
31
- else:
32
- # replace special characters
33
- for key, value in special_character_words_mapper.items():
34
- if key in row:
35
- row = row.replace(key, value)
36
- special_characters = True
37
- row = ast.literal_eval(row)
38
- if special_characters:
39
- for key, value in special_character_words_reverse_mapper.items():
40
- if key in row:
41
- # get the index of the special character word
42
- index = row.index(key)
43
- # replace the special character word with the original word
44
- row[index] = value
45
- return row
46
-
47
-
48
- def transform_data(df: pd.DataFrame):
49
- # df["author"] = df["authors"].drop(columns=["authors"], axis=1)
50
- df = df[df["doc_subtype"] != "Working documents"]
51
- df = df[df["doc_subtype"] != "Contact Groups"]
52
- df = df[df["doc_subtype"] != "Unsolicitated Submissions"]
53
- df = df[df["doc_type"] != "official document"]
54
- df = df[df["doc_subtype"] != "Stakeholder Dialogue"]
55
- df["text"] = df["text"].astype(str).str.replace("_x000D_", " ")
56
- df["text"] = df["text"].astype(str).str.replace("\n", " ")
57
- # df["text"] = df["text"].astype(str).str.replace("\r", " ")
58
- df["author"] = df["author"].str.replace("\xa0", " ")
59
- df["author"] = df["author"].str.replace("ü", "u")
60
- df["author"] = df["author"].str.strip()
61
- df["author"] = df["author"].astype(str).str.replace("\r", " ")
62
-
63
- df = df[
64
- [
65
- "author",
66
- "doc_type",
67
- "round",
68
- "text",
69
- "href",
70
- "draft_labs",
71
- "draft_cats",
72
- "retriever_id",
73
- ]
74
- ].copy()
75
-
76
- df = df.rename(columns={"text": "page_content"}).copy()
77
-
78
- df["draft_labs2"] = df["draft_labs"]
79
- df["author2"] = df["author"]
80
-
81
- df["draft_labs"] = df.apply(lambda x: transform_to_list(x["draft_labs"]), axis=1)
82
- df["author"] = df.apply(lambda x: transform_to_list(x["author"]), axis=1)
83
-
84
- # df["draft_labs"] = df["draft_labs"].apply(
85
- # lambda x: ast.literal_eval(x) if str(x) != "[]" or str(x) != "nan" else []
86
- # )
87
- # df["author"] = df["author"].apply(
88
- # lambda x: ast.literal_eval(x) if str(x) != "[]" else []
89
- # )
90
-
91
- mlb = MultiLabelBinarizer(sparse_output=True)
92
- mlb = MultiLabelBinarizer()
93
- df = df.join(
94
- pd.DataFrame(
95
- mlb.fit_transform(df.pop("draft_labs")),
96
- columns=mlb.classes_,
97
- index=df.index,
98
- )
99
- ).join(
100
- pd.DataFrame(
101
- mlb.fit_transform(df.pop("author")), columns=mlb.classes_, index=df.index
102
- )
103
- )
104
-
105
- df["draft_labs"] = df["draft_labs2"]
106
- df = df.drop(columns=["draft_labs2"], axis=1)
107
-
108
- df["author"] = df["author2"]
109
- df = df.drop(columns=["author2"], axis=1)
110
-
111
- loader = DataFrameLoader(df, page_content_column="page_content")
112
- docs = loader.load()
113
- return docs
114
-
115
-
116
- def process_data(docs):
117
-
118
- chunk_size = 512
119
- text_splitter = RecursiveCharacterTextSplitter(
120
- chunk_size=chunk_size,
121
- chunk_overlap=int(chunk_size / 10),
122
- add_start_index=True,
123
- strip_whitespace=True,
124
- separators=["\n\n", "\n", " ", ""],
125
- )
126
-
127
- docs_chunked = text_splitter.transform_documents(docs)
128
-
129
- df = pd.DataFrame(docs_chunked, columns=["page_content", "metadata", "type"]).drop(
130
- "type", axis=1
131
- )
132
- df["page_content"] = df["page_content"].astype(str)
133
- df["page_content"] = df["page_content"].str.replace("'page_content'", "")
134
- df["page_content"] = df["page_content"].str.replace("(", "")
135
- df["page_content"] = df["page_content"].str.replace(")", "").str[1:]
136
- df = pd.concat(
137
- [df.drop("metadata", axis=1), df["metadata"].apply(pd.Series)], axis=1
138
- )
139
- df = df.rename(columns={0: "a", 1: "b"})
140
- df = pd.concat([df.drop(["a", "b"], axis=1), df["b"].apply(pd.Series)], axis=1)
141
-
142
- cols = ["author", "draft_labs"]
143
- for c in cols:
144
- df[c] = df[c].apply(
145
- lambda x: "".join(x) if isinstance(x, (list, tuple)) else str(x)
146
- )
147
- chars = ["[", "]", "'"]
148
- for g in chars:
149
- df[c] = df[c].str.replace(g, "")
150
-
151
- df["page_content"] = df["page_content"].astype(str).str.replace("\n", " ")
152
- df["page_content"] = df["page_content"].astype(str).str.replace("\r", " ")
153
-
154
- cols = ["author", "draft_labs", "page_content"]
155
- df["page_content"] = df[cols].apply(lambda row: " | ".join(row.astype(str)), axis=1)
156
- df = df.rename(columns={"page_content": "content"})
157
-
158
- documents = []
159
- for _, row in df.iterrows():
160
- row_meta: dict[str, Any] = {}
161
- for column in df.columns:
162
- if column != "content":
163
- if column == "retriever_id":
164
- row_meta[column] = str(row[column])
165
- else:
166
- row_meta[column] = row[column]
167
- documents.append({"content": row["content"], "meta": row_meta})
168
- return documents
169
-
170
-
171
- def get_document_store():
172
- df = pd.read_csv(INC_TEST_DATASET_PATH)
173
- # df["retriever_id"] = [str(i) for i in range(len(df))]
174
- pathlib.Path("database").mkdir(parents=True, exist_ok=True)
175
- document_store = InMemoryDocumentStore(
176
- embedding_field="embedding", embedding_dim=EMBEDDING_DIMENSION, use_bm25=False
177
- )
178
- docs = transform_data(df=df)
179
- document_store.write_documents(process_data(docs=docs))
180
- return document_store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rag/__init__.py DELETED
File without changes
src/rag/pipeline.py DELETED
@@ -1,93 +0,0 @@
1
- import os
2
- import pickle
3
- from typing import Any
4
-
5
- from dotenv import load_dotenv
6
- from haystack.nodes import ( # type: ignore
7
- AnswerParser,
8
- EmbeddingRetriever,
9
- PromptNode,
10
- PromptTemplate,
11
- )
12
- from haystack.pipelines import Pipeline
13
-
14
- from src.document_store.document_store import get_document_store
15
-
16
- load_dotenv()
17
-
18
- OPENAI_API_KEY = os.environ.get("OPEN_API_KEY")
19
-
20
-
21
- class RAGPipeline:
22
- def __init__(
23
- self,
24
- embedding_model: str,
25
- prompt_template: str,
26
- ):
27
- self.load_document_store()
28
- self.embedding_model = embedding_model
29
- self.prompt_template = prompt_template
30
- self.retriever_node = self.generate_retriever_node()
31
- self.prompt_node = self.generate_prompt_node()
32
- self.update_embeddings()
33
- self.pipe = self.build_pipeline()
34
-
35
- def run(self, prompt: str, filters: dict) -> Any:
36
- try:
37
- result = self.pipe.run(query=prompt, params={"filters": filters})
38
- return result
39
- except Exception as e:
40
- print(e)
41
- return None
42
-
43
- def build_pipeline(self):
44
- pipe = Pipeline()
45
- pipe.add_node(component=self.retriever_node, name="retriever", inputs=["Query"])
46
- pipe.add_node(
47
- component=self.prompt_node,
48
- name="prompt_node",
49
- inputs=["retriever"],
50
- )
51
- return pipe
52
-
53
- def load_document_store(self):
54
- if os.path.exists(os.path.join("database", "document_store.pkl")):
55
- with open(
56
- file=os.path.join("database", "document_store.pkl"), mode="rb"
57
- ) as f:
58
- self.document_store = pickle.load(f)
59
- else:
60
- self.document_store = get_document_store()
61
-
62
- def generate_retriever_node(self):
63
- retriever_node = EmbeddingRetriever(
64
- document_store=self.document_store,
65
- embedding_model=self.embedding_model,
66
- top_k=7,
67
- )
68
- return retriever_node
69
-
70
- def update_embeddings(self):
71
- if not os.path.exists(os.path.join("database", "document_store.pkl")):
72
- self.document_store.update_embeddings(
73
- self.retriever_node, update_existing_embeddings=True
74
- )
75
-
76
- with open(
77
- file=os.path.join("database", "document_store.pkl"), mode="wb"
78
- ) as f:
79
- pickle.dump(self.document_store, f)
80
-
81
- def generate_prompt_node(self):
82
- rag_prompt = PromptTemplate(
83
- prompt=self.prompt_template,
84
- output_parser=AnswerParser(reference_pattern=r"Document\[(\d+)\]"),
85
- )
86
- prompt_node = PromptNode(
87
- model_name_or_path="gpt-4",
88
- default_prompt_template=rag_prompt,
89
- api_key=OPENAI_API_KEY,
90
- max_length=4000,
91
- model_kwargs={"temperature": 0.2, "max_tokens": 4096},
92
- )
93
- return prompt_node
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rag/prompt_template.yaml DELETED
@@ -1,20 +0,0 @@
1
- name: deepset/question-answering-with-document-references
2
- text: |
3
- Answer the question '{query}' using only the provided documents and avoiding text.
4
- Formulate your answer in the style of an academic report.
5
- Provide example quotes and citations using extracted text from the documents.
6
- Use facts and numbers from the documents in your answer.
7
- ALWAYS include the references of the documents used from documents at the end of each applicable sentence using the format [number].
8
- If the answer isn't in the document say 'Answering is not possible given the available information'.
9
- Documents: \n
10
- {join(documents, delimiter=new_line, pattern=new_line+'Document($retriever_id): $content', str_replace={new_line: ' ', '[': '(', ']': ')'})} \n
11
- Answer:
12
- tags:
13
- - question-answering
14
- description: Perform question answering with references to documents.
15
- meta:
16
- authors:
17
- - deepset-ai
18
- version: '0.1.0'
19
-
20
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/__init__.py DELETED
File without changes
src/utils/data.py DELETED
@@ -1,76 +0,0 @@
1
- import json
2
- from typing import Any
3
-
4
- import pandas as pd
5
- import streamlit as st
6
- from functools import reduce
7
-
8
-
9
- def get_filter_values(df: pd.DataFrame, column_name: str) -> list:
10
- return df[column_name].unique().tolist()
11
-
12
-
13
- def build_filter(
14
- meta_data: pd.DataFrame,
15
- authors_filter: list[str],
16
- draft_cats_filter: list[str],
17
- round_filter: list[int],
18
- ) -> dict[str, int | str] | dict:
19
- authors = authors_filter
20
- round_number = round_filter
21
- draft_cats = draft_cats_filter
22
-
23
- # set authors_flag to True if not empty list
24
- authors_flag = True if len(authors) > 0 else False
25
- draft_cats_flag = True if len(draft_cats) > 0 else False
26
- round_number_flag = True if len(round_number) > 0 else False
27
-
28
- conditions = []
29
-
30
- if authors_flag:
31
- authors_condition = (meta_data[col] == 1 for col in authors)
32
- authors_conditions_list = reduce(lambda a, b: a | b, authors_condition)
33
- conditions.append(authors_conditions_list)
34
-
35
- if draft_cats_flag:
36
- draft_cat_condition = (meta_data[col] for col in draft_cats)
37
- draft_cat_conditions_list = reduce(lambda a, b: a | b, draft_cat_condition)
38
- conditions.append(draft_cat_conditions_list)
39
-
40
- if round_number_flag:
41
- round_condition = meta_data["round"].isin(round_number)
42
- conditions.append(round_condition)
43
-
44
- if len(conditions) == 0:
45
- filtered_retriever_ids = []
46
- else:
47
- final_condition = reduce(lambda a, b: a & b, conditions)
48
- filtered_retriever_ids = meta_data[final_condition]["retriever_id"].tolist()
49
- if len(filtered_retriever_ids) == 0:
50
- return {}
51
- else:
52
- return {"retriever_id": filtered_retriever_ids}
53
-
54
-
55
- def load_json(file_path: str) -> dict:
56
- with open(file_path, "r") as f:
57
- return json.load(f)
58
-
59
-
60
- def save_json(file_path: str, data: dict) -> None:
61
- with open(file_path, "w") as f:
62
- json.dump(data, f, indent=4)
63
-
64
-
65
- def get_meta(result: dict[str, Any]) -> list[dict[str, Any]]:
66
- meta_data = []
67
- for doc in result["documents"]:
68
- current_meta = doc.meta
69
- current_meta["content"] = doc.content
70
- meta_data.append(current_meta)
71
- return meta_data
72
-
73
-
74
- def load_css(file_name) -> None:
75
- with open(file_name) as f:
76
- st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/data_v2.py DELETED
@@ -1,80 +0,0 @@
1
- import json
2
- from typing import Any
3
-
4
- import pandas as pd
5
- import streamlit as st
6
- from functools import reduce
7
-
8
-
9
- def get_filter_values(df: pd.DataFrame, column_name: str) -> list:
10
- return df[column_name].unique().tolist()
11
-
12
-
13
- def build_filter(
14
- meta_data: pd.DataFrame,
15
- authors_filter: list[str],
16
- draft_cats_filter: list[str],
17
- # round_filter: list[int],
18
- ) -> dict[str, int | str] | dict:
19
- authors = authors_filter
20
- #round_number = round_filter
21
- draft_cats = draft_cats_filter
22
-
23
- # set authors_flag to True if not empty list
24
- authors_flag = True if len(authors) > 0 else False
25
- draft_cats_flag = True if len(draft_cats) > 0 else False
26
- #round_number_flag = True if len(round_number) > 0 else False
27
-
28
- if authors_flag is False and draft_cats_flag is False:
29
- return {}, "no filters selected"
30
-
31
- conditions = []
32
-
33
- if authors_flag:
34
- authors_condition = (meta_data[col] == 1 for col in authors)
35
- authors_conditions_list = reduce(lambda a, b: a | b, authors_condition)
36
- conditions.append(authors_conditions_list)
37
-
38
- if draft_cats_flag:
39
- draft_cat_condition = (meta_data[col] for col in draft_cats)
40
- draft_cat_conditions_list = reduce(lambda a, b: a | b, draft_cat_condition)
41
- conditions.append(draft_cat_conditions_list)
42
-
43
- # if round_number_flag:
44
- # round_condition = meta_data["round"].isin(round_number)
45
- # conditions.append(round_condition)
46
-
47
- if len(conditions) == 0:
48
- filtered_retriever_ids = []
49
- else:
50
- final_condition = reduce(lambda a, b: a & b, conditions)
51
- filtered_retriever_ids = meta_data[final_condition]["retriever_id"].tolist()
52
-
53
- if len(filtered_retriever_ids) == 0:
54
- return {}, "no results found"
55
- else:
56
- return {"retriever_id": filtered_retriever_ids}, "success"
57
-
58
-
59
- def load_json(file_path: str) -> dict:
60
- with open(file_path, "r") as f:
61
- return json.load(f)
62
-
63
-
64
- def save_json(file_path: str, data: dict) -> None:
65
- with open(file_path, "w") as f:
66
- json.dump(data, f, indent=4)
67
-
68
-
69
- def get_meta(result: dict[str, Any]) -> list[dict[str, Any]]:
70
- meta_data = []
71
- for doc in result["documents"]:
72
- current_meta = doc.meta
73
- current_meta["content"] = doc.content
74
- meta_data.append(current_meta)
75
- return meta_data
76
-
77
-
78
- def load_css(file_name) -> None:
79
- with open(file_name) as f:
80
- st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/writer.py DELETED
@@ -1,17 +0,0 @@
1
- import streamlit as st
2
- import time
3
-
4
-
5
- def typewriter(text: str, references: list, speed: int):
6
- tokens = text.split()
7
- container = st.empty()
8
- for index in range(len(tokens) + 1):
9
- curr_full_text = " ".join(tokens[:index])
10
- container.markdown(curr_full_text)
11
- time.sleep(1 / speed)
12
- curr_full_text += "\n"
13
- container.markdown(curr_full_text)
14
- curr_full_text += "\n **References** \n"
15
- container.markdown(curr_full_text)
16
- curr_full_text += "\n".join(references)
17
- container.markdown(curr_full_text)